1919
2020TracedRArray {T,N} (x:: TracedRArray{T,N} ) where {T,N} = x
2121
22- mutable struct TracedRNumber{T} <: RNumber{T}
23- paths:: Tuple
24- mlir_data:: Union{Nothing,MLIR.IR.Value}
25-
26- function TracedRNumber {T} (
27- paths:: Tuple , mlir_data:: Union{Nothing,MLIR.IR.Value}
28- ) where {T}
29- if ! isnothing (mlir_data)
30- @assert size (MLIR. IR. type (mlir_data)) == ()
31- end
32- return new {T} (paths, mlir_data)
33- end
34- end
35-
36- Base. eltype (:: Type{TracedRNumber{T}} ) where {T} = T
37-
3822const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}}
3923const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}}
4024const AnyTracedRVector{T} = AnyTracedRArray{T,1 }
@@ -55,15 +39,6 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...)
5539 return get_ancestor_indices (parent (x), Base. reindex (parentindices (x), indices)... )
5640end
5741
58- Base. getindex (a:: TracedRNumber{T} ) where {T} = a
59-
60- Base. zero (:: TracedRNumber{T} ) where {T} = promote_to (TracedRNumber{T}, zero (T))
61- Base. one (:: TracedRNumber{T} ) where {T} = promote_to (TracedRNumber{T}, one (T))
62-
63- function Base. convert (:: Type{<:TracedRNumber{T}} , x:: Number ) where {T}
64- return promote_to (TracedRNumber{T}, T (x))
65- end
66-
6742function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
6843 @warn (
6944 """ Performing scalar indexing on task $(current_task ()) .
@@ -148,12 +123,6 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
148123 # return print(io, X.mlir_data, ")")
149124end
150125
151- function Base. show (io:: IOty , X:: TracedRNumber{T} ) where {T,IOty<: Union{IO,IOContext} }
152- return print (io, " TracedRNumber{" , T, " }(" , X. paths, " )" )
153- end
154-
155- Base. only (A:: TracedRNumber{T} ) where {T} = A
156-
157126function Base. reshape (A:: AnyTracedRArray{T,N} , dims:: NTuple{NT,Int} ) where {T,N,NT}
158127 if prod (dims) != prod (size (A))
159128 throw (
@@ -214,18 +183,6 @@ function Base.transpose(A::AnyTracedRVecOrMat)
214183end
215184Base. adjoint (A:: AnyTracedRVecOrMat{<:Real} ) = transpose (A)
216185
217- function Base. promote_rule (:: Type{TracedRNumber{T}} , :: Type{TracedRNumber{S}} ) where {T,S}
218- return TracedRNumber{Base. promote_type (T, S)}
219- end
220-
221- function Base. promote_rule (:: Type{T} , :: Type{TracedRNumber{S}} ) where {T,S}
222- return TracedRNumber{Base. promote_type (T, S)}
223- end
224-
225- function Base. convert (:: Type{TracedRNumber{T}} , x:: Number ) where {T}
226- return promote_to (TracedRNumber{T}, x)
227- end
228-
229186function promote_to (:: Type{TracedRArray{T,N}} , rhs) where {T,N}
230187 if isa (rhs, TracedRArray)
231188 rhs isa TracedRArray{T,N} && return rhs
@@ -254,103 +211,10 @@ function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
254211 )
255212end
256213
257- function promote_to (:: Type{TracedRNumber{T}} , rhs) where {T}
258- if isa (rhs, TracedRNumber)
259- rhs isa TracedRNumber{T} && return rhs
260- return TracedRNumber {T} (
261- (),
262- MLIR. IR. result (
263- MLIR. Dialects. stablehlo. convert (
264- rhs. mlir_data; result= mlir_type (TracedRNumber{T})
265- ),
266- 1 ,
267- ),
268- )
269- end
270- if isa (rhs, Number)
271- attr = fill (MLIR. IR. Attribute (T (rhs)), mlir_type (TracedRNumber{T}))
272- return TracedRNumber {T} (
273- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
274- )
275- end
276- T0 = eltype (rhs)
277- attr = MLIR. IR. DenseElementsAttribute (collect (rhs))
278- return promote_to (
279- TracedRNumber{T},
280- TracedRNumber {T0} (
281- (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
282- ),
283- )
284- end
285-
286214promote_to (:: TracedRArray{T,N} , rhs) where {T,N} = promote_to (TracedRArray{T,N}, rhs)
287- promote_to (:: TracedRNumber{T} , rhs) where {T} = promote_to (TracedRNumber{T}, rhs)
288-
289- for (jlop, hloop) in (
290- (:(Base. min), :minimum ),
291- (:(Base. max), :maximum ),
292- (:(Base.:+ ), :add ),
293- (:(Base.:- ), :subtract ),
294- (:(Base.:* ), :multiply ),
295- (:(Base.:/ ), :divide ),
296- (:(Base.:^ ), :power ),
297- )
298- @eval function $ (jlop)(
299- @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: TracedRNumber{T} )
300- ) where {T}
301- return TracedRNumber {T} (
302- (),
303- MLIR. IR. result (
304- MLIR. Dialects. stablehlo.$ (hloop)(lhs. mlir_data, rhs. mlir_data), 1
305- ),
306- )
307- end
308- end
309-
310- function Base. ifelse (
311- @nospecialize (pred:: TracedRNumber{Bool} ),
312- @nospecialize (x:: TracedRNumber{T1} ),
313- @nospecialize (y:: TracedRNumber{T2} )
314- ) where {T1,T2}
315- return TracedRNumber {promote_type(T1, T2)} (
316- (),
317- MLIR. IR. result (
318- MLIR. Dialects. stablehlo. select (pred. mlir_data, x. mlir_data, y. mlir_data), 1
319- ),
320- )
321- end
322-
323- function Base. literal_pow (
324- :: Base.RefValue{typeof(^)} , x:: TracedRNumber{T} , :: Base.RefValue{Val{P}}
325- ) where {T,P}
326- return Base. literal_pow (^ , x, Val (P))
327- end
328-
329- for (jlop, hloop) in (
330- (:(Base. abs), :abs ),
331- (:(Base.:- ), :negate ),
332- (:(Base. sin), :sine ),
333- (:(Base. cos), :cosine ),
334- (:(Base. tanh), :tanh ),
335- (:(Base. FastMath. tanh_fast), :tanh ),
336- (:(Base. exp), :exponential ),
337- (:(Base. FastMath. exp_fast), :exponential ),
338- (:(Base. log), :log ),
339- (:(Base. sqrt), :sqrt ),
340- )
341- @eval function $ (jlop)(@nospecialize (lhs:: TracedRNumber{T} )) where {T}
342- return TracedRNumber {T} (
343- (), MLIR. IR. result (MLIR. Dialects. stablehlo.$ hloop (lhs. mlir_data), 1 )
344- )
345- end
346- end
347215
348216struct TypeCast{T<: Number } <: Function end
349217
350- function (:: TypeCast{T} )(x:: TracedRNumber{T2} ) where {T,T2}
351- return promote_to (TracedRNumber{T}, x)
352- end
353-
354218elem_apply (:: Type{T} , x:: TracedRArray{T} ) where {T<: Number } = x
355219function elem_apply (:: Type{T} , x:: TracedRArray{T2} ) where {T<: Number ,T2<: Number }
356220 # Special Path to prevent going down a despecialized path
@@ -435,41 +299,13 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
435299 return traced2_result
436300end
437301
438- for (jlop, hloop, hlocomp, merge) in (
439- (:(Base.:(== )), :compare , " EQ" , :all ),
440- (:(Base.:(!= )), :compare , " NE" , :any ),
441- (:(Base.:(>= )), :compare , " GE" , nothing ),
442- (:(Base.:(> )), :compare , " GT" , nothing ),
443- (:(Base.:(<= )), :compare , " LE" , nothing ),
444- (:(Base.:(< )), :compare , " LT" , nothing ),
445- )
446- @eval function $ (jlop)(
447- @nospecialize (lhs:: TracedRNumber{T} ), @nospecialize (rhs:: TracedRNumber{T} )
448- ) where {T}
449- return TracedRNumber {Bool} (
450- (),
451- MLIR. IR. result (
452- MLIR. Dialects. stablehlo.$ (hloop)(
453- lhs. mlir_data,
454- rhs. mlir_data;
455- comparison_direction= MLIR. API. stablehloComparisonDirectionAttrGet (
456- MLIR. IR. context (), $ hlocomp
457- ),
458- ),
459- 1 ,
460- ),
461- )
462- end
463-
464- if merge != = nothing
465- @eval begin
466- function $jlop (
467- @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs:: TracedRArray{T,N} )
468- ) where {T,N}
469- elems = $ (jlop). (lhs, rhs)
470- return N == 0 ? elems : $ (merge)(elems)
471- end
472- end
302+ for (jlop, hloop, hlocomp, merge) in
303+ ((:(Base.:(== )), :compare , " EQ" , :all ), (:(Base.:(!= )), :compare , " NE" , :any ))
304+ @eval function $jlop (
305+ @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs:: TracedRArray{T,N} )
306+ ) where {T,N}
307+ elems = $ (jlop). (lhs, rhs)
308+ return N == 0 ? elems : $ (merge)(elems)
473309 end
474310end
475311
0 commit comments