Skip to content
42 changes: 42 additions & 0 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,48 @@ end

test_scatter(dsts, srcs, idxs, res; dims=[0, 1])
end

@testset "scatter gradient" begin
dst = Float32[
3 3 4 4 5
5 5 6 6 7
]
dst_ca = Reactant.to_rarray(dst)

src = ones(Float32, 2, 5)
src_ca = Reactant.to_rarray(src)

idx = [4, 2, 1, 5, 3]
idx_ca = Reactant.to_rarray(idx)

function test_scatter(dsts, srcs, idxs)
return sum(NNlib.scatter!(+, dsts, srcs, idxs))
end

function test_gradient(objective_function, dsts, srcs, idxs)
derivs, val = Enzyme.gradient(
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
Const(objective_function),
dsts,
srcs,
idxs,
)
return derivs, val
end

test_gradient_compiled = @compile test_gradient(
test_scatter, dst_ca, src_ca, idx_ca
)

grads_enz, loss_enz = Enzyme.gradient(
Enzyme.ReverseWithPrimal, Const(test_scatter), dst, src, idx
)
grads_ca, loss_ca = test_gradient_compiled(test_scatter, dst_ca, src_ca, idx_ca)

@test grads_enz[1] ≈ Array(grads_ca[1])
@test grads_enz[2] ≈ Array(grads_ca[2])
@test loss_enz ≈ loss_ca
end
end

@testset "∇conv(D = $ndim)" for ndim in 1:3
Expand Down
Loading