@@ -31,13 +31,13 @@ function ForwardColorJacCache(f,x,_chunksize = nothing;
31
31
end
32
32
33
33
p = adapt .(typeof (x),generate_chunked_partials (x,colorvec,chunksize))
34
- t = Dual {typeof(f)} .(x ,first (p))
34
+ t = reshape ( Dual {typeof(f)} .(vec (x) ,first (p)), size (x) ... )
35
35
36
36
if dx === nothing
37
37
fx = similar (t)
38
38
_dx = similar (x)
39
39
else
40
- fx = Dual {typeof(f)} .(dx ,first (p))
40
+ fx = reshape ( Dual {typeof(f)} .(vec (dx) ,first (p)), size (x) ... )
41
41
_dx = dx
42
42
end
43
43
@@ -109,17 +109,22 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
109
109
cols_index = nothing
110
110
end
111
111
112
+ vecx = vec (x)
113
+ vect = vec (t)
114
+ vecfx= vec (fx)
115
+ vecdx= vec (dx)
116
+
112
117
ncols= size (J,2 )
113
118
114
119
for i in eachindex (p)
115
120
partial_i = p[i]
116
- t .= Dual {typeof(f)} .(x , partial_i)
121
+ vect .= Dual {typeof(f)} .(vecx , partial_i)
117
122
f (fx,t)
118
123
if ! (sparsity isa Nothing)
119
124
for j in 1 : chunksize
120
125
dx .= partials .(fx, j)
121
126
if ArrayInterface. fast_scalar_indexing (dx)
122
- DiffEqDiffTools. _colorediteration! (J,sparsity,rows_index,cols_index,dx ,colorvec,color_i,ncols)
127
+ DiffEqDiffTools. _colorediteration! (J,sparsity,rows_index,cols_index,vecdx ,colorvec,color_i,ncols)
123
128
else
124
129
#=
125
130
J.nzval[rows_index] .+= (colorvec[cols_index] .== color_i) .* dx[rows_index]
@@ -128,9 +133,9 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
128
133
+= means requires a zero'd out start
129
134
=#
130
135
if J isa SparseMatrixCSC
131
- @. setindex! ((J. nzval,),getindex ((J. nzval,),rows_index) + (getindex ((colorvec,),cols_index) == color_i) * getindex ((dx ,),rows_index),rows_index)
136
+ @. setindex! ((J. nzval,),getindex ((J. nzval,),rows_index) + (getindex ((colorvec,),cols_index) == color_i) * getindex ((vecdx ,),rows_index),rows_index)
132
137
else
133
- @. setindex! ((J,),getindex ((J,),rows_index, cols_index) + (getindex ((colorvec,),cols_index) == color_i) * getindex ((dx ,),rows_index),rows_index, cols_index)
138
+ @. setindex! ((J,),getindex ((J,),rows_index, cols_index) + (getindex ((colorvec,),cols_index) == color_i) * getindex ((vecdx ,),rows_index),rows_index, cols_index)
134
139
end
135
140
end
136
141
color_i += 1
@@ -140,7 +145,7 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
140
145
for j in 1 : chunksize
141
146
col_index = (i- 1 )* chunksize + j
142
147
(col_index > maximum (colorvec)) && return
143
- J[:, col_index] .= partials .(fx , j)
148
+ J[:, col_index] .= partials .(vecfx , j)
144
149
end
145
150
end
146
151
end
0 commit comments