Skip to content

fix: regression in non-fast scalar indexing support #760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "1.0.1"
version = "1.0.2"

[deps]
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Expand All @@ -15,9 +15,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[weakdeps]
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
ForwardDiffGPUArraysCoreExt = "GPUArraysCore"
ForwardDiffStaticArraysExt = "StaticArrays"

[compat]
Expand All @@ -26,6 +28,7 @@ CommonSubexpressions = "0.3"
DiffResults = "1.1"
DiffRules = "1.4"
DiffTests = "0.1"
GPUArraysCore = "0.2"
IrrationalConstants = "0.1, 0.2"
LogExpFunctions = "0.3"
NaNMath = "1"
Expand All @@ -39,9 +42,10 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"]
test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils", "JLArrays"]
65 changes: 65 additions & 0 deletions ext/ForwardDiffGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module ForwardDiffGPUArraysCoreExt

using GPUArraysCore: AbstractGPUArray
using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials

struct PartialsFn{T,D<:Dual}
dual::D
end
PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual)

(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i)

function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
seed::Partials{N,V}) where {T,V,N}
idxs = collect(ForwardDiff.structural_eachindex(duals, x))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
return duals
end

function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x,
seeds::NTuple{N,Partials{N,V}}) where {T,V,N}
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(duals, x), N))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
return duals
end

function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
seed::Partials{N,V}) where {T,V,N}
offset = index - 1
idxs = collect(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset))
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed))
return duals
end

function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index,
seeds::NTuple{N,Partials{N,V}}, chunksize) where {T,V,N}
offset = index - 1
idxs = collect(
Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset), chunksize)
)
duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs)))
return duals
end

# gradient
function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray,
dual::Dual) where {T}
fn = PartialsFn{T}(dual)
idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(result), npartials(dual)))
result[idxs] .= fn.(1:length(idxs))
return result
end

function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray, dual,
index, chunksize) where {T}
fn = PartialsFn{T}(dual)
offset = index - 1
idxs = collect(
Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(result), offset), chunksize)
)
result[idxs] .= fn.(1:length(idxs))
return result
end

end
2 changes: 0 additions & 2 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,6 @@ Base.copy(d::Dual) = d
Base.eps(d::Dual) = eps(value(d))
Base.eps(::Type{D}) where {D<:Dual} = eps(valtype(D))

# The `base` keyword was added in Julia 1.8:
# https://github.com/JuliaLang/julia/pull/42428
Base.precision(d::Dual; base::Integer=2) = precision(value(d); base=base)
function Base.precision(::Type{D}; base::Integer=2) where {D<:Dual}
precision(valtype(D); base=base)
Expand Down
1 change: 0 additions & 1 deletion test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
@test precision(typeof(FDNUM)) === precision(V)
@test precision(NESTED_FDNUM) === precision(PRIMAL)
@test precision(typeof(NESTED_FDNUM)) === precision(V)

@test precision(FDNUM; base=10) === precision(PRIMAL; base=10)
@test precision(typeof(FDNUM); base=10) === precision(V; base=10)
@test precision(NESTED_FDNUM; base=10) === precision(PRIMAL; base=10)
Expand Down
22 changes: 22 additions & 0 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using ForwardDiff
using ForwardDiff: Dual, Tag
using StaticArrays
using DiffTests
using JLArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -255,4 +256,25 @@ end
end
end

@testset "GPUArraysCore" begin
fn(x) = sum(x .^ 2 ./ 2)

x = [1.0, 2.0, 3.0]
x_jl = JLArray(x)

grad = ForwardDiff.gradient(fn, x)
grad_jl = ForwardDiff.gradient(fn, x_jl)

@test grad_jl isa JLArray
@test Array(grad_jl) grad

cfg = ForwardDiff.GradientConfig(
fn, x_jl, ForwardDiff.Chunk{2}(), ForwardDiff.Tag(fn, eltype(x))
)
grad_jl = ForwardDiff.gradient(fn, x_jl, cfg)

@test grad_jl isa JLArray
@test Array(grad_jl) grad
end

end # module
14 changes: 14 additions & 0 deletions test/JacobianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ForwardDiff: Dual, Tag, JacobianConfig
using StaticArrays
using DiffTests
using LinearAlgebra
using JLArrays

include(joinpath(dirname(@__FILE__), "utils.jl"))

Expand Down Expand Up @@ -279,4 +280,17 @@ end
end
end

@testset "GPUArraysCore" begin
f(x) = x .^ 2 ./ 2

x = [1.0, 2.0, 3.0]
x_jl = JLArray(x)

jac = ForwardDiff.jacobian(f, x)
jac_jl = ForwardDiff.jacobian(f, x_jl)

@test jac_jl isa JLArray
@test Array(jac_jl) ≈ jac
end

end # module
Loading