Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit c337dd5

Browse files
Merge pull request #232 from vpuri3/update_coeffs
have update_coeffs(L::ADVecProd,) recursively update L.f
2 parents 4e4fc7b + 57f4a55 commit c337dd5

File tree

6 files changed

+203
-90
lines changed

6 files changed

+203
-90
lines changed

ext/SparseDiffToolsZygote.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache
2525
g(cache1, x)
2626
@. x -= 2ϵ * v
2727
g(cache2, x)
28+
@. x += ϵ * v
2829
@. dy = (cache1 - cache2) / (2ϵ)
2930
end
3031

src/differentiation/jaches_products.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ function num_hesvec!(dy,
7878
g(cache2, x)
7979
@. x -= 2ϵ * v
8080
g(cache3, x)
81+
@. x += ϵ * v
8182
@. dy = (cache2 - cache3) / (2ϵ)
8283
end
8384

@@ -110,6 +111,7 @@ function numauto_hesvec!(dy,
110111
g(cache1, x)
111112
@. x -= 2ϵ * v
112113
g(cache2, x)
114+
@. x += ϵ * v
113115
@. dy = (cache1 - cache2) / (2ϵ)
114116
end
115117

@@ -158,6 +160,7 @@ function num_hesvecgrad!(dy, g, x, v, cache2 = similar(v), cache3 = similar(v))
158160
g(cache2, x)
159161
@. x -= 2ϵ * v
160162
g(cache3, x)
163+
@. x += ϵ * v
161164
@. dy = (cache2 - cache3) / (2ϵ)
162165
end
163166

@@ -207,10 +210,12 @@ struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd
207210
end
208211

209212
function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t)
210-
FwdModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache)
213+
f = update_coefficients(L.f, u, p, t)
214+
FwdModeAutoDiffVecProd(f, u, L.cache, L.vecprod, L.vecprod!)
211215
end
212216

213217
function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t)
218+
update_coefficients!(L.f, u, p, t)
214219
copy!(L.u, u)
215220
L
216221
end

src/differentiation/vecjac_products.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,26 +65,28 @@ struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
6565
end
6666

6767
function update_coefficients(L::RevModeAutoDiffVecProd, u, p, t)
68-
RevModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache)
68+
f = update_coefficients(L.f, u, p, t)
69+
RevModeAutoDiffVecProd(f, u, L.vecprod, L.vecprod!, L.cache)
6970
end
7071

7172
function update_coefficients!(L::RevModeAutoDiffVecProd, u, p, t)
73+
update_coefficients!(L.f, u, p, t)
7274
copy!(L.u, u)
7375
L
7476
end
7577

7678
# Interpret the call as df/du' * u
7779
function (L::RevModeAutoDiffVecProd)(v, p, t)
78-
L.vecprod(_u -> L.f(_u, p, t), L.u, v)
80+
L.vecprod(L.f, L.u, v)
7981
end
8082

8183
# prefer non in-place method
8284
function (L::RevModeAutoDiffVecProd{ad,iip,true})(dv, v, p, t) where{ad,iip}
83-
L.vecprod!(dv, _u -> L.f(_u, p, t), L.u, v, L.cache...)
85+
L.vecprod!(dv, L.f, L.u, v, L.cache...)
8486
end
8587

8688
function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad}
87-
L.vecprod!(dv, (_du, _u) -> L.f(_du, _u, p, t), L.u, v, L.cache...)
89+
L.vecprod!(dv, L.f, L.u, v, L.cache...)
8890
end
8991

9092
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
@@ -100,11 +102,11 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFi
100102

101103
cache = (similar(u), similar(u),)
102104

103-
outofplace = static_hasmethod(f, typeof((u, p, t)))
104-
isinplace = static_hasmethod(f, typeof((u, u, p, t)))
105+
outofplace = static_hasmethod(f, typeof((u,)))
106+
isinplace = static_hasmethod(f, typeof((u, u,)))
105107

106108
if !(isinplace) & !(outofplace)
107-
error("$f must have signature f(u, p, t), or f(du, u, p, t)")
109+
error("$f must have signature f(u), or f(du, u)")
108110
end
109111

110112
L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = autodiff,

test/test_jaches_products.jl

Lines changed: 115 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,40 @@ using LinearAlgebra, Test
44
using Random
55
Random.seed!(123)
66
N = 300
7-
const A = rand(N, N)
8-
f(y, x) = mul!(y, A, x)
9-
f(x) = A * x
7+
108
x = rand(N)
119
v = rand(N)
10+
11+
# Save original values of x and v to make sure they are not ever mutated
12+
x0 = copy(x)
13+
v0 = copy(v)
14+
1215
a, b = rand(2)
1316
dy = similar(x)
14-
g(x) = sum(abs2, x)
15-
function h(x)
16-
FiniteDiff.finite_difference_gradient(g, x)
17+
18+
# Define functions for testing
19+
20+
A = rand(N, N)
21+
_f(y, x) = mul!(y, A, x.^2)
22+
_f(x) = A * (x.^2)
23+
24+
_g(x) = sum(abs2, x.^2)
25+
function _h(x)
26+
FiniteDiff.finite_difference_gradient(_g, x)
1727
end
18-
function h(dy, x)
19-
FiniteDiff.finite_difference_gradient!(dy, g, x)
28+
function _h(dy, x)
29+
FiniteDiff.finite_difference_gradient!(dy, _g, x)
2030
end
2131

32+
# Make functions state-dependent for operator tests
33+
34+
include("update_coeffs_testutils.jl")
35+
f = WrapFunc(_f, 1.0, 1.0)
36+
g = WrapFunc(_g, 1.0, 1.0)
37+
h = WrapFunc(_h, 1.0, 1.0)
38+
39+
###
40+
2241
cache1 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))),
2342
eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
2443
cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), eltype(x))), eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(v)))
@@ -36,122 +55,147 @@ cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), e
3655
similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-2
3756
@test num_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
3857

39-
@test numauto_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
58+
@test numauto_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
4059
@test numauto_hesvec!(dy, g, x, v, ForwardDiff.GradientConfig(g, x), similar(v),
41-
similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-8
42-
@test numauto_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
60+
similar(v))ForwardDiff.hessian(g, x) * v
61+
@test numauto_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v
4362

44-
@test autonum_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
63+
@test autonum_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
4564
@test autonum_hesvec!(dy, g, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v rtol=1e-2
46-
@test autonum_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
65+
@test autonum_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v
4766

48-
@test numback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
49-
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-8
50-
@test numback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
67+
@test numback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
68+
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v
69+
@test numback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v
5170

5271
cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
5372
}.(x, ForwardDiff.Partials.(tuple.(v)))
5473
cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
5574
}.(x, ForwardDiff.Partials.(tuple.(v)))
56-
@test autoback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
57-
@test autoback_hesvec!(dy, g, x, v, cache3, cache4)ForwardDiff.hessian(g, x) * v rtol=1e-8
58-
@test autoback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
75+
@test autoback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
76+
@test autoback_hesvec!(dy, g, x, v, cache3, cache4)ForwardDiff.hessian(g, x) * v
77+
@test autoback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v
5978

6079
@test num_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
6180
@test num_hesvecgrad!(dy, h, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-2
6281
@test num_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
6382

64-
@test auto_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
65-
@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v rtol=1e-2
66-
@test auto_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
83+
@test auto_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v
84+
@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v
85+
@test auto_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v
6786

6887
@info "JacVec"
6988

70-
L = JacVec(f, x)
89+
L = JacVec(f, copy(x), 1.0, 1.0)
90+
update_coefficients!(f, x, 1.0, 1.0)
7191
@test L * x auto_jacvec(f, x, x)
7292
@test L * v auto_jacvec(f, x, v)
7393
@test mul!(dy, L, v) auto_jacvec(f, x, v)
7494
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
75-
update_coefficients!(L, v, nothing, 0.0)
76-
@test mul!(dy, L, v) auto_jacvec(f, v, v)
77-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
78-
79-
L = JacVec(f, x, autodiff = AutoFiniteDiff())
95+
update_coefficients!(L, v, 3.0, 4.0)
96+
update_coefficients!(f, v, 3.0, 4.0)
97+
@test mul!(dy, L, x) auto_jacvec(f, v, x)
98+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*auto_jacvec(f,v,x) + b*_dy
99+
update_coefficients!(f, v, 5.0, 6.0)
100+
@test L(dy, v, 5.0, 6.0) auto_jacvec(f,v,v)
101+
102+
L = JacVec(f, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
103+
update_coefficients!(f, x, 1.0, 1.0)
80104
@test L * x num_jacvec(f, x, x)
81105
@test L * v num_jacvec(f, x, v)
82106
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
83107
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
84-
update_coefficients!(L, v, nothing, 0.0)
85-
@test mul!(dy, L, v)num_jacvec(f, v, v) rtol=1e-6
86-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
108+
update_coefficients!(L, v, 3.0, 4.0)
109+
update_coefficients!(f, v, 3.0, 4.0)
110+
@test mul!(dy, L, x)num_jacvec(f, v, x) rtol=1e-6
111+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*num_jacvec(f,v,x) + b*_dy rtol=1e-6
112+
update_coefficients!(f, v, 5.0, 6.0)
113+
@test L(dy, v, 5.0, 6.0) num_jacvec(f,v,v) rtol=1e-6
87114

88115
out = similar(v)
89-
gmres!(out, L, v)
116+
@test_nowarn gmres!(out, L, v)
90117

91118
@info "HesVec"
92119

93-
x = rand(N)
94-
v = rand(N)
95-
L = HesVec(g, x, autodiff = AutoFiniteDiff())
96-
@test L * x num_hesvec(g, x, x)
97-
@test L * v num_hesvec(g, x, v)
120+
L = HesVec(g, copy(x), 1.0, 1.0, autodiff = AutoFiniteDiff())
121+
update_coefficients!(g, x, 1.0, 1.0)
122+
@test L * x num_hesvec(g, x, x) rtol=1e-2
123+
@test L * v num_hesvec(g, x, v) rtol=1e-2
98124
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
99125
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
100-
update_coefficients!(L, v, nothing, 0.0)
101-
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
102-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
103-
104-
L = HesVec(g, x)
126+
update_coefficients!(L, v, 3.0, 4.0)
127+
update_coefficients!(g, v, 3.0, 4.0)
128+
@test mul!(dy, L, x)num_hesvec(g, v, x) rtol=1e-2
129+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*num_hesvec(g,v,x) + b*_dy rtol=1e-2
130+
update_coefficients!(g, v, 5.0, 6.0)
131+
@test L(dy, v, 5.0, 6.0) num_hesvec(g,v,v) rtol=1e-2
132+
133+
L = HesVec(g, copy(x), 1.0, 1.0)
105134
@test L * x numauto_hesvec(g, x, x)
106135
@test L * v numauto_hesvec(g, x, v)
107-
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
108-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
109-
update_coefficients!(L, v, nothing, 0.0)
110-
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
111-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
136+
@test mul!(dy, L, v)numauto_hesvec(g, x, v)
137+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,0)a*numauto_hesvec(g,x,v)+0*_dy
138+
update_coefficients!(L, v, 3.0, 4.0)
139+
update_coefficients!(g, v, 3.0, 4.0)
140+
@test mul!(dy, L, x)numauto_hesvec(g, v, x)
141+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*numauto_hesvec(g,v,x)+b*_dy
142+
update_coefficients!(g, v, 5.0, 6.0)
143+
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g,v,v)
112144

113145
out = similar(v)
114146
gmres!(out, L, v)
115147

116-
using Zygote
117-
118-
x = rand(N)
119-
v = rand(N)
120-
121-
L = HesVec(g, x, autodiff = AutoZygote())
148+
L = HesVec(g, copy(x), 1.0, 1.0; autodiff = AutoZygote())
149+
update_coefficients!(g, x, 1.0, 1.0)
122150
@test L * x autoback_hesvec(g, x, x)
123151
@test L * v autoback_hesvec(g, x, v)
124-
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8
125-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
126-
update_coefficients!(L, v, nothing, 0.0)
127-
@test mul!(dy, L, v)autoback_hesvec(g, v, v) rtol=1e-8
128-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
152+
@test mul!(dy, L, v)autoback_hesvec(g, x, v)
153+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy
154+
update_coefficients!(L, v, 3.0, 4.0)
155+
update_coefficients!(g, v, 3.0, 4.0)
156+
@test mul!(dy, L, x)autoback_hesvec(g, v, x)
157+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*autoback_hesvec(g,v,x)+b*_dy
158+
update_coefficients!(g, v, 5.0, 6.0)
159+
@test L(dy, v, 5.0, 6.0) autoback_hesvec(g,v,v)
129160

130161
out = similar(v)
131162
gmres!(out, L, v)
132163

133164
@info "HesVecGrad"
134165

135-
x = rand(N)
136-
v = rand(N)
137-
L = HesVecGrad(h, x, autodiff = AutoFiniteDiff())
166+
L = HesVecGrad(h, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
167+
update_coefficients!(h, x, 1.0, 1.0)
168+
update_coefficients!(g, x, 1.0, 1.0)
138169
@test L * x num_hesvec(g, x, x) rtol=1e-2
139170
@test L * v num_hesvec(g, x, v) rtol=1e-2
140171
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
141172
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
142-
update_coefficients!(L, v, nothing, 0.0)
143-
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
144-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
145-
146-
L = HesVecGrad(h, x)
173+
for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end
174+
@test mul!(dy, L, x)num_hesvec(g, v, x) rtol=1e-2
175+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*num_hesvec(g,v,x)+b*_dy rtol=1e-2
176+
update_coefficients!(g, v, 5.0, 6.0)
177+
@test L(dy, v, 5.0, 6.0) num_hesvec(g,v,v) rtol=1e-2
178+
179+
L = HesVecGrad(h, copy(x), 1.0, 1.0)
180+
update_coefficients!(g, x, 1.0, 1.0)
181+
update_coefficients!(h, x, 1.0, 1.0)
147182
@test L * x autonum_hesvec(g, x, x)
148183
@test L * v numauto_hesvec(g, x, v)
149-
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
150-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
151-
update_coefficients!(L, v, nothing, 0.0)
152-
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
153-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
184+
@test mul!(dy, L, v)numauto_hesvec(g, x, v)
185+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy
186+
for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end
187+
@test mul!(dy, L, x)numauto_hesvec(g, v, x)
188+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*numauto_hesvec(g,v,x)+b*_dy
189+
update_coefficients!(g, v, 5.0, 6.0)
190+
update_coefficients!(h, v, 5.0, 6.0)
191+
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g,v,v)
154192

155193
out = similar(v)
156194
gmres!(out, L, v)
195+
196+
# Test that x and v were not mutated
197+
# x's rtol can't be too large since it is mutated and then restored in some algorithms
198+
@test x x0
199+
@test v v0
200+
157201
#

0 commit comments

Comments
 (0)