From 002d767c0c47c0d86751fc74e60e1c281abc3aba Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Tue, 27 May 2025 17:27:04 +0200 Subject: [PATCH 1/6] First version of NNlib.scatter code & tests --- ext/ReactantNNlibExt/Implementations.jl | 119 ++++++++++++ test/nn/nnlib.jl | 229 ++++++++++++++++++++++++ 2 files changed, 348 insertions(+) diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index 79cd58723a..1a170da540 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -472,3 +472,122 @@ function _nnlib_gather_impl(src::AnyTracedRArray, idxs::AbstractArray, n_dims::I slice_sizes=Int64[size(src)[1:n_dims]..., ones(Int64, ndims(src) - n_dims)...], ) end + +# Scatter +for OP in (+, -, max, min, *, /) + @eval function NNlib.scatter( + op::typeof($OP), + src::AnyTracedRArray{T}, + idx::AbstractArray; + init=nothing, + dstsize=nothing, + ) where {T} + dims = ndims(src) - ndims(idx) + dstsz = if isnothing(dstsize) + (size(src)[1:dims]..., NNlib.maximum_dims(idx)...) + else + dstsize + end + xinit = isnothing(init) ? NNlib.scatter_empty(op, T) : init + dst = Ops.fill(xinit, dstsz) + + NNlib.scatter!(op, dst, src, idx) + return dst + end + + @eval function NNlib.scatter!( + op::typeof($OP), + dst::AnyTracedRArray{T}, + src::AnyTracedRArray{T}, + idx::AbstractArray, + ) where {T} + NNlib.scatter_dims(dst, src, idx) + res = _nnlib_scatter_impl(op, dst, src, transpose(_stack_indices(idx))) + res = Ops.reshape(res, size(dst)...) + set_mlir_data!(dst, get_mlir_data(res)) + return dst + end + + @eval function NNlib.scatter!( + op::typeof($OP), + dst::AnyTracedRArray{T}, + src::AnyTracedRArray{T}, + idxs::AbstractArray{<:Number}, + ) where {T} + NNlib.scatter_dims(dst, src, idxs) + idx = ndims(idxs) == 1 ? reshape(idxs, size(idxs)..., 1) : idxs + res = _nnlib_scatter_impl(op, dst, src, idx) + res = Ops.reshape(res, size(dst)...) + set_mlir_data!(dst, get_mlir_data(res)) + return dst + end + + @eval function _nnlib_scatter_impl( + op::typeof($OP), + dst::AnyTracedRArray{T}, + src::AnyTracedRArray{T}, + idx::AbstractArray, + ) where {T} + inputs = if ndims(dst) == 1 + Ops.reshape(dst, length(dst), 1) + else + Ops.transpose(dst, Int64[ndims(dst):-1:1...]) + end + scatter_indices = TracedUtils.promote_to(TracedRArray{Int,ndims(idx)}, idx) + updates = if ndims(src) == 1 + Ops.reshape(src, length(src), 1) + else + Ops.transpose(src, Int64[ndims(src):-1:1...]) + end + + sample_inputs = [ + TracedUtils.promote_to(TracedRNumber{T}, 0), + TracedUtils.promote_to(TracedRNumber{T}, 0), + ] + + func = + TracedUtils.make_mlir_fn( + op, + (sample_inputs), + (), + "scatter_reduce_fn", + false; + args_in_result=:result, + return_dialect=:stablehlo, + ).f + update_computation = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(update_computation, MLIR.IR.region(func, 1)) + MLIR.IR.rmfromparent!(func) + + update_window_dims = Int64[2] + inserted_window_dims = Int64[1] + input_batching_dims = Int64[] + scatter_indices_batching_dims = Int64[] + scatter_dims_to_operand_dims = Int64[1] + index_vector_dim = Int64(2) + + scatter_res = Ops.scatter( + [inputs], + scatter_indices, + [updates]; + update_computation=update_computation, + update_window_dims=update_window_dims, + inserted_window_dims=inserted_window_dims, + input_batching_dims=input_batching_dims, + scatter_indices_batching_dims=scatter_indices_batching_dims, + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, + index_vector_dim=index_vector_dim, + )[1] + return Ops.transpose(scatter_res, Int64[ndims(scatter_res):-1:1...]) + end +end + +function NNlib.maximum_dims(dims::AnyTracedRArray{<:Integer}) + return (maximum(dims),) +end +function NNlib.maximum_dims(dims::AnyTracedRArray{NTuple{N,T}}) where {N,T} + return ntuple(i -> maximum(x -> x[i], dims), N) +end +function NNlib.maximum_dims(dims::AnyTracedRArray{CartesianIndex{N}}) where {N} + return ntuple(i -> maximum(x -> x[i], dims), N) +end diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 3fb23e559e..2ade10dbc1 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -1,4 +1,5 @@ using NNlib, Reactant, Enzyme +using Statistics @testset "Activation Functions" begin sumabs2(f, x) = sum(abs2, f.(x)) @@ -381,6 +382,234 @@ end end end +function test_scatter(dsts, srcs, idxs, res; dims) + @testset "scatter Float32 $op" for op in (+, -, max, min, *, /, mean) + for idx in values(idxs), dim in dims + dst = copy(dsts[dim]) + target_y = res[(op, dim, true)] + src = srcs[(dim, true)] + if op == / + src = src .* 2.0f0 + end + + y1 = @jit( + NNlib.scatter!(op, Reactant.to_rarray(dst), Reactant.to_rarray(src), idx) + ) + @test y1 ≈ target_y + @test y1 isa ConcreteRArray{Float32,ndims(dst)} + @test size(y1) == size(dsts[dim]) + dst = copy(dsts[dim]) + y2 = @jit( + NNlib.scatter!( + op, + Reactant.to_rarray(dst), + Reactant.to_rarray(src), + Reactant.to_rarray(idx), + ) + ) + @test y2 ≈ target_y + @test y2 isa ConcreteRArray{Float32,ndims(dst)} + @test size(y2) == size(dsts[dim]) + + target_y = res[(op, dim, false)] + src = srcs[(dim, false)] + if op == / + src = src .* 2.0f0 + end + + y3 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), idx)) + @test y3 ≈ target_y + @test y3 isa ConcreteRArray{Float32,ndims(dst)} + @test size(y3) == size(dsts[dim]) + # y4 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), Reactant.to_rarray(idx))) + # @test y4 ≈ target_y + # @test y4 isa ConcreteRArray{Float32,ndims(dst)} + # @test size(y4) == size(dsts[dim]) + end + end +end + +# Adapted from https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/test/testsuite/scatter.jl#L108 +@testset "NNlib scatter" begin + @testset "scatter 1d src, 1d idx => 1d output" begin + #! format: off + dsts = Dict( + 0 => Float32[3, 4, 5, 6, 7] + ) + + srcs = Dict( + (0, true) => ones(Float32, 5), + (0, false) => collect(Float32, 1:5), + ) + + idxs = Dict( + :int => [4, 2, 1, 5, 3], + :tup => [(4,), (2,), (1,), (5,), (3,)], + :car => CartesianIndex.([(4,), (2,), (1,), (5,), (3,)]), + ) + + res = Dict( + (+, 0, true) => Float32[4, 5, 6, 7, 8], + (+, 0, false) => Float32[3, 2, 5, 1, 4], + + (-, 0, true) => Float32[2, 3, 4, 5, 6], + (-, 0, false) => Float32[-3, -2, -5, -1, -4], + + (max, 0, true) => Float32[3, 4, 5, 6, 7], + (max, 0, false) => Float32[3, 2, 5, 1, 4], + + (min, 0, true) => Float32[1, 1, 1, 1, 1], + (min, 0, false) => Float32[3, 2, 5, 1, 4], + + (*, 0, true) => Float32[3, 4, 5, 6, 7], + (*, 0, false) => Float32[3, 2, 5, 1, 4], + + (/, 0, true) => Float32[1.5, 2.0, 2.5, 3.0, 3.5], + (/, 0, false) => Float32[1//6, 1//4, 1//10, 1//2, 1//8], + + (mean, 0, true) => Float32[4, 5, 6, 7, 8], + (mean, 0, false) => Float32[3, 2, 5, 1, 4], + ) + #! format: on + test_scatter(dsts, srcs, idxs, res; dims=[0]) + end + + @testset "scatter 2d src, 1d idx => 2d output" begin + #! format: off + dsts = Dict( + 0 => Float32[3 3 4 4 5 + 5 5 6 6 7] + ) + + srcs = Dict( + (0, true) => ones(Float32, 2, 5), + (0, false) => ones(Float32, 2) * collect(1:5)', + ) + + idxs = Dict( + :int => [4, 2, 1, 5, 3], + :tup => [(4,), (2,), (1,), (5,), (3,)], + :car => CartesianIndex.([(4,), (2,), (1,), (5,), (3,)]), + ) + + res = Dict( + (+, 0, true) => Float32[4 4 5 5 6; + 6 6 7 7 8], + (+, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + + (-, 0, true) => Float32[2 2 3 3 4; + 4 4 5 5 6], + (-, 0, false) => Float32[-3 -2 -5 -1 -4; + -3 -2 -5 -1 -4], + + (max, 0, true) => Float32[3 3 4 4 5; + 5 5 6 6 7], + (max, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + + (min, 0, true) => Float32[1 1 1 1 1; + 1 1 1 1 1], + (min, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + + (*, 0, true) => Float32[3 3 4 4 5; + 5 5 6 6 7], + (*, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + + (/, 0, true) => Float32[1.5 1.5 2.0 2.0 2.5; + 2.5 2.5 3.0 3.0 3.5], + (/, 0, false) => Float32[1//6 1//4 1//10 1//2 1//8; + 1//6 1//4 1//10 1//2 1//8], + + (mean, 0, true) => Float32[4 4 5 5 6; + 6 6 7 7 8], + (mean, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + ) + #! format: on + test_scatter(dsts, srcs, idxs, res; dims=[0]) + end + + @testset "scatter 2d+3d src, 2d idx => 1d+2d output" begin + #! format: off + dsts = Dict( + 0 => Float32[3, 4, 5, 6, 7], + 1 => Float32[3 3 4 4 5; + 5 5 6 6 7], + ) + + srcs = Dict( + (0, true) => ones(Float32, 3, 4), + (0, false) => ones(Float32, 3) * collect(1:4)', + (1, true) => ones(Float32, 2, 3, 4), + (1, false) => Float32[1, 2] .* reshape(ones(Float32, 3) * collect(1:4)', 1,3,4), + ) + + idxs = Dict( + :int => [1 2 3 4; + 4 2 1 3; + 3 5 5 3], + :tup => [(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)], + :car => CartesianIndex.( + [(1,) (2,) (3,) (4,); + (4,) (2,) (1,) (3,); + (3,) (5,) (5,) (3,)]), + ) + + res = Dict( + (+, 0, true) => Float32[5, 6, 9, 8, 9], + (+, 1, true) => Float32[5 5 8 6 7; + 7 7 10 8 9], + (+, 0, false) => Float32[4, 4, 12, 5, 5], + (+, 1, false) => Float32[4 4 12 5 5; + 8 8 24 10 10], + (-, 0, true) => Float32[1, 2, 1, 4, 5], + (-, 1, true) => Float32[1 1 0 2 3; + 3 3 2 4 5], + (-, 0, false) => Float32[-4, -4, -12, -5, -5], + (-, 1, false) => Float32[-4 -4 -12 -5 -5; + -8 -8 -24 -10 -10], + (max, 0, true) => Float32[3, 4, 5, 6, 7], + (max, 1, true) => Float32[3 3 4 4 5; + 5 5 6 6 7], + (max, 0, false) => Float32[3, 2, 4, 4, 3], + (max, 1, false) => Float32[3 2 4 4 3; + 6 4 8 8 6], + (min, 0, true) => Float32[1, 1, 1, 1, 1], + (min, 1, true) => Float32[1 1 1 1 1; + 1 1 1 1 1], + (min, 0, false) => Float32[1, 2, 1, 1, 2], + (min, 1, false) => Float32[1 2 1 1 2; + 2 4 2 2 4], + (*, 0, true) => Float32[3, 4, 5, 6, 7], + (*, 1, true) => Float32[3 3 4 4 5; + 5 5 6 6 7], + (*, 0, false) => Float32[3, 4, 48, 4, 6], + (*, 1, false) => Float32[3 4 48 4 6; + 12 16 768 16 24], + (/, 0, true) => Float32[0.75, 1., 0.3125, 1.5, 1.75], + (/, 1, true) => Float32[0.75 0.75 0.25 1. 1.25; + 1.25 1.25 0.375 1.5 1.75], + (/, 0, false) => Float32[1//3, 1//4, 1//48, 1//4, 1//6], + (/, 1, false) => Float32[1//3 1//4 1//48 1//4 1//6; + 1//12 1//16 1//768 1//16 1//24], + (mean, 0, true) => Float32[4., 5., 6., 7., 8.], + (mean, 1, true) => Float32[4. 4. 5. 5. 6.; + 6. 6. 7. 7. 8.], + (mean, 0, false) => Float32[2, 2, 3, 2.5, 2.5], + (mean, 1, false) => Float32[2. 2. 3. 2.5 2.5; + 4. 4. 6. 5. 5.], + ) + #! format: on + + test_scatter(dsts, srcs, idxs, res; dims=[0]) + end +end + @testset "∇conv(D = $ndim)" for ndim in 1:3 x_spatial_dim = 4 batch_size = 2 From 437b88bff8af3cbd8b4ed09d3f3e2bf2ba6e1ded Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Tue, 3 Jun 2025 16:43:01 +0200 Subject: [PATCH 2/6] Support for higher scatter dims + refactoring --- ext/ReactantNNlibExt/Implementations.jl | 203 ++++++++++----------- test/nn/nnlib.jl | 226 +++++++++++++----------- 2 files changed, 228 insertions(+), 201 deletions(-) diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index 64bed1a43a..9d576a8cc2 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -476,112 +476,117 @@ function _nnlib_gather_impl(src::AnyTracedRArray, idxs::AbstractArray, n_dims::I end # Scatter -for OP in (+, -, max, min, *, /) - @eval function NNlib.scatter( - op::typeof($OP), - src::AnyTracedRArray{T}, - idx::AbstractArray; - init=nothing, - dstsize=nothing, - ) where {T} - dims = ndims(src) - ndims(idx) - dstsz = if isnothing(dstsize) - (size(src)[1:dims]..., NNlib.maximum_dims(idx)...) - else - dstsize - end - xinit = isnothing(init) ? NNlib.scatter_empty(op, T) : init - dst = Ops.fill(xinit, dstsz) - - NNlib.scatter!(op, dst, src, idx) - return dst +# The mean function currently produces an ambiguity due to +# https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/src/scatter.jl#L85 +# This could be resolved by an explicit dispatch which would require Statistics as dependency +function NNlib.scatter( + op::OP, src::AnyTracedRArray{T}, idx::AbstractArray; init=nothing, dstsize=nothing +) where {OP,T} + dims = ndims(src) - ndims(idx) + dstsz = if isnothing(dstsize) + (size(src)[1:dims]..., NNlib.maximum_dims(idx)...) + else + dstsize end - - @eval function NNlib.scatter!( - op::typeof($OP), - dst::AnyTracedRArray{T}, - src::AnyTracedRArray{T}, - idx::AbstractArray, - ) where {T} - NNlib.scatter_dims(dst, src, idx) - res = _nnlib_scatter_impl(op, dst, src, transpose(_stack_indices(idx))) - res = Ops.reshape(res, size(dst)...) - set_mlir_data!(dst, get_mlir_data(res)) - return dst + if any(d -> d isa TracedRNumber, dstsz) + throw( + ArgumentError( + "dstsize must be specified when idx is a TracedRArray or contains a TracedRNumber.", + ), + ) end + xinit = isnothing(init) ? NNlib.scatter_empty(op, T) : init + dst = Ops.fill(xinit, dstsz) - @eval function NNlib.scatter!( - op::typeof($OP), - dst::AnyTracedRArray{T}, - src::AnyTracedRArray{T}, - idxs::AbstractArray{<:Number}, - ) where {T} - NNlib.scatter_dims(dst, src, idxs) - idx = ndims(idxs) == 1 ? reshape(idxs, size(idxs)..., 1) : idxs - res = _nnlib_scatter_impl(op, dst, src, idx) - res = Ops.reshape(res, size(dst)...) - set_mlir_data!(dst, get_mlir_data(res)) - return dst - end + NNlib.scatter!(op, dst, src, idx) + return dst +end - @eval function _nnlib_scatter_impl( - op::typeof($OP), - dst::AnyTracedRArray{T}, - src::AnyTracedRArray{T}, - idx::AbstractArray, - ) where {T} - inputs = if ndims(dst) == 1 - Ops.reshape(dst, length(dst), 1) - else - Ops.transpose(dst, Int64[ndims(dst):-1:1...]) - end - scatter_indices = TracedUtils.promote_to(TracedRArray{Int,ndims(idx)}, idx) - updates = if ndims(src) == 1 - Ops.reshape(src, length(src), 1) +function NNlib.scatter!( + op::OP, dst::AnyTracedRArray, src::AnyTracedRArray, idx::AbstractArray +) where {OP} + dims = NNlib.scatter_dims(dst, src, idx) + idx = reshape(_stack_indices(idx), prod(size(idx)), 1) + res = _nnlib_scatter_impl(op, dst, src, idx, dims) + res = Ops.reshape(res, size(dst)...) + set_mlir_data!(dst, get_mlir_data(res)) + return dst +end + +function NNlib.scatter!( + op::OP, dst::AnyTracedRArray, src::AnyTracedRArray, idx::AbstractArray{<:Number} +) where {OP} + dims = NNlib.scatter_dims(dst, src, idx) + idx = reshape(idx, prod(size(idx)), 1) + res = _nnlib_scatter_impl(op, dst, src, idx, dims) + res = Ops.reshape(res, size(dst)...) + set_mlir_data!(dst, get_mlir_data(res)) + return dst +end + +function _nnlib_scatter_impl( + op::OP, + dst::AnyTracedRArray{T}, + src::AnyTracedRArray{T}, + idx::AbstractArray, + n_dims::Int, +) where {OP,T} + inputs = if ndims(dst) == 1 + Ops.reshape(dst, length(dst), 1) + else + Ops.transpose(dst, Int64[ndims(dst):-1:1...]) + end + scatter_indices = TracedUtils.promote_to(TracedRArray{Int,ndims(idx)}, idx) + updates = if ndims(src) == 1 + Ops.reshape(src, length(src), 1) + else + if n_dims == 0 + Ops.reshape(src, size(idx, 1), 1) else - Ops.transpose(src, Int64[ndims(src):-1:1...]) + src1 = Ops.reshape(src, size(src)[1:n_dims]..., prod(size(src)[(n_dims + 1):end])) + Ops.transpose(src1, Int64[ndims(src1):-1:1...]) end - - sample_inputs = [ - TracedUtils.promote_to(TracedRNumber{T}, 0), - TracedUtils.promote_to(TracedRNumber{T}, 0), - ] - - func = - TracedUtils.make_mlir_fn( - op, - (sample_inputs), - (), - "scatter_reduce_fn", - false; - args_in_result=:result, - return_dialect=:stablehlo, - ).f - update_computation = MLIR.IR.Region() - MLIR.API.mlirRegionTakeBody(update_computation, MLIR.IR.region(func, 1)) - MLIR.IR.rmfromparent!(func) - - update_window_dims = Int64[2] - inserted_window_dims = Int64[1] - input_batching_dims = Int64[] - scatter_indices_batching_dims = Int64[] - scatter_dims_to_operand_dims = Int64[1] - index_vector_dim = Int64(2) - - scatter_res = Ops.scatter( - [inputs], - scatter_indices, - [updates]; - update_computation=update_computation, - update_window_dims=update_window_dims, - inserted_window_dims=inserted_window_dims, - input_batching_dims=input_batching_dims, - scatter_indices_batching_dims=scatter_indices_batching_dims, - scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, - index_vector_dim=index_vector_dim, - )[1] - return Ops.transpose(scatter_res, Int64[ndims(scatter_res):-1:1...]) end + + sample_inputs = [ + TracedUtils.promote_to(TracedRNumber{T}, 0), + TracedUtils.promote_to(TracedRNumber{T}, 0), + ] + + func = + TracedUtils.make_mlir_fn( + op, + (sample_inputs), + (), + "scatter_reduce_fn", + false; + args_in_result=:result, + return_dialect=:stablehlo, + ).f + update_computation = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(update_computation, MLIR.IR.region(func, 1)) + MLIR.IR.rmfromparent!(func) + + update_window_dims = Int64[2] + inserted_window_dims = Int64[1] + input_batching_dims = Int64[] + scatter_indices_batching_dims = Int64[] + scatter_dims_to_operand_dims = Int64[1] + index_vector_dim = Int64(2) + + scatter_res = Ops.scatter( + [inputs], + scatter_indices, + [updates]; + update_computation=update_computation, + update_window_dims=update_window_dims, + inserted_window_dims=inserted_window_dims, + input_batching_dims=input_batching_dims, + scatter_indices_batching_dims=scatter_indices_batching_dims, + scatter_dims_to_operand_dims=scatter_dims_to_operand_dims, + index_vector_dim=index_vector_dim, + )[1] + return Ops.transpose(scatter_res, Int64[ndims(scatter_res):-1:1...]) end function NNlib.maximum_dims(dims::AnyTracedRArray{<:Integer}) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index c686634d1f..06d4398b29 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -382,56 +382,78 @@ end end end -function test_scatter(dsts, srcs, idxs, res; dims) - @testset "scatter Float32 $op" for op in (+, -, max, min, *, /, mean) - for idx in values(idxs), dim in dims - dst = copy(dsts[dim]) - target_y = res[(op, dim, true)] - src = srcs[(dim, true)] - if op == / - src = src .* 2.0f0 - end - - y1 = @jit( - NNlib.scatter!(op, Reactant.to_rarray(dst), Reactant.to_rarray(src), idx) - ) - @test y1 ≈ target_y - @test y1 isa ConcreteRArray{Float32,ndims(dst)} - @test size(y1) == size(dsts[dim]) - dst = copy(dsts[dim]) - y2 = @jit( - NNlib.scatter!( - op, - Reactant.to_rarray(dst), - Reactant.to_rarray(src), - Reactant.to_rarray(idx), +# Adapted from https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/test/testsuite/scatter.jl#L108 +# mean is omitted as operation to avoid the ambiguity error +@testset "NNlib scatter" begin + function test_scatter(dsts, srcs, idxs, res; dims) + @testset "scatter Float32 $op" for op in (+, -, max, min, *, /) + for idx in values(idxs), dim in dims + dst = copy(dsts[dim]) + target_y = res[(op, dim, true)] + src = srcs[(dim, true)] + if op == / + src = src .* 2.0f0 + end + + y1 = @jit( + NNlib.scatter!( + op, Reactant.to_rarray(dst), Reactant.to_rarray(src), idx + ) ) - ) - @test y2 ≈ target_y - @test y2 isa ConcreteRArray{Float32,ndims(dst)} - @test size(y2) == size(dsts[dim]) - - target_y = res[(op, dim, false)] - src = srcs[(dim, false)] - if op == / - src = src .* 2.0f0 + @test y1 ≈ target_y + @test y1 isa ConcreteRArray{Float32,ndims(dst)} + @test size(y1) == size(dsts[dim]) + dst = copy(dsts[dim]) + y2 = @jit( + NNlib.scatter!( + op, + Reactant.to_rarray(dst), + Reactant.to_rarray(src), + Reactant.to_rarray(idx), + ) + ) + @test y2 ≈ target_y + @test y2 isa ConcreteRArray{Float32,ndims(dst)} + @test size(y2) == size(dsts[dim]) + + target_y = res[(op, dim, false)] + src = srcs[(dim, false)] + if op == / + src = src .* 2.0f0 + end + + y3 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), idx)) + @test y3 ≈ target_y + @test y3 isa ConcreteRArray{Float32,ndims(dst)} + @test size(y3) == size(dsts[dim]) + y4 = @jit( + NNlib.scatter( + op, + Reactant.to_rarray(src), + Reactant.to_rarray(idx); + dstsize=size(dsts[dim]), + ) + ) + @test y4 ≈ target_y + @test y4 isa ConcreteRArray{Float32,ndims(dst)} + @test size(y4) == size(dsts[dim]) + + ridx = Reactant.to_rarray(idx) + if ridx isa Reactant.AbstractConcreteArray + @test_throws ArgumentError @jit( + NNlib.scatter(op, Reactant.to_rarray(src), ridx) + ) + else + y5 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), ridx)) + @test y5 ≈ target_y + @test y5 isa ConcreteRArray{Float32,ndims(dst)} + @test size(y5) == size(dsts[dim]) + end end - - y3 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), idx)) - @test y3 ≈ target_y - @test y3 isa ConcreteRArray{Float32,ndims(dst)} - @test size(y3) == size(dsts[dim]) - # y4 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), Reactant.to_rarray(idx))) - # @test y4 ≈ target_y - # @test y4 isa ConcreteRArray{Float32,ndims(dst)} - # @test size(y4) == size(dsts[dim]) end end -end -# Adapted from https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/test/testsuite/scatter.jl#L108 -@testset "NNlib scatter" begin - @testset "scatter 1d src, 1d idx => 1d output" begin + @testset "scatter 1d src, 1d index => 1d output" begin #! format: off dsts = Dict( 0 => Float32[3, 4, 5, 6, 7] @@ -474,65 +496,65 @@ end test_scatter(dsts, srcs, idxs, res; dims=[0]) end - @testset "scatter 2d src, 1d idx => 2d output" begin - #! format: off - dsts = Dict( - 0 => Float32[3 3 4 4 5 - 5 5 6 6 7] - ) - - srcs = Dict( - (0, true) => ones(Float32, 2, 5), - (0, false) => ones(Float32, 2) * collect(1:5)', - ) - - idxs = Dict( - :int => [4, 2, 1, 5, 3], - :tup => [(4,), (2,), (1,), (5,), (3,)], - :car => CartesianIndex.([(4,), (2,), (1,), (5,), (3,)]), - ) - - res = Dict( - (+, 0, true) => Float32[4 4 5 5 6; - 6 6 7 7 8], - (+, 0, false) => Float32[3 2 5 1 4; - 3 2 5 1 4], - - (-, 0, true) => Float32[2 2 3 3 4; - 4 4 5 5 6], - (-, 0, false) => Float32[-3 -2 -5 -1 -4; - -3 -2 -5 -1 -4], + @testset "scatter 2d src, 1d index => 2d output" begin + #! format: off + dsts = Dict( + 0 => Float32[3 3 4 4 5 + 5 5 6 6 7] + ) - (max, 0, true) => Float32[3 3 4 4 5; - 5 5 6 6 7], - (max, 0, false) => Float32[3 2 5 1 4; - 3 2 5 1 4], + srcs = Dict( + (0, true) => ones(Float32, 2, 5), + (0, false) => ones(Float32, 2) * collect(1:5)', + ) - (min, 0, true) => Float32[1 1 1 1 1; - 1 1 1 1 1], - (min, 0, false) => Float32[3 2 5 1 4; - 3 2 5 1 4], + idxs = Dict( + :int => [4, 2, 1, 5, 3], + :tup => [(4,), (2,), (1,), (5,), (3,)], + :car => CartesianIndex.([(4,), (2,), (1,), (5,), (3,)]), + ) - (*, 0, true) => Float32[3 3 4 4 5; - 5 5 6 6 7], - (*, 0, false) => Float32[3 2 5 1 4; - 3 2 5 1 4], - - (/, 0, true) => Float32[1.5 1.5 2.0 2.0 2.5; - 2.5 2.5 3.0 3.0 3.5], - (/, 0, false) => Float32[1//6 1//4 1//10 1//2 1//8; - 1//6 1//4 1//10 1//2 1//8], - - (mean, 0, true) => Float32[4 4 5 5 6; - 6 6 7 7 8], - (mean, 0, false) => Float32[3 2 5 1 4; - 3 2 5 1 4], - ) - #! format: on + res = Dict( + (+, 0, true) => Float32[4 4 5 5 6; + 6 6 7 7 8], + (+, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + + (-, 0, true) => Float32[2 2 3 3 4; + 4 4 5 5 6], + (-, 0, false) => Float32[-3 -2 -5 -1 -4; + -3 -2 -5 -1 -4], + + (max, 0, true) => Float32[3 3 4 4 5; + 5 5 6 6 7], + (max, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + + (min, 0, true) => Float32[1 1 1 1 1; + 1 1 1 1 1], + (min, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + + (*, 0, true) => Float32[3 3 4 4 5; + 5 5 6 6 7], + (*, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + + (/, 0, true) => Float32[1.5 1.5 2.0 2.0 2.5; + 2.5 2.5 3.0 3.0 3.5], + (/, 0, false) => Float32[1//6 1//4 1//10 1//2 1//8; + 1//6 1//4 1//10 1//2 1//8], + + (mean, 0, true) => Float32[4 4 5 5 6; + 6 6 7 7 8], + (mean, 0, false) => Float32[3 2 5 1 4; + 3 2 5 1 4], + ) + #! format: on test_scatter(dsts, srcs, idxs, res; dims=[0]) end - @testset "scatter 2d+3d src, 2d idx => 1d+2d output" begin + @testset "scatter 2d+3d src, 2d index => 1d+2d output" begin #! format: off dsts = Dict( 0 => Float32[3, 4, 5, 6, 7], @@ -594,9 +616,9 @@ end (/, 0, true) => Float32[0.75, 1., 0.3125, 1.5, 1.75], (/, 1, true) => Float32[0.75 0.75 0.25 1. 1.25; 1.25 1.25 0.375 1.5 1.75], - (/, 0, false) => Float32[1//3, 1//4, 1//48, 1//4, 1//6], - (/, 1, false) => Float32[1//3 1//4 1//48 1//4 1//6; - 1//12 1//16 1//768 1//16 1//24], + (/, 0, false) => Float32[1//12, 1//16, 1//768, 1//16, 1//24], + (/, 1, false) => Float32[1//12 1//16 1//768 1//16 1//24; + 1//48 1//64 1//12288 1//64 1//96], (mean, 0, true) => Float32[4., 5., 6., 7., 8.], (mean, 1, true) => Float32[4. 4. 5. 5. 6.; 6. 6. 7. 7. 8.], @@ -606,7 +628,7 @@ end ) #! format: on - test_scatter(dsts, srcs, idxs, res; dims=[0]) + test_scatter(dsts, srcs, idxs, res; dims=[0, 1]) end end From 196cce64f206f7cbcc8c6d3290ec7adf5b031177 Mon Sep 17 00:00:00 2001 From: Julian Trommer Date: Thu, 5 Jun 2025 10:55:06 +0200 Subject: [PATCH 3/6] Added support for mean in NNlib.scatter --- Project.toml | 2 +- ext/ReactantNNlibExt/Implementations.jl | 12 ++++++++++++ ext/ReactantNNlibExt/ReactantNNlibExt.jl | 1 + test/nn/nnlib.jl | 2 +- 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index a50340aac8..8d85230020 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"] ReactantKernelAbstractionsExt = "KernelAbstractions" ReactantMPIExt = "MPI" -ReactantNNlibExt = "NNlib" +ReactantNNlibExt = ["NNlib", "Statistics"] ReactantOffsetArraysExt = "OffsetArrays" ReactantOneHotArraysExt = "OneHotArrays" ReactantPythonCallExt = "PythonCall" diff --git a/ext/ReactantNNlibExt/Implementations.jl b/ext/ReactantNNlibExt/Implementations.jl index 9d576a8cc2..dfc3da860c 100644 --- a/ext/ReactantNNlibExt/Implementations.jl +++ b/ext/ReactantNNlibExt/Implementations.jl @@ -524,6 +524,18 @@ function NNlib.scatter!( return dst end +for AT in (AbstractArray, AbstractArray{<:Number}) + @eval function NNlib.scatter!( + ::typeof(mean), dst::AnyTracedRArray, src::AnyTracedRArray, idx::$AT + ) + Ns = NNlib.scatter!(+, zero(dst), one.(src), idx) + dst_ = NNlib.scatter!(+, zero(dst), src, idx) + res = dst .+ NNlib.safe_div.(dst_, Ns) + set_mlir_data!(dst, get_mlir_data(res)) + return dst + end +end + function _nnlib_scatter_impl( op::OP, dst::AnyTracedRArray{T}, diff --git a/ext/ReactantNNlibExt/ReactantNNlibExt.jl b/ext/ReactantNNlibExt/ReactantNNlibExt.jl index 438c0b723e..b11d43ddc9 100644 --- a/ext/ReactantNNlibExt/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt/ReactantNNlibExt.jl @@ -10,6 +10,7 @@ using Reactant.TracedUtils: using ReactantCore: @trace using LinearAlgebra: LinearAlgebra, triu +using Statistics: mean include("Overlay.jl") include("Ops.jl") diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 06d4398b29..64a1ae5a34 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -386,7 +386,7 @@ end # mean is omitted as operation to avoid the ambiguity error @testset "NNlib scatter" begin function test_scatter(dsts, srcs, idxs, res; dims) - @testset "scatter Float32 $op" for op in (+, -, max, min, *, /) + @testset "scatter Float32 $op" for op in (+, -, max, min, *, /, mean) for idx in values(idxs), dim in dims dst = copy(dsts[dim]) target_y = res[(op, dim, true)] From 731d350efa0dc04e9f24bb0b02ad85232812b576 Mon Sep 17 00:00:00 2001 From: JulianTrommer Date: Thu, 7 Aug 2025 10:04:04 +0200 Subject: [PATCH 4/6] Added tests for scatter gradient --- test/nn/nnlib.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 64a1ae5a34..ffc0b9b463 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -630,6 +630,49 @@ end test_scatter(dsts, srcs, idxs, res; dims=[0, 1]) end + + @testset "scatter gradient" begin + + dst = Float32[3 3 4 4 5 + 5 5 6 6 7] + dst_ca = Reactant.to_rarray(dst) + + src = ones(Float32, 2, 5) + src_ca = Reactant.to_rarray(src) + + idx = [4, 2, 1, 5, 3] + idx_ca = Reactant.to_rarray(idx) + + function test_scatter(dsts, srcs, idxs) + return sum(NNlib.scatter!(-, dsts, srcs, idxs)) + end + + function test_gradient(objective_function, dsts, srcs, idxs) + derivs, val = Enzyme.gradient( + Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), + Const(objective_function), + dsts, + srcs, + idxs, + ) + return derivs, val + end + + test_gradient_compiled = @compile test_gradient(test_scatter, dst_ca, src_ca, idx_ca) + + grads_enz, loss_enz = Enzyme.gradient( + Enzyme.ReverseWithPrimal, + Const(test_scatter), + dst, + src, + idx + ) + grads_ca, loss_ca = test_gradient_compiled(test_scatter, dst_ca, src_ca, idx_ca) + + @test grads_enz[1] ≈ Array(grads_ca[1]) + @test grads_enz[2] ≈ Array(grads_ca[2]) + @test loss_enz ≈ loss_ca + end end @testset "∇conv(D = $ndim)" for ndim in 1:3 From 127fd70a058400505d37ebdaa92b4a842fed7822 Mon Sep 17 00:00:00 2001 From: JulianTrommer Date: Mon, 11 Aug 2025 13:06:04 +0200 Subject: [PATCH 5/6] Fixed wrong operation in test --- test/nn/nnlib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index ffc0b9b463..c6dbb7671c 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -644,7 +644,7 @@ end idx_ca = Reactant.to_rarray(idx) function test_scatter(dsts, srcs, idxs) - return sum(NNlib.scatter!(-, dsts, srcs, idxs)) + return sum(NNlib.scatter!(+, dsts, srcs, idxs)) end function test_gradient(objective_function, dsts, srcs, idxs) From 83a411b41f073a72d116ff4615f4021b23eabe69 Mon Sep 17 00:00:00 2001 From: JulianTrommer Date: Tue, 12 Aug 2025 10:15:34 +0200 Subject: [PATCH 6/6] Fixed formatting issue --- test/nn/nnlib.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index fd06564b37..f5b6427eca 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -631,9 +631,10 @@ end end @testset "scatter gradient" begin - - dst = Float32[3 3 4 4 5 - 5 5 6 6 7] + dst = Float32[ + 3 3 4 4 5 + 5 5 6 6 7 + ] dst_ca = Reactant.to_rarray(dst) src = ones(Float32, 2, 5) @@ -657,14 +658,12 @@ end return derivs, val end - test_gradient_compiled = @compile test_gradient(test_scatter, dst_ca, src_ca, idx_ca) + test_gradient_compiled = @compile test_gradient( + test_scatter, dst_ca, src_ca, idx_ca + ) grads_enz, loss_enz = Enzyme.gradient( - Enzyme.ReverseWithPrimal, - Const(test_scatter), - dst, - src, - idx + Enzyme.ReverseWithPrimal, Const(test_scatter), dst, src, idx ) grads_ca, loss_ca = test_gradient_compiled(test_scatter, dst_ca, src_ca, idx_ca)