18
18
19
19
getsize (:: Val{N} ) where N = N
20
20
getsize (N:: Integer ) = N
21
+ void_setindex! (args... ) = (setindex! (args... ); return )
21
22
22
23
function ForwardColorJacCache (f,x,_chunksize = nothing ;
23
24
dx = nothing ,
@@ -30,15 +31,15 @@ function ForwardColorJacCache(f,x,_chunksize = nothing;
30
31
chunksize = _chunksize
31
32
end
32
33
33
- p = adapt .(typeof (x),generate_chunked_partials (x,colorvec,chunksize))
34
+ p = adapt .(parameterless_type (x),generate_chunked_partials (x,colorvec,chunksize))
34
35
_t = Dual {ForwardDiff.Tag(f,eltype(vec(x)))} .(vec (x),first (p))
35
36
t = ArrayInterface. restructure (x,_t)
36
37
if dx isa Nothing
37
38
fx = similar (t)
38
39
_dx = similar (x)
39
40
else
40
- tup = first ( first (p) ) .* false
41
- _pi = adapt .( typeof (dx),[tup for i in 1 : length (dx)])
41
+ tup = ArrayInterface . allowed_getindex (ArrayInterface . allowed_getindex (p, 1 ), 1 ) .* false
42
+ _pi = adapt ( parameterless_type (dx),[tup for i in 1 : length (dx)])
42
43
fx = reshape (Dual {ForwardDiff.Tag(f,eltype(vec(x)))} .(vec (dx),_pi),size (dx)... )
43
44
_dx = dx
44
45
end
@@ -121,7 +122,7 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
121
122
for j in 1 : chunksize
122
123
col_index = (i- 1 )* chunksize + j
123
124
(col_index > ncols) && return J
124
- Ji = mapreduce (i -> i== col_index ? partials .(vec (fx), j) : zeros (nrows), hcat, 1 : ncols)
125
+ Ji = mapreduce (i -> i== col_index ? partials .(vec (fx), j) : adapt ( parameterless_type (J), zeros (eltype (J), nrows) ), hcat, 1 : ncols)
125
126
J = J + (size (Ji)!= size (J) ? reshape (Ji,size (J)) : Ji) # branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
126
127
end
127
128
end
0 commit comments