Skip to content

Commit f6c4023

Browse files
committed
Remove J argument for the immutable oop autodiff function.
1 parent 6dd73e4 commit f6c4023

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,18 @@ end
9494
end
9595

9696
function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
97-
dx = jac_cache.dx
98-
vecx = vec(x)
99-
sparsity = jac_cache.sparsity
10097

101-
J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec(dx) .* vecx' : zeros(eltype(x),size(sparsity))) : zero(jac_prototype)
98+
if jac_prototype isa Nothing ? ArrayInterface.ismutable(x) : ArrayInterface.ismutable(jac_prototype)
99+
# Whenever J is mutable, we mutate it to avoid allocations
100+
dx = jac_cache.dx
101+
vecx = vec(x)
102+
sparsity = jac_cache.sparsity
102103

103-
if ArrayInterface.ismutable(J) # Whenever J is mutable, we mutate it to avoid allocations
104+
J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec(dx) .* vecx' :
105+
zeros(eltype(x),size(sparsity))) : zero(jac_prototype)
104106
forwarddiff_color_jacobian(J, f, x, jac_cache)
105107
else
106-
forwarddiff_color_jacobian_immutable(J, f, x, jac_cache)
108+
return forwarddiff_color_jacobian_immutable(f, x, jac_cache, jac_prototype)
107109
end
108110
end
109111

@@ -138,17 +140,20 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f,x::AbstractArr
138140
pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i]
139141
rows_index_c = rows_index[pick_inds]
140142
cols_index_c = cols_index[pick_inds]
141-
if J isa SparseMatrixCSC
143+
if J isa SparseMatrixCSC || j > 1
144+
# Use sparse matrix to add to J column by column except . . .
142145
Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
143146
else
147+
# To overwrite pre-allocated matrix J, using sparse will cause an error
148+
# so we use this step to overwrite J
144149
len_rows = length(pick_inds)
145150
unused_rows = setdiff(1:nrows,rows_index_c)
146151
perm_rows = sortperm(vcat(rows_index_c,unused_rows))
147152
cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows]
148153
Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols]
149154
end
150155
if j == 1 && i == 1
151-
J .= Ji # overwrite pre-allocated matrix
156+
J .= Ji # overwrite pre-allocated matrix J
152157
else
153158
J .+= Ji
154159
end
@@ -172,7 +177,7 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f,x::AbstractArr
172177
end
173178

174179
# When J is immutable, this version of forwarddiff_color_jacobian will avoid mutating J
175-
function forwarddiff_color_jacobian_immutable(J::AbstractArray{<:Number},f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache)
180+
function forwarddiff_color_jacobian_immutable(f,x::AbstractArray{<:Number},jac_cache::ForwardColorJacCache,jac_prototype=nothing)
176181
t = jac_cache.t
177182
dx = jac_cache.dx
178183
p = jac_cache.p
@@ -184,6 +189,7 @@ function forwarddiff_color_jacobian_immutable(J::AbstractArray{<:Number},f,x::Ab
184189

185190
vecx = vec(x)
186191

192+
J = jac_prototype isa Nothing ? (sparsity isa Nothing ? false .* vec(dx) .* vecx' : zeros(eltype(x),size(sparsity))) : zero(jac_prototype)
187193
nrows,ncols = size(J)
188194

189195
if !(sparsity isa Nothing)

0 commit comments

Comments
 (0)