Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.

Commit 1e1c7c1

Browse files
Merge pull request #170 from JuliaDiff/set_tag
Allow users to set the tag for the configs
2 parents cb76902 + aa9fc34 commit 1e1c7c1

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@ end
1111
getsize(::Val{N}) where N = N
1212
getsize(N::Integer) = N
1313
void_setindex!(args...) = (setindex!(args...); return)
14+
gettag(::Type{ForwardDiff.Dual{T}}) where {T} = T
1415

1516
const default_chunk_size = ForwardDiff.pickchunksize
17+
const SMALLTAG = ForwardDiff.Tag(missing,Float64)
1618

1719
function ForwardColorJacCache(f::F,x,_chunksize = nothing;
1820
dx = nothing,
21+
tag = nothing,
1922
colorvec=1:length(x),
2023
sparsity::Union{AbstractArray,Nothing}=nothing) where {F}
2124

@@ -25,15 +28,21 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
2528
chunksize = _chunksize
2629
end
2730

31+
if tag === nothing
32+
T = typeof(ForwardDiff.Tag(f,eltype(vec(x))))
33+
else
34+
T = tag
35+
end
36+
2837
if x isa Array
2938
p = generate_chunked_partials(x,colorvec,chunksize)
30-
t = similar(x,Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x)))),eltype(x),length(first(first(p)))})
39+
t = similar(x,Dual{T})
3140
for i in eachindex(t)
32-
t[i] = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x)))),eltype(x),length(first(first(p)))}(x[i],ForwardDiff.Partials(first(p)[i]))
41+
t[i] = Dual{T,eltype(x),length(first(first(p)))}(x[i],ForwardDiff.Partials(first(p)[i]))
3342
end
3443
else
3544
p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize))
36-
_t = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(x),first(p))
45+
_t = Dual{T,eltype(x),getsize(chunksize)}.(vec(x),ForwardDiff.Partials.(first(p)))
3746
t = ArrayInterface.restructure(x,_t)
3847
end
3948

@@ -44,7 +53,7 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
4453
else
4554
tup = ArrayInterface.allowed_getindex(ArrayInterface.allowed_getindex(p,1),1) .* false
4655
_pi = adapt(parameterless_type(dx),[tup for i in 1:length(dx)])
47-
fx = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(dx),_pi),size(dx)...)
56+
fx = reshape(Dual{T,eltype(dx),length(tup)}.(vec(dx),ForwardDiff.Partials.(_pi)),size(dx)...)
4857
_dx = dx
4958
end
5059

@@ -162,7 +171,7 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f::F,x::Abstract
162171

163172
for i in eachindex(p)
164173
partial_i = p[i]
165-
t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t))
174+
t = reshape(eltype(t).(vecx, ForwardDiff.Partials.(partial_i)),size(t))
166175
fx = f(t)
167176
if !(sparsity isa Nothing)
168177
for j in 1:chunksize
@@ -230,7 +239,7 @@ function forwarddiff_color_jacobian_immutable(f,x::AbstractArray{<:Number},jac_c
230239

231240
for i in eachindex(p)
232241
partial_i = p[i]
233-
t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t))
242+
t = reshape(eltype(t).(vecx, ForwardDiff.Partials.(partial_i)),size(t))
234243
fx = f(t)
235244
if !(sparsity isa Nothing)
236245
for j in 1:chunksize
@@ -311,10 +320,10 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
311320

312321
if vect isa Array
313322
@inbounds @simd ivdep for j in eachindex(vect)
314-
vect[j] = Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}(vecx[j], partial_i[j])
323+
vect[j] = eltype(t)(vecx[j], ForwardDiff.Partials(partial_i[j]))
315324
end
316325
else
317-
vect .= Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i)
326+
vect .= eltype(t).(vecx, ForwardDiff.Partials.(partial_i))
318327
end
319328

320329
f(fx,t)

0 commit comments

Comments
 (0)