Skip to content

Commit 4cec56e

Browse files
Merge branch 'sparsity'
2 parents 6341f93 + 506c98b commit 4cec56e

File tree

4 files changed

+37
-10
lines changed

4 files changed

+37
-10
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ julia = "1"
2222
[extras]
2323
DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa"
2424
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
25+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2526
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2627

2728
[targets]
28-
test = ["Test", "DiffEqDiffTools", "IterativeSolvers"]
29+
test = ["Test", "DiffEqDiffTools", "IterativeSolvers", "Random"]

src/differentiation/compute_jacobian_ad.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
struct ForwardColorJacCache{T,T2,T3,T4,T5}
1+
struct ForwardColorJacCache{T,T2,T3,T4,T5,T6}
22
t::T
33
fx::T2
44
dx::T3
55
p::T4
66
color::T5
7+
sparsity::T6
78
end
89

910
function default_chunk_size(maxcolor)
@@ -19,7 +20,8 @@ getsize(N::Integer) = N
1920

2021
function ForwardColorJacCache(f,x,_chunksize = nothing;
2122
dx = nothing,
22-
color=1:length(x))
23+
color=1:length(x),
24+
sparsity::Union{SparseMatrixCSC,Nothing}=nothing)
2325

2426
if _chunksize === nothing
2527
chunksize = default_chunk_size(maximum(color))
@@ -38,7 +40,7 @@ function ForwardColorJacCache(f,x,_chunksize = nothing;
3840
end
3941

4042
p = generate_chunked_partials(x,color,chunksize)
41-
ForwardColorJacCache(t,fx,_dx,p,color)
43+
ForwardColorJacCache(t,fx,_dx,p,color,sparsity)
4244
end
4345

4446
generate_chunked_partials(x,color,N::Integer) = generate_chunked_partials(x,color,Val(N))
@@ -78,8 +80,9 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
7880
f,
7981
x::AbstractArray{<:Number};
8082
dx = nothing,
81-
color = eachindex(x))
82-
forwarddiff_color_jacobian!(J,f,x,ForwardColorJacCache(f,x,dx=dx,color=color))
83+
color = eachindex(x),
84+
sparsity = J isa SparseMatrixCSC ? J : nothing)
85+
forwarddiff_color_jacobian!(J,f,x,ForwardColorJacCache(f,x,dx=dx,color=color,sparsity=sparsity))
8386
end
8487

8588
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
@@ -92,15 +95,16 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
9295
dx = jac_cache.dx
9396
p = jac_cache.p
9497
color = jac_cache.color
98+
sparsity = jac_cache.sparsity
9599
color_i = 1
96100
chunksize = length(first(first(jac_cache.p)))
97101

98102
for i in 1:length(p)
99103
partial_i = p[i]
100104
t .= Dual{typeof(f)}.(x, partial_i)
101105
f(fx,t)
102-
if J isa SparseMatrixCSC
103-
rows_index, cols_index, val = findnz(J)
106+
if sparsity isa SparseMatrixCSC
107+
rows_index, cols_index, val = findnz(sparsity)
104108
for j in 1:chunksize
105109
dx .= partials.(fx, j)
106110
for k in 1:length(cols_index)

test/test_ad.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,22 @@ forwarddiff_color_jacobian!(_J1, f, x, color = repeat(1:3,10))
3636
@test fcalls == 1
3737

3838
fcalls = 0
39-
jac_cache = ForwardColorJacCache(f,x,color = repeat(1:3,10))
39+
jac_cache = ForwardColorJacCache(f,x,color = repeat(1:3,10), sparsity = _J1)
4040
forwarddiff_color_jacobian!(_J1, f, x, jac_cache)
4141
@test _J1 J
4242
@test fcalls == 1
43+
44+
fcalls = 0
45+
_J1 = similar(_J)
46+
_denseJ1 = collect(_J1)
47+
forwarddiff_color_jacobian!(_denseJ1, f, x, color = repeat(1:3,10), sparsity = _J1)
48+
@test _denseJ1 J
49+
@test fcalls == 1
50+
51+
fcalls = 0
52+
_J1 = similar(_J)
53+
_denseJ1 = collect(_J1)
54+
jac_cache = ForwardColorJacCache(f,x,color = repeat(1:3,10), sparsity = _J1)
55+
forwarddiff_color_jacobian!(_denseJ1, f, x, jac_cache)
56+
@test _denseJ1 J
57+
@test fcalls == 1

test/test_integration.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@ J = DiffEqDiffTools.finite_difference_jacobian(f, rand(30))
3737

3838
#Jacobian computed with coloring vectors
3939
fcalls = 0
40-
_J = 200 .* true_jac
40+
_J = similar(true_jac)
4141
DiffEqDiffTools.finite_difference_jacobian!(_J, f, rand(30), color = colors)
4242
@test fcalls == 4
4343
@test _J J
44+
45+
fcalls = 0
46+
_J = similar(true_jac)
47+
_denseJ = collect(_J)
48+
DiffEqDiffTools.finite_difference_jacobian!(_denseJ, f, rand(30), color = colors, sparsity=_J)
49+
@test fcalls == 4
50+
@test _denseJ J

0 commit comments

Comments
 (0)