Skip to content

Commit 748a33d

Browse files
Merge pull request #128 from chenwilliam77/master
Add method to mirror ForwardDiff's syntax for writing Jacobian of an oop function to an allocated matrix
2 parents 9637980 + 9b24650 commit 748a33d

File tree

2 files changed

+120
-3
lines changed

2 files changed

+120
-3
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,116 @@ end
6868
jac_prototype = nothing,
6969
chunksize = nothing,
7070
dx = sparsity === nothing && jac_prototype === nothing ? nothing : copy(x)) #if dx is nothing, we will estimate dx at the cost of a function call
71-
if sparsity === nothing && jac_prototype === nothing || !ArrayInterface.ismutable(x)
71+
72+
if sparsity === nothing && jac_prototype === nothing
7273
cfg = chunksize === nothing ? ForwardDiff.JacobianConfig(f, x) : ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize)))
7374
return ForwardDiff.jacobian(f, x, cfg)
7475
end
7576
if dx isa Nothing
7677
dx = f(x)
7778
end
78-
forwarddiff_color_jacobian(f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity),jac_prototype)
79+
return forwarddiff_color_jacobian(f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity),jac_prototype)
80+
end
81+
82+
@inline function forwarddiff_color_jacobian(J::AbstractArray{<:Number}, f,
83+
x::AbstractArray{<:Number};
84+
colorvec = 1:length(x),
85+
sparsity = nothing,
86+
jac_prototype = nothing,
87+
chunksize = nothing,
88+
dx = similar(x, size(J, 1))) #dx kwarg can be used to avoid re-allocating dx every time
89+
if sparsity === nothing && jac_prototype === nothing
90+
cfg = chunksize === nothing ? ForwardDiff.JacobianConfig(f, x) : ForwardDiff.JacobianConfig(f, x, ForwardDiff.Chunk(getsize(chunksize)))
91+
return ForwardDiff.jacobian(f, x, cfg)
92+
end
93+
return forwarddiff_color_jacobian(J,f,x,ForwardColorJacCache(f,x,chunksize,dx=dx,colorvec=colorvec,sparsity=sparsity))
7994
end
8095

8196
function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
97+
98+
if jac_prototype isa Nothing ? ArrayInterface.ismutable(x) : ArrayInterface.ismutable(jac_prototype)
99+
# Whenever J is mutable, we mutate it to avoid allocations
100+
dx = jac_cache.dx
101+
vecx = vec(x)
102+
sparsity = jac_cache.sparsity
103+
104+
J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec(dx) .* vecx' :
105+
zeros(eltype(x),size(sparsity))) : zero(jac_prototype)
106+
return forwarddiff_color_jacobian(J, f, x, jac_cache)
107+
else
108+
return forwarddiff_color_jacobian_immutable(f, x, jac_cache, jac_prototype)
109+
end
110+
end
111+
112+
# When J is mutable, this version of forwarddiff_color_jacobian will mutate J to avoid allocations
113+
function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache)
114+
t = jac_cache.t
115+
dx = jac_cache.dx
116+
p = jac_cache.p
117+
colorvec = jac_cache.colorvec
118+
sparsity = jac_cache.sparsity
119+
chunksize = jac_cache.chunksize
120+
color_i = 1
121+
maxcolor = maximum(colorvec)
122+
123+
vecx = vec(x)
124+
125+
nrows,ncols = size(J)
126+
127+
if !(sparsity isa Nothing)
128+
rows_index, cols_index = ArrayInterface.findstructralnz(sparsity)
129+
rows_index = [rows_index[i] for i in 1:length(rows_index)]
130+
cols_index = [cols_index[i] for i in 1:length(cols_index)]
131+
end
132+
133+
for i in eachindex(p)
134+
partial_i = p[i]
135+
t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t))
136+
fx = f(t)
137+
if !(sparsity isa Nothing)
138+
for j in 1:chunksize
139+
dx = vec(partials.(fx, j))
140+
pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i]
141+
rows_index_c = rows_index[pick_inds]
142+
cols_index_c = cols_index[pick_inds]
143+
if J isa SparseMatrixCSC || j > 1
144+
# Use sparse matrix to add to J column by column except . . .
145+
Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
146+
else
147+
# To overwrite pre-allocated matrix J, using sparse will cause an error
148+
# so we use this step to overwrite J
149+
len_rows = length(pick_inds)
150+
unused_rows = setdiff(1:nrows,rows_index_c)
151+
perm_rows = sortperm(vcat(rows_index_c,unused_rows))
152+
cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows]
153+
Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols]
154+
end
155+
if j == 1 && i == 1
156+
J .= Ji # overwrite pre-allocated matrix J
157+
else
158+
J .+= Ji
159+
end
160+
color_i += 1
161+
(color_i > maxcolor) && return J
162+
end
163+
else
164+
for j in 1:chunksize
165+
col_index = (i-1)*chunksize + j
166+
(col_index > ncols) && return J
167+
Ji = mapreduce(i -> i==col_index ? partials.(vec(fx), j) : adapt(parameterless_type(J),zeros(eltype(J),nrows)), hcat, 1:ncols)
168+
if j == 1 && i == 1
169+
J .= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) # overwrite pre-allocated matrix
170+
else
171+
J .+= (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
172+
end
173+
end
174+
end
175+
end
176+
return J
177+
end
178+
179+
# When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J
180+
function forwarddiff_color_jacobian_immutable(f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
82181
t = jac_cache.t
83182
dx = jac_cache.dx
84183
p = jac_cache.p
@@ -131,7 +230,7 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
131230
end
132231
end
133232
end
134-
J
233+
return J
135234
end
136235

137236
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},

test/test_ad.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ForwardDiff: Dual, jacobian
33
using SparseArrays, Test
44
using LinearAlgebra
55
using BlockBandedMatrices
6+
using BandedMatrices
67
using StaticArrays
78

89
fcalls = 0
@@ -115,6 +116,23 @@ _J1 = forwarddiff_color_jacobian(oopf, x, colorvec = repeat(1:3,10), sparsity =
115116
@test _J1 J
116117
@test fcalls == 1
117118

119+
#oop with in-place Jacobian
120+
fcalls = 0
121+
_oop_jacout = sparse(1.01 .* J) # want to be nonzero to check that the pre-allocated matrix is overwritten properly
122+
forwarddiff_color_jacobian(_oop_jacout, oopf, x; colorvec = repeat(1:3,10), sparsity = _J, jac_prototype = _J)
123+
@test _oop_jacout J
124+
@test typeof(_oop_jacout) == typeof(_J)
125+
@test fcalls == 1
126+
127+
# BandedMatrix
128+
_oop_jacout = BandedMatrix(-1 => diag(J, -1) .* 1.01, 0 => diag(J, 0) .* 1.01,
129+
1 => diag(J, 1) .* 1.01) # check w/BandedMatrix instead of sparse
130+
fcalls = 0
131+
forwarddiff_color_jacobian(_oop_jacout, oopf, x; colorvec = repeat(1:3,10), sparsity = _J)
132+
@test _oop_jacout J
133+
@test isa(_oop_jacout, BandedMatrix)
134+
@test fcalls == 1
135+
118136
@info "4th passed"
119137

120138
fcalls = 0

0 commit comments

Comments
 (0)