diff --git a/Project.toml b/Project.toml index 9eb2914c..abbf186c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ForwardDiff" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "1.0.1" +version = "1.0.2" [deps] CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950" @@ -15,9 +15,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [weakdeps] +GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] +ForwardDiffGPUArraysCoreExt = "GPUArraysCore" ForwardDiffStaticArraysExt = "StaticArrays" [compat] @@ -26,6 +28,7 @@ CommonSubexpressions = "0.3" DiffResults = "1.1" DiffRules = "1.4" DiffTests = "0.1" +GPUArraysCore = "0.2" IrrationalConstants = "0.1, 0.2" LogExpFunctions = "0.3" NaNMath = "1" @@ -39,9 +42,10 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" +JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "IrrationalConstants", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils", "JLArrays"] diff --git a/ext/ForwardDiffGPUArraysCoreExt.jl b/ext/ForwardDiffGPUArraysCoreExt.jl new file mode 100644 index 00000000..a881dc4d --- /dev/null +++ b/ext/ForwardDiffGPUArraysCoreExt.jl @@ -0,0 +1,65 @@ +module ForwardDiffGPUArraysCoreExt + +using GPUArraysCore: AbstractGPUArray +using ForwardDiff: ForwardDiff, Dual, Partials, npartials, partials + +struct PartialsFn{T,D<:Dual} + dual::D +end +PartialsFn{T}(dual::Dual) where {T} = PartialsFn{T,typeof(dual)}(dual) + +(f::PartialsFn{T})(i) where {T} = partials(T, f.dual, i) + +function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, + seed::Partials{N,V}) where {T,V,N} + idxs = collect(ForwardDiff.structural_eachindex(duals, x)) + duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed)) + return duals +end + +function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, + seeds::NTuple{N,Partials{N,V}}) where {T,V,N} + idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(duals, x), N)) + duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs))) + return duals +end + +function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index, + seed::Partials{N,V}) where {T,V,N} + offset = index - 1 + idxs = collect(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset)) + duals[idxs] .= Dual{T,V,N}.(view(x, idxs), Ref(seed)) + return duals +end + +function ForwardDiff.seed!(duals::AbstractGPUArray{Dual{T,V,N}}, x, index, + seeds::NTuple{N,Partials{N,V}}, chunksize) where {T,V,N} + offset = index - 1 + idxs = collect( + Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(duals, x), offset), chunksize) + ) + duals[idxs] .= Dual{T,V,N}.(view(x, idxs), getindex.(Ref(seeds), 1:length(idxs))) + return duals +end + +# gradient +function ForwardDiff.extract_gradient!(::Type{T}, result::AbstractGPUArray, + dual::Dual) where {T} + fn = PartialsFn{T}(dual) + idxs = collect(Iterators.take(ForwardDiff.structural_eachindex(result), npartials(dual))) + result[idxs] .= fn.(1:length(idxs)) + return result +end + +function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::AbstractGPUArray, dual, + index, chunksize) where {T} + fn = PartialsFn{T}(dual) + offset = index - 1 + idxs = collect( + Iterators.take(Iterators.drop(ForwardDiff.structural_eachindex(result), offset), chunksize) + ) + result[idxs] .= fn.(1:length(idxs)) + return result +end + +end diff --git a/src/dual.jl b/src/dual.jl index 179c48d3..e5375601 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -298,8 +298,6 @@ Base.copy(d::Dual) = d Base.eps(d::Dual) = eps(value(d)) Base.eps(::Type{D}) where {D<:Dual} = eps(valtype(D)) -# The `base` keyword was added in Julia 1.8: -# https://github.com/JuliaLang/julia/pull/42428 Base.precision(d::Dual; base::Integer=2) = precision(value(d); base=base) function Base.precision(::Type{D}; base::Integer=2) where {D<:Dual} precision(valtype(D); base=base) diff --git a/test/DualTest.jl b/test/DualTest.jl index bd50c37c..5ac9a3e1 100644 --- a/test/DualTest.jl +++ b/test/DualTest.jl @@ -118,7 +118,6 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false @test precision(typeof(FDNUM)) === precision(V) @test precision(NESTED_FDNUM) === precision(PRIMAL) @test precision(typeof(NESTED_FDNUM)) === precision(V) - @test precision(FDNUM; base=10) === precision(PRIMAL; base=10) @test precision(typeof(FDNUM); base=10) === precision(V; base=10) @test precision(NESTED_FDNUM; base=10) === precision(PRIMAL; base=10) diff --git a/test/GradientTest.jl b/test/GradientTest.jl index 4f46c167..5c2c0938 100644 --- a/test/GradientTest.jl +++ b/test/GradientTest.jl @@ -9,6 +9,7 @@ using ForwardDiff using ForwardDiff: Dual, Tag using StaticArrays using DiffTests +using JLArrays include(joinpath(dirname(@__FILE__), "utils.jl")) @@ -255,4 +256,25 @@ end end end +@testset "GPUArraysCore" begin + fn(x) = sum(x .^ 2 ./ 2) + + x = [1.0, 2.0, 3.0] + x_jl = JLArray(x) + + grad = ForwardDiff.gradient(fn, x) + grad_jl = ForwardDiff.gradient(fn, x_jl) + + @test grad_jl isa JLArray + @test Array(grad_jl) ≈ grad + + cfg = ForwardDiff.GradientConfig( + fn, x_jl, ForwardDiff.Chunk{2}(), ForwardDiff.Tag(fn, eltype(x)) + ) + grad_jl = ForwardDiff.gradient(fn, x_jl, cfg) + + @test grad_jl isa JLArray + @test Array(grad_jl) ≈ grad +end + end # module diff --git a/test/JacobianTest.jl b/test/JacobianTest.jl index 1e52f7fa..865503b5 100644 --- a/test/JacobianTest.jl +++ b/test/JacobianTest.jl @@ -8,6 +8,7 @@ using ForwardDiff: Dual, Tag, JacobianConfig using StaticArrays using DiffTests using LinearAlgebra +using JLArrays include(joinpath(dirname(@__FILE__), "utils.jl")) @@ -279,4 +280,17 @@ end end end +@testset "GPUArraysCore" begin + f(x) = x .^ 2 ./ 2 + + x = [1.0, 2.0, 3.0] + x_jl = JLArray(x) + + jac = ForwardDiff.jacobian(f, x) + jac_jl = ForwardDiff.jacobian(f, x_jl) + + @test jac_jl isa JLArray + @test Array(jac_jl) ≈ jac +end + end # module