From 2cdbb83d6377b41b4aef033498ec1c069aed8fc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 4 Oct 2024 14:57:47 -0400 Subject: [PATCH 01/19] Remove `Val` constraint on `Base._cat` signature --- src/TracedRArray.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index bc61862680..252f59ce67 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -759,10 +759,14 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted) return dest end -function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D} - @assert D isa Integer "Support for non-integer dimensions is not implemented yet." +dispatch_val(x) = x +dispatch_val(::Val{D}) where {D} = D - # MLIR expects the dimension `D` to be ≤ the rank of the input tensors +function Base._cat(dims, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N} + dims = dispatch_val(dims) + @assert dims isa Integer "Support for non-integer dimensions is not implemented yet." + + # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors A = maybe_expand_dims(A, dims) Bs = maybe_expand_dims.(Bs, (dims,)) @@ -775,7 +779,7 @@ function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) wher MLIR.Dialects.stablehlo.concatenate( [A.mlir_data, [B.mlir_data for B in Bs]...]; result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), - dimension=D - 1, # stablehlo expects this to be zero-indexed + dimension=dims - 1, # stablehlo expects this to be zero-indexed ), 1, ), From 9daca04529971db095cdc8ed0ee77ed6a94785c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 4 Oct 2024 15:57:00 -0400 Subject: [PATCH 02/19] Remove `Val` constraint on `maybe_expand_dims` --- src/TracedRArray.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 252f59ce67..8c0c4a944b 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -788,7 +788,8 @@ function Base._cat(dims, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N} return Res end -function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D} - D ≤ N && return x - return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, Val(D))) +function maybe_expand_dims(x::AbstractArray, dims) where {T,N} + dims = dispatch_val(dims) + dims ≤ N && return x + return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims)) end From 1ea20e0863ede5ce931ca9fa4e32b0e490c638a7 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 16:03:10 -0400 Subject: [PATCH 03/19] fix: update src/TracedRArray.jl --- src/TracedRArray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 8c0c4a944b..7ba42931dd 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -788,7 +788,7 @@ function Base._cat(dims, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N} return Res end -function maybe_expand_dims(x::AbstractArray, dims) where {T,N} +function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} dims = dispatch_val(dims) dims ≤ N && return x return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims)) From b4c3127227064293e4ce9049a678a90536a5ad2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 4 Oct 2024 17:49:33 -0400 Subject: [PATCH 04/19] Generalize `Base._cat` implementation on `TracedRArray` to typed `Base._cat_t` --- src/TracedRArray.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 2067a52f45..29e06965a7 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -764,22 +764,23 @@ end dispatch_val(x) = x dispatch_val(::Val{D}) where {D} = D -function Base._cat(dims, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N} +function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} dims = dispatch_val(dims) @assert dims isa Integer "Support for non-integer dimensions is not implemented yet." # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors - A = maybe_expand_dims(A, dims) - Bs = maybe_expand_dims.(Bs, (dims,)) + X = maybe_expand_dims.(X, (dims,)) catdims = Base.dims2cat(dims) - shape = Base.cat_size_shape(catdims, A, Bs...) - RT = Base.promote_eltype(A, Bs...) - Res = TracedRArray{RT,length(shape)}( + shape = Base.cat_size_shape(catdims, X...) + RT = Base.promote_eltype(T, X...) + + return TracedRArray{RT,length(shape)}( (), MLIR.IR.result( + # TODO maybe we should do some conversion? MLIR.Dialects.stablehlo.concatenate( - [A.mlir_data, [B.mlir_data for B in Bs]...]; + get_mlir_data.(X); result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), dimension=dims - 1, # stablehlo expects this to be zero-indexed ), @@ -787,7 +788,6 @@ function Base._cat(dims, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N} ), shape, ) - return Res end function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} From 296b0caf74bacbf6d6651778fe7b50c53f69ac15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Fri, 4 Oct 2024 17:53:07 -0400 Subject: [PATCH 05/19] Update src/TracedRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/TracedRArray.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 29e06965a7..980eb85e9f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -774,7 +774,6 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} catdims = Base.dims2cat(dims) shape = Base.cat_size_shape(catdims, X...) RT = Base.promote_eltype(T, X...) - return TracedRArray{RT,length(shape)}( (), MLIR.IR.result( From f442b35e2e118b350770fbad443c1f119b66b41f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 4 Oct 2024 23:30:04 -0400 Subject: [PATCH 06/19] Fix collection type passed to `stablehlo.concatenate` --- src/TracedRArray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 980eb85e9f..377cd07267 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -779,7 +779,7 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} MLIR.IR.result( # TODO maybe we should do some conversion? MLIR.Dialects.stablehlo.concatenate( - get_mlir_data.(X); + collect(get_mlir_data.(X)); result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), dimension=dims - 1, # stablehlo expects this to be zero-indexed ), From 5102a7f0027205a703fac8d22f1373b8d502b59c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 4 Oct 2024 23:47:37 -0400 Subject: [PATCH 07/19] Test `cat` methods --- test/basic.jl | 94 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 4467adf8f4..38847db4b4 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -210,6 +210,100 @@ end end @testset "concatenation" begin + @testset "0-dim" begin + x = fill(true) + x_concrete = Reactant.to_rarray(x) + + # NOTE [,,,] is a call to `vect`, not `*cat` + # f = Reactant.compile((x_concrete,)) do x + # return [x, x, x] + # end + # @test f(x_concrete) ≈ ones(3) + + # vcat + f = Reactant.compile((x_concrete,)) do x + return [x; x; x] + end + @test f(x_concrete) == ones(Bool, 3) + + # hcat + f = Reactant.compile((x_concrete,)) do x + return [x x x] + end + @test f(x_concrete) == ones(Bool, 1, 3) + + # hvcat + f = Reactant.compile((x_concrete,)) do x + return [x x x; x x x] + end + @test f(x_concrete) == ones(Bool, 2, 3) + + # typed_vcat + f = Reactant.compile((x_concrete,)) do x + return Int[x; x; x] + end + @test f(x_concrete) == ones(Int, 3) + + # typed_hcat + f = Reactant.compile((x_concrete,)) do x + return Int[x; x; x] + end + @test f(x_concrete) == ones(Int, 1, 3) + + # typed_hvcat + f = Reactant.compile((x_concrete,)) do x + return Int[x x x; x x x] + end + @test f(x_concrete) == ones(Int, 2, 3) + end + + @testset "1-dim" begin + x = ones(Bool, 2) + x_concrete = Reactant.to_rarray(x) + + # NOTE [,,,] is a call to `vect`, not `*cat` + # f = Reactant.compile((x_concrete,)) do x + # return [x, x, x] + # end + # @test f(x_concrete) ≈ ones(3) + + # vcat + f = Reactant.compile((x_concrete,)) do x + return [x; x; x] + end + @test f(x_concrete) == ones(Bool, 6) + + # hcat + f = Reactant.compile((x_concrete,)) do x + return [x x x] + end + @test f(x_concrete) == ones(Bool, 2, 3) + + # hvcat + f = Reactant.compile((x_concrete,)) do x + return [x x x; x x x] + end + @test f(x_concrete) == ones(Bool, 4, 3) + + # typed_vcat + f = Reactant.compile((x_concrete,)) do x + return Int[x; x; x] + end + @test f(x_concrete) == ones(Int, 6) + + # typed_hcat + f = Reactant.compile((x_concrete,)) do x + return Int[x; x; x] + end + @test f(x_concrete) == ones(Int, 2, 3) + + # typed_hvcat + f = Reactant.compile((x_concrete,)) do x + return Int[x x x; x x x] + end + @test f(x_concrete) == ones(Int, 4, 3) + end + x = ones(2, 4, 3) x_concrete = Reactant.to_rarray(x) From 2f2c3bb821b0d17c6a484055cd7d518721e4084a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 08:44:11 -0400 Subject: [PATCH 08/19] Test result eltype on `*cat` methods --- test/basic.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 38847db4b4..fc34dfb35d 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -225,36 +225,42 @@ end return [x; x; x] end @test f(x_concrete) == ones(Bool, 3) + @test eltype(f(x_concrete)) === Bool # hcat f = Reactant.compile((x_concrete,)) do x return [x x x] end @test f(x_concrete) == ones(Bool, 1, 3) + @test eltype(f(x_concrete)) === Bool # hvcat f = Reactant.compile((x_concrete,)) do x return [x x x; x x x] end @test f(x_concrete) == ones(Bool, 2, 3) + @test eltype(f(x_concrete)) === Bool # typed_vcat f = Reactant.compile((x_concrete,)) do x return Int[x; x; x] end @test f(x_concrete) == ones(Int, 3) + @test eltype(f(x_concrete)) === Int # typed_hcat f = Reactant.compile((x_concrete,)) do x return Int[x; x; x] end @test f(x_concrete) == ones(Int, 1, 3) + @test eltype(f(x_concrete)) === Int # typed_hvcat f = Reactant.compile((x_concrete,)) do x return Int[x x x; x x x] end @test f(x_concrete) == ones(Int, 2, 3) + @test eltype(f(x_concrete)) === Int end @testset "1-dim" begin @@ -272,36 +278,42 @@ end return [x; x; x] end @test f(x_concrete) == ones(Bool, 6) + @test eltype(f(x_concrete)) === Bool # hcat f = Reactant.compile((x_concrete,)) do x return [x x x] end @test f(x_concrete) == ones(Bool, 2, 3) + @test eltype(f(x_concrete)) === Bool # hvcat f = Reactant.compile((x_concrete,)) do x return [x x x; x x x] end @test f(x_concrete) == ones(Bool, 4, 3) + @test eltype(f(x_concrete)) === Bool # typed_vcat f = Reactant.compile((x_concrete,)) do x return Int[x; x; x] end @test f(x_concrete) == ones(Int, 6) + @test eltype(f(x_concrete)) === Int # typed_hcat f = Reactant.compile((x_concrete,)) do x return Int[x; x; x] end @test f(x_concrete) == ones(Int, 2, 3) + @test eltype(f(x_concrete)) === Int # typed_hvcat f = Reactant.compile((x_concrete,)) do x return Int[x x x; x x x] end @test f(x_concrete) == ones(Int, 4, 3) + @test eltype(f(x_concrete)) === Int end x = ones(2, 4, 3) From 0520fb77f852987d3deddecb2109366ac8981eda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 14:01:39 -0400 Subject: [PATCH 09/19] Fix conversion of integer arrays to `ConcreteRArray`s --- src/Tracing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index ae4f3b4c66..8e362fb3e7 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -380,7 +380,7 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: AbstractFloat + if mode == ArrayToConcrete && eltype(RT) <: Union{AbstractFloat, Integer} return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), (), Val(mode)) From a1efaccaee2f7ae5cecbb1e22e57766458faf7c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Sat, 5 Oct 2024 14:02:52 -0400 Subject: [PATCH 10/19] Format code Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Tracing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index 8e362fb3e7..b3233a6447 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -380,7 +380,7 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: Union{AbstractFloat, Integer} + if mode == ArrayToConcrete && eltype(RT) <: Union{AbstractFloat,Integer} return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), (), Val(mode)) From d572c76f0107affd5be9ea697f9356dc9d353ae2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 18:38:02 -0400 Subject: [PATCH 11/19] Fix `_typed_cat`, `_typed_hcat`, `typed_hvcat` dispatches --- src/TracedRArray.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 377cd07267..0726a056ff 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -764,6 +764,27 @@ end dispatch_val(x) = x dispatch_val(::Val{D}) where {D} = D +@inline function Base._typed_vcat( + ::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray} +) where {T} + return Base._cat_t(Val(1), T, X...) +end +@inline function Base._typed_hcat( + ::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray} +) where {T} + return Base._cat_t(Val(2), T, X...) +end + +# `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant +# generic implementation uses `typed_hcat` and `typed_vcat` which is alright +@inline function Base.typed_hvcat( + ::Type{T}, rows::Tuple{Vararg{Int}}, as::TracedRArray... +) where {T} + return invoke( + Base.typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as... + ) +end + function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} dims = dispatch_val(dims) @assert dims isa Integer "Support for non-integer dimensions is not implemented yet." From 4729f15ab2c3280d192489a42e0595b6ac518bb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 19:37:01 -0400 Subject: [PATCH 12/19] Fix `hvcat` --- src/TracedRArray.jl | 26 ++++++++++++++++++++++++++ test/basic.jl | 14 ++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 0726a056ff..cf756b1d7d 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -785,6 +785,32 @@ end ) end +function Base._typed_hvncat( + T::Type, dims::NTuple{N,Int}, row_first::Bool, as::TracedRArray... +) where {N} + As = if row_first + perm = [2, 1, 3:N...] + dims = [dims[2], dims[1], dims[3:end]...] + permutedims(reshape(collect(as), dims...), perm) + else + reshape(collect(as), dims) + end + + for d in 1:N + Bs = Array{Any,N - d}(undef, size(As)[2:end]...) + + for (i, col) in + zip(eachindex(Bs), eachslice(As; dims=Tuple(2:ndims(As)), drop=true)) + # TODO row_first affects the flattening? + Bs[i] = Base._cat_t(d, T, col...) + end + + As = Bs + end + + return only(As) +end + function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} dims = dispatch_val(dims) @assert dims isa Integer "Support for non-integer dimensions is not implemented yet." diff --git a/test/basic.jl b/test/basic.jl index fc34dfb35d..54dce33fd1 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -241,6 +241,13 @@ end @test f(x_concrete) == ones(Bool, 2, 3) @test eltype(f(x_concrete)) === Bool + # hvncat + f = Reactant.compile((x_concrete,)) do x + return [x x x; x x x;;; x x x; x x x] + end + @test f(x_concrete) == ones(Bool, 2, 3, 2) + @test eltype(f(x_concrete)) === Bool + # typed_vcat f = Reactant.compile((x_concrete,)) do x return Int[x; x; x] @@ -261,6 +268,13 @@ end end @test f(x_concrete) == ones(Int, 2, 3) @test eltype(f(x_concrete)) === Int + + # typed_hvncat + f = Reactant.compile((x_concrete,)) do x + return Int[x x x; x x x;;; x x x; x x x] + end + @test f(x_concrete) == ones(Int, 2, 3, 2) + @test eltype(f(x_concrete)) === Int end @testset "1-dim" begin From d85e1c4dcf3e6f96ec84a5a6c0c4df8b77a2a821 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 19:42:12 -0400 Subject: [PATCH 13/19] Convert to target eltype before cat --- src/TracedRArray.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index cf756b1d7d..2780179098 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -821,6 +821,10 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} catdims = Base.dims2cat(dims) shape = Base.cat_size_shape(catdims, X...) RT = Base.promote_eltype(T, X...) + + # convert to the target eltype + X = map(Base.Fix1(promote_to, TracedRArray{RT,length(shape)}), X) + return TracedRArray{RT,length(shape)}( (), MLIR.IR.result( From 7d28fa0b8e0dc8083c9080716de068985eaf87e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 19:42:26 -0400 Subject: [PATCH 14/19] Fix `typed_hcat` tests --- test/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/basic.jl b/test/basic.jl index 54dce33fd1..6501e0b014 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -257,7 +257,7 @@ end # typed_hcat f = Reactant.compile((x_concrete,)) do x - return Int[x; x; x] + return Int[x x x] end @test f(x_concrete) == ones(Int, 1, 3) @test eltype(f(x_concrete)) === Int @@ -317,7 +317,7 @@ end # typed_hcat f = Reactant.compile((x_concrete,)) do x - return Int[x; x; x] + return Int[x x x] end @test f(x_concrete) == ones(Int, 2, 3) @test eltype(f(x_concrete)) === Int From 0975d431ee713a574b5c5c53815bc41472b979ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 19:44:35 -0400 Subject: [PATCH 15/19] Test `typed_hvncat` on vectors --- test/basic.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 6501e0b014..7ba38db80e 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -308,6 +308,13 @@ end @test f(x_concrete) == ones(Bool, 4, 3) @test eltype(f(x_concrete)) === Bool + # hvncat + f = Reactant.compile((x_concrete,)) do x + return [x x x; x x x;;; x x x; x x x] + end + @test f(x_concrete) == ones(Bool, 4, 3, 2) + @test eltype(f(x_concrete)) === Bool + # typed_vcat f = Reactant.compile((x_concrete,)) do x return Int[x; x; x] @@ -328,6 +335,13 @@ end end @test f(x_concrete) == ones(Int, 4, 3) @test eltype(f(x_concrete)) === Int + + # typed_hvncat + f = Reactant.compile((x_concrete,)) do x + return Int[x x x; x x x;;; x x x; x x x] + end + @test f(x_concrete) == ones(Int, 4, 3, 2) + @test eltype(f(x_concrete)) === Int end x = ones(2, 4, 3) From 882ace092b87d8d6981b5d85cbfc9082e90bee25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 19:57:22 -0400 Subject: [PATCH 16/19] Refactor tests --- test/basic.jl | 127 +++++++++++--------------------------------------- 1 file changed, 26 insertions(+), 101 deletions(-) diff --git a/test/basic.jl b/test/basic.jl index 7ba38db80e..9bb1b3c95b 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -210,8 +210,8 @@ end end @testset "concatenation" begin - @testset "0-dim" begin - x = fill(true) + @testset "$(ndims(x))-dim" for x in [fill(true), fill(true, 2)] + x = [true, false] x_concrete = Reactant.to_rarray(x) # NOTE [,,,] is a call to `vect`, not `*cat` @@ -221,126 +221,51 @@ end # @test f(x_concrete) ≈ ones(3) # vcat - f = Reactant.compile((x_concrete,)) do x - return [x; x; x] - end - @test f(x_concrete) == ones(Bool, 3) + g(x) = [x; x; x] + f = @compile g(x_concrete) + @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Bool # hcat - f = Reactant.compile((x_concrete,)) do x - return [x x x] - end - @test f(x_concrete) == ones(Bool, 1, 3) + g(x) = [x x x] + f = @compile g(x_concrete) + @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Bool # hvcat - f = Reactant.compile((x_concrete,)) do x - return [x x x; x x x] - end - @test f(x_concrete) == ones(Bool, 2, 3) + g(x) = [x x x; x x x] + f = @compile g(x_concrete) + @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Bool # hvncat - f = Reactant.compile((x_concrete,)) do x - return [x x x; x x x;;; x x x; x x x] - end - @test f(x_concrete) == ones(Bool, 2, 3, 2) + g(x) = [x x x; x x x;;; x x x; x x x] + f = @compile g(x_concrete) + @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Bool # typed_vcat - f = Reactant.compile((x_concrete,)) do x - return Int[x; x; x] - end - @test f(x_concrete) == ones(Int, 3) + g(x) = Int[x; x; x] + f = @compile g(x_concrete) + @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Int # typed_hcat - f = Reactant.compile((x_concrete,)) do x - return Int[x x x] - end - @test f(x_concrete) == ones(Int, 1, 3) + g(x) = Int[x x x] + f = @compile g(x_concrete) + @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Int # typed_hvcat - f = Reactant.compile((x_concrete,)) do x - return Int[x x x; x x x] - end - @test f(x_concrete) == ones(Int, 2, 3) + g(x) = Int[x x x; x x x] + f = @compile g(x_concrete) + @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Int # typed_hvncat - f = Reactant.compile((x_concrete,)) do x - return Int[x x x; x x x;;; x x x; x x x] - end - @test f(x_concrete) == ones(Int, 2, 3, 2) - @test eltype(f(x_concrete)) === Int - end - - @testset "1-dim" begin - x = ones(Bool, 2) - x_concrete = Reactant.to_rarray(x) - - # NOTE [,,,] is a call to `vect`, not `*cat` - # f = Reactant.compile((x_concrete,)) do x - # return [x, x, x] - # end - # @test f(x_concrete) ≈ ones(3) - - # vcat - f = Reactant.compile((x_concrete,)) do x - return [x; x; x] - end - @test f(x_concrete) == ones(Bool, 6) - @test eltype(f(x_concrete)) === Bool - - # hcat - f = Reactant.compile((x_concrete,)) do x - return [x x x] - end - @test f(x_concrete) == ones(Bool, 2, 3) - @test eltype(f(x_concrete)) === Bool - - # hvcat - f = Reactant.compile((x_concrete,)) do x - return [x x x; x x x] - end - @test f(x_concrete) == ones(Bool, 4, 3) - @test eltype(f(x_concrete)) === Bool - - # hvncat - f = Reactant.compile((x_concrete,)) do x - return [x x x; x x x;;; x x x; x x x] - end - @test f(x_concrete) == ones(Bool, 4, 3, 2) - @test eltype(f(x_concrete)) === Bool - - # typed_vcat - f = Reactant.compile((x_concrete,)) do x - return Int[x; x; x] - end - @test f(x_concrete) == ones(Int, 6) - @test eltype(f(x_concrete)) === Int - - # typed_hcat - f = Reactant.compile((x_concrete,)) do x - return Int[x x x] - end - @test f(x_concrete) == ones(Int, 2, 3) - @test eltype(f(x_concrete)) === Int - - # typed_hvcat - f = Reactant.compile((x_concrete,)) do x - return Int[x x x; x x x] - end - @test f(x_concrete) == ones(Int, 4, 3) - @test eltype(f(x_concrete)) === Int - - # typed_hvncat - f = Reactant.compile((x_concrete,)) do x - return Int[x x x; x x x;;; x x x; x x x] - end - @test f(x_concrete) == ones(Int, 4, 3, 2) + g(x) = Int[x x x; x x x;;; x x x; x x x] + f = @compile g(x_concrete) + @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Int end From 91a756013574155f36b66e5dc4b0265ca39ac1c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 20:04:16 -0400 Subject: [PATCH 17/19] Add more test cases --- test/basic.jl | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/test/basic.jl b/test/basic.jl index 9bb1b3c95b..d665679782 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -210,8 +210,17 @@ end end @testset "concatenation" begin - @testset "$(ndims(x))-dim" for x in [fill(true), fill(true, 2)] - x = [true, false] + @testset "$(ndims(x))-dim" for x in [ + fill(true), + [true, false], + [true false], + [true true; true false], + [ + true true true true; true true true false;;; + true true false true; true true false false;;; + true false true true; true false true false + ], + ] x_concrete = Reactant.to_rarray(x) # NOTE [,,,] is a call to `vect`, not `*cat` @@ -268,21 +277,6 @@ end @test f(x_concrete) == g(x) @test eltype(f(x_concrete)) === Int end - - x = ones(2, 4, 3) - x_concrete = Reactant.to_rarray(x) - - cat1(x) = vcat(x, x, x) - cat2(x) = hcat(x, x, x) - cat3(x) = cat(x, x, x; dims=Val(3)) - - cat1_compiled = @compile cat1(x_concrete) - cat2_compiled = @compile cat2(x_concrete) - cat3_compiled = @compile cat3(x_concrete) - - @test cat1(x) ≈ cat1_compiled(x_concrete) - @test cat2(x) ≈ cat2_compiled(x_concrete) - @test cat3(x) ≈ cat3_compiled(x_concrete) end function update_on_copy(x) From 40734c085008096458c101668d6aecfa87faf45c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 20:05:02 -0400 Subject: [PATCH 18/19] Refactor tests --- test/basic.jl | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/test/basic.jl b/test/basic.jl index d665679782..5f14233e59 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -230,51 +230,51 @@ end # @test f(x_concrete) ≈ ones(3) # vcat - g(x) = [x; x; x] - f = @compile g(x_concrete) - @test f(x_concrete) == g(x) + test_vcat(x) = [x; x; x] + f = @compile test_vcat(x_concrete) + @test f(x_concrete) == test_vcat(x) @test eltype(f(x_concrete)) === Bool # hcat - g(x) = [x x x] - f = @compile g(x_concrete) - @test f(x_concrete) == g(x) + test_hcat(x) = [x x x] + f = @compile test_hcat(x_concrete) + @test f(x_concrete) == test_hcat(x) @test eltype(f(x_concrete)) === Bool # hvcat - g(x) = [x x x; x x x] - f = @compile g(x_concrete) - @test f(x_concrete) == g(x) + test_hvcat(x) = [x x x; x x x] + f = @compile test_hvcat(x_concrete) + @test f(x_concrete) == test_hvcat(x) @test eltype(f(x_concrete)) === Bool # hvncat - g(x) = [x x x; x x x;;; x x x; x x x] - f = @compile g(x_concrete) - @test f(x_concrete) == g(x) + test_hvncat(x) = [x x x; x x x;;; x x x; x x x] + f = @compile test_hvncat(x_concrete) + @test f(x_concrete) == test_hvncat(x) @test eltype(f(x_concrete)) === Bool # typed_vcat - g(x) = Int[x; x; x] - f = @compile g(x_concrete) - @test f(x_concrete) == g(x) + test_typed_vcat(x) = Int[x; x; x] + f = @compile test_typed_vcat(x_concrete) + @test f(x_concrete) == test_typed_vcat(x) @test eltype(f(x_concrete)) === Int # typed_hcat - g(x) = Int[x x x] - f = @compile g(x_concrete) - @test f(x_concrete) == g(x) + test_typed_hcat(x) = Int[x x x] + f = @compile test_typed_hcat(x_concrete) + @test f(x_concrete) == test_typed_hcat(x) @test eltype(f(x_concrete)) === Int # typed_hvcat - g(x) = Int[x x x; x x x] - f = @compile g(x_concrete) - @test f(x_concrete) == g(x) + test_typed_hvcat(x) = Int[x x x; x x x] + f = @compile test_typed_hvcat(x_concrete) + @test f(x_concrete) == test_typed_hvcat(x) @test eltype(f(x_concrete)) === Int # typed_hvncat - g(x) = Int[x x x; x x x;;; x x x; x x x] - f = @compile g(x_concrete) - @test f(x_concrete) == g(x) + test_hvncat(x) = Int[x x x; x x x;;; x x x; x x x] + f = @compile test_hvncat(x_concrete) + @test f(x_concrete) == test_hvncat(x) @test eltype(f(x_concrete)) === Int end end From 601c867cf68ebb609d467db2b70566c54be781ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 5 Oct 2024 20:13:50 -0400 Subject: [PATCH 19/19] Fix typo --- test/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/basic.jl b/test/basic.jl index 5f14233e59..368f22d728 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -272,9 +272,9 @@ end @test eltype(f(x_concrete)) === Int # typed_hvncat - test_hvncat(x) = Int[x x x; x x x;;; x x x; x x x] - f = @compile test_hvncat(x_concrete) - @test f(x_concrete) == test_hvncat(x) + test_typed_hvncat(x) = Int[x x x; x x x;;; x x x; x x x] + f = @compile test_typed_hvncat(x_concrete) + @test f(x_concrete) == test_typed_hvncat(x) @test eltype(f(x_concrete)) === Int end end