@@ -26,11 +26,11 @@ function Base.setproperty!(x::TracedRArray, f::Symbol, v)
2626 return setfield! (x, f, v)
2727end
2828
29- mutable struct TracedRScalar {T} <: RScalar {T}
29+ mutable struct TracedRNumber {T} <: RNumber {T}
3030 paths:: Tuple
3131 mlir_data:: Union{Nothing,MLIR.IR.Value}
3232
33- function TracedRScalar {T} (
33+ function TracedRNumber {T} (
3434 paths:: Tuple , mlir_data:: Union{Nothing,MLIR.IR.Value}
3535 ) where {T}
3636 if ! isnothing (mlir_data)
@@ -40,14 +40,14 @@ mutable struct TracedRScalar{T} <: RScalar{T}
4040 end
4141end
4242
43- function Base. setproperty! (x:: TracedRScalar , f:: Symbol , v)
43+ function Base. setproperty! (x:: TracedRNumber , f:: Symbol , v)
4444 if f === :mlir_data && ! isnothing (v)
4545 @assert size (MLIR. IR. type (v)) == ()
4646 end
4747 return setfield! (x, f, v)
4848end
4949
50- Base. eltype (:: Type{TracedRScalar {T}} ) where {T} = T
50+ Base. eltype (:: Type{TracedRNumber {T}} ) where {T} = T
5151
5252const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
5353const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
@@ -69,13 +69,13 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
6969 return get_ancestor_indices (parent (x), Base. reindex (parentindices (x), indices)... )
7070end
7171
72- Base. getindex (a:: TracedRScalar {T} ) where {T} = a
72+ Base. getindex (a:: TracedRNumber {T} ) where {T} = a
7373
74- Base. zero (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar {T}, zero (T))
75- Base. one (:: TracedRScalar {T} ) where {T} = promote_to (TracedRScalar {T}, one (T))
74+ Base. zero (:: TracedRNumber {T} ) where {T} = promote_to (TracedRNumber {T}, zero (T))
75+ Base. one (:: TracedRNumber {T} ) where {T} = promote_to (TracedRNumber {T}, one (T))
7676
77- function Base. convert (:: Type{<:TracedRScalar {T}} , x:: Number ) where {T}
78- return promote_to (TracedRScalar {T}, T (x))
77+ function Base. convert (:: Type{<:TracedRNumber {T}} , x:: Number ) where {T}
78+ return promote_to (TracedRNumber {T}, T (x))
7979end
8080
8181function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
@@ -102,7 +102,7 @@ and require expensive copies and synchronization each time and therefore should
102102 ),
103103 1 ,
104104 )
105- return TracedRScalar {T} ((), res2)
105+ return TracedRNumber {T} ((), res2)
106106end
107107
108108function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
@@ -137,7 +137,7 @@ function Base.setindex!(
137137 a:: TracedRArray{T,N} , v, indices:: Vararg{Union{Base.AbstractUnitRange,Colon},N}
138138) where {T,N}
139139 indices = [
140- (promote_to (TracedRScalar {Int}, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
140+ (promote_to (TracedRNumber {Int}, i isa Colon ? 1 : first (i)) - 1 ). mlir_data for
141141 i in indices
142142 ]
143143 v = promote_to (TracedRArray{T,N}, v)
@@ -162,11 +162,11 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
162162 # return print(io, X.mlir_data, ")")
163163end
164164
165- function Base. show (io:: IOty , X:: TracedRScalar {T} ) where {T,IOty<: Union{IO,IOContext} }
166- return print (io, " TracedRScalar {" , T, " }(" , X. paths, " )" )
165+ function Base. show (io:: IOty , X:: TracedRNumber {T} ) where {T,IOty<: Union{IO,IOContext} }
166+ return print (io, " TracedRNumber {" , T, " }(" , X. paths, " )" )
167167end
168168
169- Base. only (A:: TracedRScalar {T} ) where {T} = A
169+ Base. only (A:: TracedRNumber {T} ) where {T} = A
170170
171171function Base. reshape (A:: AnyTracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
172172 if prod (dims) != prod (size (A))
@@ -238,12 +238,12 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
238238 return TracedRArray{Base. promote_type (T, S),N}
239239end
240240
241- function Base. promote_rule (:: Type{T} , :: Type{TracedRScalar {S}} ) where {T,S}
242- return TracedRScalar {Base. promote_type (T, S)}
241+ function Base. promote_rule (:: Type{T} , :: Type{TracedRNumber {S}} ) where {T,S}
242+ return TracedRNumber {Base. promote_type (T, S)}
243243end
244244
245- function Base. convert (:: Type{TracedRScalar {T}} , x:: Number ) where {T}
246- return promote_to (TracedRScalar {T}, x)
245+ function Base. convert (:: Type{TracedRNumber {T}} , x:: Number ) where {T}
246+ return promote_to (TracedRNumber {T}, x)
247247end
248248
249249function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
@@ -262,7 +262,7 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
262262 end
263263 if isa (rhs, Number)
264264 throw (ArgumentError (" Cannot promote number to `TracedRArray`. Use \
265- `TracedRScalar ` instead." ))
265+ `TracedRNumber ` instead." ))
266266 end
267267 T0 = eltype (rhs)
268268 attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
@@ -274,37 +274,37 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
274274 )
275275end
276276
277- function promote_to (:: Type{TracedRScalar {T}} , rhs) where {T}
278- if isa (rhs, TracedRScalar )
279- rhs isa TracedRScalar {T} && return rhs
280- return TracedRScalar {T} (
277+ function promote_to (:: Type{TracedRNumber {T}} , rhs) where {T}
278+ if isa (rhs, TracedRNumber )
279+ rhs isa TracedRNumber {T} && return rhs
280+ return TracedRNumber {T} (
281281 (),
282282 MLIR. IR. result (
283283 MLIR. Dialects. stablehlo. convert (
284- rhs. mlir_data; result= mlir_type (TracedRScalar {T})
284+ rhs. mlir_data; result= mlir_type (TracedRNumber {T})
285285 ),
286286 1 ,
287287 ),
288288 )
289289 end
290290 if isa (rhs, Number)
291- attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRScalar {T}))
292- return TracedRScalar {T} (
291+ attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRNumber {T}))
292+ return TracedRNumber {T} (
293293 (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
294294 )
295295 end
296296 T0 = eltype (rhs)
297297 attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
298298 return promote_to (
299- TracedRScalar {T},
300- TracedRScalar {T0} (
299+ TracedRNumber {T},
300+ TracedRNumber {T0} (
301301 (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
302302 ),
303303 )
304304end
305305
306306promote_to (:: TracedRArray{T,N} , rhs) where {T,N} = promote_to (TracedRArray{T,N}, rhs)
307- promote_to (:: TracedRScalar {T} , rhs) where {T} = promote_to (TracedRScalar {T}, rhs)
307+ promote_to (:: TracedRNumber {T} , rhs) where {T} = promote_to (TracedRNumber {T}, rhs)
308308
309309for (jlop, hloop) in (
310310 (:(Base. min), :minimum ),
@@ -316,7 +316,7 @@ for (jlop, hloop) in (
316316 (:(Base.:^ ), :power ),
317317)
318318 @eval function $ (jlop)(
319- @nospecialize (lhs:: TracedRScalar {T} ), @nospecialize (rhs:: TracedRScalar {T} )
319+ @nospecialize (lhs:: TracedRNumber {T} ), @nospecialize (rhs:: TracedRNumber {T} )
320320 ) where {T}
321321 return TracedRArray {T} (
322322 (),
@@ -328,22 +328,22 @@ for (jlop, hloop) in (
328328end
329329
330330function Base. ifelse (
331- @nospecialize (pred:: TracedRScalar {Bool} ),
332- @nospecialize (x:: TracedRScalar {T1} ),
333- @nospecialize (y:: TracedRScalar {T2} )
331+ @nospecialize (pred:: TracedRNumber {Bool} ),
332+ @nospecialize (x:: TracedRNumber {T1} ),
333+ @nospecialize (y:: TracedRNumber {T2} )
334334) where {T1,T2}
335- return TracedRScalar {promote_type(T1, T2)} (
335+ return TracedRNumber {promote_type(T1, T2)} (
336336 (),
337337 MLIR. IR. result (
338338 MLIR. Dialects. stablehlo. select (pred. mlir_data, x. mlir_data, y. mlir_data), 1
339339 ),
340340 )
341341end
342342
343- Base. abs2 (x:: Reactant.TracedRScalar {T} ) where {T} = x * conj (x)
343+ Base. abs2 (x:: Reactant.TracedRNumber {T} ) where {T} = x * conj (x)
344344
345345function Base. literal_pow (
346- :: Base.RefValue{typeof(^)} , x:: TracedRScalar {T} , :: Base.RefValue{Val{P}}
346+ :: Base.RefValue{typeof(^)} , x:: TracedRNumber {T} , :: Base.RefValue{Val{P}}
347347) where {T,P}
348348 return Base. literal_pow (^ , x, Val (P))
349349end
@@ -360,8 +360,8 @@ for (jlop, hloop) in (
360360 (:(Base. log), :log ),
361361 (:(Base. sqrt), :sqrt ),
362362)
363- @eval function $ (jlop)(@nospecialize (lhs:: TracedRScalar {T} )) where {T}
364- return TracedRScalar {T} (
363+ @eval function $ (jlop)(@nospecialize (lhs:: TracedRNumber {T} )) where {T}
364+ return TracedRNumber {T} (
365365 (), MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 )
366366 )
367367 end
@@ -467,9 +467,9 @@ for (jlop, hloop, hlocomp, merge) in (
467467 (:(Base.:(< )), :compare , " LT" , nothing ),
468468)
469469 @eval function $ (jlop)(
470- @nospecialize (lhs:: TracedRScalar {T} ), @nospecialize (rhs:: TracedRScalar {T} )
470+ @nospecialize (lhs:: TracedRNumber {T} ), @nospecialize (rhs:: TracedRNumber {T} )
471471 ) where {T}
472- return TracedRScalar {Bool} (
472+ return TracedRNumber {Bool} (
473473 (),
474474 MLIR. IR. result (
475475 MLIR. Dialects. stablehlo.$ (hloop)(
@@ -571,7 +571,7 @@ function Base.mapreduce(
571571 fnbody = MLIR. IR. Block (in_tys, [MLIR. IR. Location () for arg in in_tys])
572572
573573 args = (
574- TracedRScalar {T} ((), MLIR. IR. argument (fnbody, i), ()) for
574+ TracedRNumber {T} ((), MLIR. IR. argument (fnbody, i), ()) for
575575 (i, ty) in enumerate (in_tys)
576576 )
577577
0 commit comments