Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit c126bf8

Browse files
committed
Fix aliasing of x and v in tests, as well as unreverted edits to x in in-place
hesvecs, and update tests accordingly
1 parent 1f8274f commit c126bf8

File tree

4 files changed

+109
-73
lines changed

4 files changed

+109
-73
lines changed

ext/SparseDiffToolsZygote.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ function SparseDiffTools.numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache
2525
g(cache1, x)
2626
@. x -= 2ϵ * v
2727
g(cache2, x)
28+
@. x += ϵ * v
2829
@. dy = (cache1 - cache2) / (2ϵ)
2930
end
3031

@@ -37,6 +38,7 @@ function SparseDiffTools.numback_hesvec(f, x, v)
3738
gxp = g(x)
3839
x -= 2ϵ * v
3940
gxm = g(x)
41+
@. x += ϵ * v
4042
(gxp - gxm) / (2ϵ)
4143
end
4244

src/differentiation/jaches_products.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ function num_hesvec!(dy,
7878
g(cache2, x)
7979
@. x -= 2ϵ * v
8080
g(cache3, x)
81+
@. x += ϵ * v
8182
@. dy = (cache2 - cache3) / (2ϵ)
8283
end
8384

@@ -110,6 +111,7 @@ function numauto_hesvec!(dy,
110111
g(cache1, x)
111112
@. x -= 2ϵ * v
112113
g(cache2, x)
114+
@. x += ϵ * v
113115
@. dy = (cache1 - cache2) / (2ϵ)
114116
end
115117

@@ -158,6 +160,7 @@ function num_hesvecgrad!(dy, g, x, v, cache2 = similar(v), cache3 = similar(v))
158160
g(cache2, x)
159161
@. x -= 2ϵ * v
160162
g(cache3, x)
163+
@. x += ϵ * v
161164
@. dy = (cache2 - cache3) / (2ϵ)
162165
end
163166

test/test_jaches_products.jl

Lines changed: 66 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,23 @@ using LinearAlgebra, Test
44
using Random
55
Random.seed!(123)
66
N = 300
7-
const A = rand(N, N)
8-
9-
_f(y, x) = mul!(y, A, x.^2)
10-
_f(x) = A * (x.^2)
117

128
x = rand(N)
139
v = rand(N)
10+
11+
# Save original values of x and v to make sure they are not ever mutated
12+
x0 = copy(x)
13+
v0 = copy(v)
14+
1415
a, b = rand(2)
1516
dy = similar(x)
17+
18+
# Define functions for testing
19+
20+
A = rand(N, N)
21+
_f(y, x) = mul!(y, A, x.^2)
22+
_f(x) = A * (x.^2)
23+
1624
_g(x) = sum(abs2, x.^2)
1725
function _h(x)
1826
FiniteDiff.finite_difference_gradient(_g, x)
@@ -21,7 +29,7 @@ function _h(dy, x)
2129
FiniteDiff.finite_difference_gradient!(dy, _g, x)
2230
end
2331

24-
# Define state-dependent functions for operator tests
32+
# Make functions state-dependent for operator tests
2533

2634
include("update_coeffs_testutils.jl")
2735
f = WrapFunc(_f, 1.0, 1.0)
@@ -47,152 +55,147 @@ cache2 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(SparseDiffTools.DeivVecTag(), e
4755
similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-2
4856
@test num_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
4957

50-
@test numauto_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
58+
@test numauto_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
5159
@test numauto_hesvec!(dy, g, x, v, ForwardDiff.GradientConfig(g, x), similar(v),
52-
similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-8
53-
@test numauto_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
60+
similar(v))ForwardDiff.hessian(g, x) * v
61+
@test numauto_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v
5462

55-
@test autonum_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
63+
@test autonum_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
5664
@test autonum_hesvec!(dy, g, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v rtol=1e-2
57-
@test autonum_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
65+
@test autonum_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v
5866

59-
@test numback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
60-
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-8
61-
@test numback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
67+
@test numback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-6
68+
@test numback_hesvec!(dy, g, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-6
69+
@test numback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-6
6270

6371
cache3 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
6472
}.(x, ForwardDiff.Partials.(tuple.(v)))
6573
cache4 = ForwardDiff.Dual{typeof(ForwardDiff.Tag(Nothing, eltype(x))), eltype(x), 1
6674
}.(x, ForwardDiff.Partials.(tuple.(v)))
67-
@test autoback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
68-
@test autoback_hesvec!(dy, g, x, v, cache3, cache4)ForwardDiff.hessian(g, x) * v rtol=1e-8
69-
@test autoback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-8
75+
@test autoback_hesvec!(dy, g, x, v)ForwardDiff.hessian(g, x) * v
76+
@test autoback_hesvec!(dy, g, x, v, cache3, cache4)ForwardDiff.hessian(g, x) * v
77+
@test autoback_hesvec(g, x, v)ForwardDiff.hessian(g, x) * v
7078

7179
@test num_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
7280
@test num_hesvecgrad!(dy, h, x, v, similar(v), similar(v))ForwardDiff.hessian(g, x) * v rtol=1e-2
7381
@test num_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
7482

75-
@test auto_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
76-
@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v rtol=1e-2
77-
@test auto_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v rtol=1e-2
83+
@test auto_hesvecgrad!(dy, h, x, v)ForwardDiff.hessian(g, x) * v
84+
@test auto_hesvecgrad!(dy, h, x, v, cache1, cache2)ForwardDiff.hessian(g, x) * v
85+
@test auto_hesvecgrad(h, x, v)ForwardDiff.hessian(g, x) * v
7886

7987
@info "JacVec"
8088

81-
L = JacVec(f, x, 1.0, 1.0)
89+
L = JacVec(f, copy(x), 1.0, 1.0)
8290
update_coefficients!(f, x, 1.0, 1.0)
8391
@test L * x auto_jacvec(f, x, x)
8492
@test L * v auto_jacvec(f, x, v)
8593
@test mul!(dy, L, v) auto_jacvec(f, x, v)
8694
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
8795
update_coefficients!(L, v, 3.0, 4.0)
8896
update_coefficients!(f, v, 3.0, 4.0)
89-
@test mul!(dy, L, v) auto_jacvec(f, v, v)
90-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
97+
@test mul!(dy, L, x) auto_jacvec(f, v, x)
98+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*auto_jacvec(f,v,x) + b*_dy
9199
update_coefficients!(f, v, 5.0, 6.0)
92-
@test L(dy, v, 5.0, 6.0) auto_jacvec(f,x,v)
100+
@test L(dy, v, 5.0, 6.0) auto_jacvec(f,v,v)
93101

94-
L = JacVec(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
102+
L = JacVec(f, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
95103
update_coefficients!(f, x, 1.0, 1.0)
96104
@test L * x num_jacvec(f, x, x)
97105
@test L * v num_jacvec(f, x, v)
98106
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
99107
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
100108
update_coefficients!(L, v, 3.0, 4.0)
101109
update_coefficients!(f, v, 3.0, 4.0)
102-
@test mul!(dy, L, v)num_jacvec(f, v, v) rtol=1e-6
103-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_jacvec(f,x,v) + b*_dy rtol=1e-6
110+
@test mul!(dy, L, x)num_jacvec(f, v, x) rtol=1e-6
111+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*num_jacvec(f,v,x) + b*_dy rtol=1e-6
104112
update_coefficients!(f, v, 5.0, 6.0)
105-
@test L(dy, v, 5.0, 6.0) num_jacvec(f,x,v) rtol=1e-6
113+
@test L(dy, v, 5.0, 6.0) num_jacvec(f,v,v) rtol=1e-6
106114

107115
out = similar(v)
108116
@test_nowarn gmres!(out, L, v)
109117

110118
@info "HesVec"
111119

112-
x = rand(N)
113-
v = rand(N)
114-
L = HesVec(g, x, 1.0, 1.0, autodiff = AutoFiniteDiff())
120+
L = HesVec(g, copy(x), 1.0, 1.0, autodiff = AutoFiniteDiff())
115121
update_coefficients!(g, x, 1.0, 1.0)
116122
@test L * x num_hesvec(g, x, x) rtol=1e-2
117-
num_hesvec(g, x, x)
118123
@test L * v num_hesvec(g, x, v) rtol=1e-2
119124
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
120125
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
121126
update_coefficients!(L, v, 3.0, 4.0)
122127
update_coefficients!(g, v, 3.0, 4.0)
123-
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
124-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*num_hesvec(g,x,v) + b*_dy rtol=1e-2
128+
@test mul!(dy, L, x)num_hesvec(g, v, x) rtol=1e-2
129+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b) a*num_hesvec(g,v,x) + b*_dy rtol=1e-2
125130
update_coefficients!(g, v, 5.0, 6.0)
126-
@test L(dy, v, 5.0, 6.0) num_hesvec(g,x,v) rtol=1e-2
131+
@test L(dy, v, 5.0, 6.0) num_hesvec(g,v,v) rtol=1e-2
127132

128-
L = HesVec(g, x, 1.0, 1.0)
129-
update_coefficients!(g, x, 1.0, 1.0)
130-
numauto_hesvec(g, x, x)
131-
num_hesvec(g, x, x)
133+
L = HesVec(g, copy(x), 1.0, 1.0)
132134
@test L * x numauto_hesvec(g, x, x)
133135
@test L * v numauto_hesvec(g, x, v)
134-
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
135-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
136+
@test mul!(dy, L, v)numauto_hesvec(g, x, v)
137+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,0)a*numauto_hesvec(g,x,v)+0*_dy
136138
update_coefficients!(L, v, 3.0, 4.0)
137139
update_coefficients!(g, v, 3.0, 4.0)
138-
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
139-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
140+
@test mul!(dy, L, x)numauto_hesvec(g, v, x)
141+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*numauto_hesvec(g,v,x)+b*_dy
140142
update_coefficients!(g, v, 5.0, 6.0)
141-
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g,x,v) rtol=1e-2
143+
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g,v,v)
142144

143145
out = similar(v)
144146
gmres!(out, L, v)
145147

146-
x = rand(N)
147-
v = rand(N)
148-
149-
L = HesVec(g, x, 1.0, 1.0; autodiff = AutoZygote())
148+
L = HesVec(g, copy(x), 1.0, 1.0; autodiff = AutoZygote())
150149
update_coefficients!(g, x, 1.0, 1.0)
151150
@test L * x autoback_hesvec(g, x, x)
152151
@test L * v autoback_hesvec(g, x, v)
153-
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8
154-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
152+
@test mul!(dy, L, v)autoback_hesvec(g, x, v)
153+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy
155154
update_coefficients!(L, v, 3.0, 4.0)
156155
update_coefficients!(g, v, 3.0, 4.0)
157-
@test mul!(dy, L, v)autoback_hesvec(g, v, v) rtol=1e-8
158-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
156+
@test mul!(dy, L, x)autoback_hesvec(g, v, x)
157+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*autoback_hesvec(g,v,x)+b*_dy
159158
update_coefficients!(g, v, 5.0, 6.0)
160-
@test L(dy, v, 5.0, 6.0) autoback_hesvec(g,x,v) rtol=1e-2
159+
@test L(dy, v, 5.0, 6.0) autoback_hesvec(g,v,v)
161160

162161
out = similar(v)
163162
gmres!(out, L, v)
164163

165164
@info "HesVecGrad"
166165

167-
x = rand(N)
168-
v = rand(N)
169-
L = HesVecGrad(h, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
166+
L = HesVecGrad(h, copy(x), 1.0, 1.0; autodiff = AutoFiniteDiff())
170167
update_coefficients!(h, x, 1.0, 1.0)
171168
update_coefficients!(g, x, 1.0, 1.0)
172169
@test L * x num_hesvec(g, x, x) rtol=1e-2
173170
@test L * v num_hesvec(g, x, v) rtol=1e-2
174171
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
175172
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
176173
for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end
177-
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
178-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
174+
@test mul!(dy, L, x)num_hesvec(g, v, x) rtol=1e-2
175+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*num_hesvec(g,v,x)+b*_dy rtol=1e-2
179176
update_coefficients!(g, v, 5.0, 6.0)
180-
@test L(dy, v, 5.0, 6.0) num_hesvec(g,x,v) rtol=1e-2
177+
@test L(dy, v, 5.0, 6.0) num_hesvec(g,v,v) rtol=1e-2
181178

182-
L = HesVecGrad(h, x, 1.0, 1.0)
179+
L = HesVecGrad(h, copy(x), 1.0, 1.0)
183180
update_coefficients!(g, x, 1.0, 1.0)
184181
update_coefficients!(h, x, 1.0, 1.0)
185182
@test L * x autonum_hesvec(g, x, x)
186183
@test L * v numauto_hesvec(g, x, v)
187184
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8
188-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
185+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy
189186
for op in (L, g, h) update_coefficients!(op, v, 3.0, 4.0) end
190-
@test mul!(dy, L, v)numauto_hesvec(g, v, v) rtol=1e-8
191-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*numauto_hesvec(g,x,v)+b*_dy rtol=1e-8
187+
@test mul!(dy, L, x)numauto_hesvec(g, v, x)
188+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,x,a,b)a*numauto_hesvec(g,v,x)+b*_dy
192189
update_coefficients!(g, v, 5.0, 6.0)
193190
update_coefficients!(h, v, 5.0, 6.0)
194-
@test L(dy, v, 5.0, 6.0) num_hesvec(g,x,v) rtol=1e-2
191+
@test L(dy, v, 5.0, 6.0) numauto_hesvec(g,v,v)
195192

196193
out = similar(v)
197194
gmres!(out, L, v)
195+
196+
# Test that x and v were not mutated
197+
# x's rtol can't be too large since it is mutated and then restored in some algorithms
198+
@test x x0
199+
@test v v0
200+
198201
#

test/test_vecjac_products.jl

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,61 @@ using LinearAlgebra, Test
44
using Random
55
Random.seed!(123)
66
N = 300
7-
const A = rand(N, N)
87

8+
# Use Float32 since Zygote defaults to Float32
99
x = rand(Float32, N)
1010
v = rand(Float32, N)
1111

12+
# Save original values of x and v to make sure they are not ever mutated
13+
x0 = copy(x)
14+
v0 = copy(v)
15+
16+
a, b = rand(2)
17+
dy = similar(x)
18+
19+
A = rand(Float32, N, N)
1220
_f(du,u) = mul!(du, A, u)
1321
_f(u) = A * u
1422

1523
# Define state-dependent functions for operator tests
1624
include("update_coeffs_testutils.jl")
17-
f = WrapFunc(_f, 1.0, 1.0)
25+
f = WrapFunc(_f, 1.0f0, 1.0f0)
26+
27+
# Compute Jacobian via Zygote
1828

1929
@info "VecJac"
2030

21-
L = VecJac(f, x, 1.0, 1.0)
31+
L = VecJac(f, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote())
32+
update_coefficients!(f, x, 1.0, 1.0)
33+
actual_jac = Zygote.jacobian(f, x)[1]
34+
@test L * x actual_jac' * x
35+
@test L * v actual_jac' * v
36+
@test mul!(dy, L, v) actual_jac' * v
2237
update_coefficients!(L, v, 3.0, 4.0)
2338
update_coefficients!(f, v, 3.0, 4.0)
24-
actual_vjp = Zygote.jacobian(f, x)[1]' * v
25-
@test L * v actual_vjp
39+
actual_jac = Zygote.jacobian(f, v)[1]
40+
@test mul!(dy, L, x) actual_jac' * x
41+
_dy=copy(dy); @test mul!(dy,L,x,a,b) a*actual_jac'*x + b*_dy
2642
update_coefficients!(f, v, 5.0, 6.0)
27-
actual_vjp2 = Zygote.jacobian(f, x)[1]' * v
28-
@test L(copy(v), v, 5.0, 6.0) actual_vjp2
43+
actual_jac = Zygote.jacobian(f, v)[1]
44+
@test L(dy, v, 5.0, 6.0) actual_jac' * v
2945

30-
L = VecJac(f, x, 1.0, 1.0; autodiff = AutoFiniteDiff())
46+
L = VecJac(f, copy(x), 1.0f0, 1.0f0; autodiff = AutoFiniteDiff())
47+
update_coefficients!(f, x, 1.0, 1.0)
48+
actual_jac = Zygote.jacobian(f, x)[1]
49+
@test L * x actual_jac' * x
50+
@test L * v actual_jac' * v
51+
@test mul!(dy, L, v) actual_jac' * v
3152
update_coefficients!(L, v, 3.0, 4.0)
3253
update_coefficients!(f, v, 3.0, 4.0)
33-
@test L * v actual_vjp
54+
actual_jac = Zygote.jacobian(f, v)[1]
55+
@test mul!(dy, L, x) actual_jac' * x
56+
_dy=copy(dy); @test mul!(dy,L,x,a,b) a*actual_jac'*x + b*_dy
3457
update_coefficients!(f, v, 5.0, 6.0)
35-
@test L(copy(v), v, 5.0, 6.0) actual_vjp2
58+
actual_jac = Zygote.jacobian(f, v)[1]
59+
@test L(dy, v, 5.0, 6.0) actual_jac' * v
60+
61+
# Test that x and v were not mutated
62+
@test x x0
63+
@test v v0
3664
#

0 commit comments

Comments
 (0)