Skip to content

Commit 18d1a42

Browse files
Merge pull request #63 from JuliaDiffEq/vecjac
non-vector support in color differentiation
2 parents 64c50f5 + 5dfd5c3 commit 18d1a42

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ function ForwardColorJacCache(f,x,_chunksize = nothing;
3131
end
3232

3333
p = adapt.(typeof(x),generate_chunked_partials(x,colorvec,chunksize))
34-
t = Dual{typeof(f)}.(x,first(p))
34+
t = reshape(Dual{typeof(f)}.(vec(x),first(p)),size(x)...)
3535

3636
if dx === nothing
3737
fx = similar(t)
3838
_dx = similar(x)
3939
else
40-
fx = Dual{typeof(f)}.(dx,first(p))
40+
fx = reshape(Dual{typeof(f)}.(vec(dx),first(p)),size(x)...)
4141
_dx = dx
4242
end
4343

@@ -109,17 +109,22 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
109109
cols_index = nothing
110110
end
111111

112+
vecx = vec(x)
113+
vect = vec(t)
114+
vecfx= vec(fx)
115+
vecdx= vec(dx)
116+
112117
ncols=size(J,2)
113118

114119
for i in eachindex(p)
115120
partial_i = p[i]
116-
t .= Dual{typeof(f)}.(x, partial_i)
121+
vect .= Dual{typeof(f)}.(vecx, partial_i)
117122
f(fx,t)
118123
if !(sparsity isa Nothing)
119124
for j in 1:chunksize
120125
dx .= partials.(fx, j)
121126
if ArrayInterface.fast_scalar_indexing(dx)
122-
DiffEqDiffTools._colorediteration!(J,sparsity,rows_index,cols_index,dx,colorvec,color_i,ncols)
127+
DiffEqDiffTools._colorediteration!(J,sparsity,rows_index,cols_index,vecdx,colorvec,color_i,ncols)
123128
else
124129
#=
125130
J.nzval[rows_index] .+= (colorvec[cols_index] .== color_i) .* dx[rows_index]
@@ -128,9 +133,9 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
128133
+= means requires a zero'd out start
129134
=#
130135
if J isa SparseMatrixCSC
131-
@. setindex!((J.nzval,),getindex((J.nzval,),rows_index) + (getindex((colorvec,),cols_index) == color_i) * getindex((dx,),rows_index),rows_index)
136+
@. setindex!((J.nzval,),getindex((J.nzval,),rows_index) + (getindex((colorvec,),cols_index) == color_i) * getindex((vecdx,),rows_index),rows_index)
132137
else
133-
@. setindex!((J,),getindex((J,),rows_index, cols_index) + (getindex((colorvec,),cols_index) == color_i) * getindex((dx,),rows_index),rows_index, cols_index)
138+
@. setindex!((J,),getindex((J,),rows_index, cols_index) + (getindex((colorvec,),cols_index) == color_i) * getindex((vecdx,),rows_index),rows_index, cols_index)
134139
end
135140
end
136141
color_i += 1
@@ -140,7 +145,7 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
140145
for j in 1:chunksize
141146
col_index = (i-1)*chunksize + j
142147
(col_index > maximum(colorvec)) && return
143-
J[:, col_index] .= partials.(fx, j)
148+
J[:, col_index] .= partials.(vecfx, j)
144149
end
145150
end
146151
end

0 commit comments

Comments
 (0)