Skip to content

Commit aa14fbe

Browse files
committed
Fix gradient and Jacobian for functions with Dual output
1 parent 8820300 commit aa14fbe

File tree

7 files changed

+95
-7
lines changed

7 files changed

+95
-7
lines changed

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDi
4747
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
4848
return quote
4949
$(Expr(:meta, :inline))
50-
V = StaticArrays.similar_type(S, valtype($y))
50+
V = StaticArrays.similar_type(S, valtype(T, $y))
5151
return V($result)
5252
end
5353
end
@@ -76,7 +76,7 @@ end
7676
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
7777
return quote
7878
$(Expr(:meta, :inline))
79-
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
79+
V = StaticArrays.similar_type(S, valtype(T, eltype($ydual)), Size($M, $N))
8080
return V($result)
8181
end
8282
end
@@ -87,7 +87,7 @@ end
8787
end
8888

8989
function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where T
90-
result = similar(ydual, valtype(eltype(ydual)), length(ydual), length(x))
90+
result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), length(x))
9191
return extract_jacobian!(T, result, ydual, length(x))
9292
end
9393

src/dual.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,17 @@ end
128128
@inline valtype(::Dual{T,V,N}) where {T,V,N} = V
129129
@inline valtype(::Type{Dual{T,V,N}}) where {T,V,N} = V
130130

131+
@inline valtype(::Type{T}, ::V) where {T,V} = valtype(T, V)
132+
@inline valtype(::Type, ::Type{V}) where {V} = V
133+
@inline valtype(::Type{T}, ::Type{Dual{T,V,N}}) where {T,V,N} = V
134+
@inline function valtype(::Type{T}, ::Type{Dual{S,V,N}}) where {T,S,V,N}
135+
if S T
136+
Dual{S,V,N}
137+
else
138+
throw(DualMismatchError(T,S))
139+
end
140+
end
141+
131142
@inline tagtype(::V) where {V} = Nothing
132143
@inline tagtype(::Type{V}) where {V} = Nothing
133144
@inline tagtype(::Dual{T,V,N}) where {T,V,N} = T

src/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ const GRAD_ERROR = DimensionMismatch("gradient(f, x) expects that f(x) is a real
9797
function vector_mode_gradient(f::F, x, cfg::GradientConfig{T}) where {T, F}
9898
ydual = vector_mode_dual_eval!(f, cfg, x)
9999
ydual isa Real || throw(GRAD_ERROR)
100-
result = similar(x, valtype(ydual))
100+
result = similar(x, valtype(T, ydual))
101101
return extract_gradient!(T, result, ydual)
102102
end
103103

@@ -156,7 +156,7 @@ function chunk_mode_gradient_expr(result_definition::Expr)
156156
end
157157

158158
@eval function chunk_mode_gradient(f::F, x, cfg::GradientConfig{T,V,N}) where {F,T,V,N}
159-
$(chunk_mode_gradient_expr(:(result = similar(x, valtype(ydual)))))
159+
$(chunk_mode_gradient_expr(:(result = similar(x, valtype(T, ydual)))))
160160
end
161161

162162
@eval function chunk_mode_gradient!(result, f::F, x, cfg::GradientConfig{T,V,N}) where {F,T,V,N}

src/jacobian.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function vector_mode_jacobian(f::F, x, cfg::JacobianConfig{T}) where {F,T}
128128
N = chunksize(cfg)
129129
ydual = vector_mode_dual_eval!(f, cfg, x)
130130
ydual isa AbstractArray || throw(JACOBIAN_ERROR)
131-
result = similar(ydual, valtype(eltype(ydual)), length(ydual), N)
131+
result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), N)
132132
extract_jacobian!(T, result, ydual, N)
133133
extract_value!(T, result, ydual)
134134
return result
@@ -217,7 +217,7 @@ end
217217
seed!(xdual, x)
218218
end,
219219
:(ydual = f(xdual)),
220-
:(result = similar(ydual, valtype(eltype(ydual)), length(ydual), xlen)),
220+
:(result = similar(ydual, valtype(T, eltype(ydual)), length(ydual), xlen)),
221221
:()))
222222
end
223223

test/DualTest.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,21 @@ ForwardDiff.:≺(::Type{OuterTestTag}, ::Type{TestTag}) = false
101101
@test ForwardDiff.valtype(NESTED_FDNUM) == Dual{TestTag,V,M}
102102
@test ForwardDiff.valtype(typeof(NESTED_FDNUM)) == Dual{TestTag,V,M}
103103

104+
@test ForwardDiff.valtype(TestTag, FDNUM) == V
105+
@test ForwardDiff.valtype(TestTag, typeof(FDNUM)) == V
106+
@test ForwardDiff.valtype(TestTag, NESTED_FDNUM) == Dual{TestTag,V,M}
107+
@test ForwardDiff.valtype(TestTag, typeof(NESTED_FDNUM)) == Dual{TestTag,V,M}
108+
109+
@test ForwardDiff.valtype(OuterTestTag, FDNUM) == Dual{TestTag,V,N}
110+
@test ForwardDiff.valtype(OuterTestTag, typeof(FDNUM)) == Dual{TestTag,V,N}
111+
@test ForwardDiff.valtype(OuterTestTag, NESTED_FDNUM) == Dual{TestTag,Dual{TestTag,V,M},N}
112+
@test ForwardDiff.valtype(OuterTestTag, typeof(NESTED_FDNUM)) == Dual{TestTag,Dual{TestTag,V,M},N}
113+
114+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, Dual{OuterTestTag}(PRIMAL, PARTIALS))
115+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, typeof(Dual{OuterTestTag}(PRIMAL, PARTIALS)))
116+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, Dual{OuterTestTag}(Dual{TestTag}(PRIMAL, M_PARTIALS), NESTED_PARTIALS))
117+
@test_throws ForwardDiff.DualMismatchError(TestTag, OuterTestTag) ForwardDiff.valtype(TestTag, typeof(Dual{OuterTestTag}(Dual{TestTag}(PRIMAL, M_PARTIALS), NESTED_PARTIALS)))
118+
104119
#####################
105120
# Generic Functions #
106121
#####################

test/GradientTest.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ using DiffTests
1212

1313
include(joinpath(dirname(@__FILE__), "utils.jl"))
1414

15+
struct TestTag end
16+
struct OuterTestTag end
17+
ForwardDiff.:(::Type{TestTag}, ::Type{OuterTestTag}) = true
18+
ForwardDiff.:(::Type{OuterTestTag}, ::Type{<:Tag}) = true
19+
1520
##################
1621
# hardcoded test #
1722
##################
@@ -255,4 +260,30 @@ end
255260
end
256261
end
257262

263+
# issue #769
264+
@testset "functions with `Dual` output" begin
265+
x = [Dual{OuterTestTag}(Dual{TestTag}(1.3, 2.1), Dual{TestTag}(0.3, -2.4))]
266+
f(x) = sum(ForwardDiff.value, x)
267+
der = ForwardDiff.derivative(ForwardDiff.value, only(x))
268+
269+
# Vector mode
270+
grad = ForwardDiff.gradient(f, x)
271+
@test grad isa Vector{typeof(der)}
272+
@test grad == [der]
273+
grad = ForwardDiff.gradient(f, SVector{1}(x))
274+
@test grad isa SVector{1,typeof(der)}
275+
@test grad == SVector{1}(der)
276+
277+
# Chunk mode
278+
y = repeat(x, 3)
279+
cfg = ForwardDiff.GradientConfig(f, y, ForwardDiff.Chunk{2}())
280+
grad = ForwardDiff.gradient(f, y, cfg)
281+
@test grad isa Vector{typeof(der)}
282+
@test grad == [der, der, der]
283+
cfg = ForwardDiff.GradientConfig(f, SVector{3}(y), ForwardDiff.Chunk{2}())
284+
grad = ForwardDiff.gradient(f, SVector{3}(y), cfg)
285+
@test grad isa SVector{3,typeof(der)}
286+
@test grad == SVector{3}(der, der, der)
287+
end
288+
258289
end # module

test/JacobianTest.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ using LinearAlgebra
1111

1212
include(joinpath(dirname(@__FILE__), "utils.jl"))
1313

14+
struct TestTag end
15+
struct OuterTestTag end
16+
ForwardDiff.:(::Type{TestTag}, ::Type{OuterTestTag}) = true
17+
ForwardDiff.:(::Type{OuterTestTag}, ::Type{<:Tag}) = true
18+
1419
##################
1520
# hardcoded test #
1621
##################
@@ -308,4 +313,30 @@ end
308313
end
309314
end
310315

316+
# issue #769
317+
@testset "functions with `Dual` output" begin
318+
x = [Dual{OuterTestTag}(Dual{TestTag}(1.3, 2.1), Dual{TestTag}(0.3, -2.4))]
319+
f(x) = map(ForwardDiff.value, x)
320+
der = ForwardDiff.derivative(ForwardDiff.value, only(x))
321+
322+
# Vector mode
323+
jac = ForwardDiff.jacobian(f, x)
324+
@test jac isa Matrix{typeof(der)}
325+
@test jac == [der;;]
326+
jac = ForwardDiff.jacobian(f, SVector{1}(x))
327+
@test jac isa SMatrix{1,1,typeof(der)}
328+
@test jac == SMatrix{1,1}(der)
329+
330+
# Chunk mode
331+
y = repeat(x, 3)
332+
cfg = ForwardDiff.JacobianConfig(f, y, ForwardDiff.Chunk{2}())
333+
jac = ForwardDiff.jacobian(f, y, cfg)
334+
@test jac isa Matrix{typeof(der)}
335+
@test jac == Diagonal([der, der, der])
336+
cfg = ForwardDiff.JacobianConfig(f, SVector{3}(y), ForwardDiff.Chunk{2}())
337+
jac = ForwardDiff.jacobian(f, SVector{3}(y), cfg)
338+
@test jac isa SMatrix{3,3,typeof(der)}
339+
@test jac == Diagonal([der, der, der])
340+
end
341+
311342
end # module

0 commit comments

Comments
 (0)