Skip to content

Commit 4ce04f3

Browse files
Merge pull request #103 from JuliaDiff/gpu
Fix standard GPU AD
2 parents afaeea7 + 12e1811 commit 4ce04f3

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/SparseDiffTools.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ include("coloring/matrix2graph.jl")
4646
include("differentiation/compute_jacobian_ad.jl")
4747
include("differentiation/jaches_products.jl")
4848

49+
Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
50+
parameterless_type(x) = parameterless_type(typeof(x))
51+
parameterless_type(x::Type) = __parameterless_type(x)
52+
4953
function __init__()
5054
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
5155
export numback_hesvec, numback_hesvec!, autoback_hesvec, autoback_hesvec!

src/differentiation/compute_jacobian_ad.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ end
1818

1919
getsize(::Val{N}) where N = N
2020
getsize(N::Integer) = N
21+
void_setindex!(args...) = (setindex!(args...); return)
2122

2223
function ForwardColorJacCache(f,x,_chunksize = nothing;
2324
dx = nothing,
@@ -30,15 +31,15 @@ function ForwardColorJacCache(f,x,_chunksize = nothing;
3031
chunksize = _chunksize
3132
end
3233

33-
p = adapt.(typeof(x),generate_chunked_partials(x,colorvec,chunksize))
34+
p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize))
3435
_t = Dual{ForwardDiff.Tag(f,eltype(vec(x)))}.(vec(x),first(p))
3536
t = ArrayInterface.restructure(x,_t)
3637
if dx isa Nothing
3738
fx = similar(t)
3839
_dx = similar(x)
3940
else
40-
tup = first(first(p)) .* false
41-
_pi = adapt.(typeof(dx),[tup for i in 1:length(dx)])
41+
tup = ArrayInterface.allowed_getindex(ArrayInterface.allowed_getindex(p,1),1) .* false
42+
_pi = adapt(parameterless_type(dx),[tup for i in 1:length(dx)])
4243
fx = reshape(Dual{ForwardDiff.Tag(f,eltype(vec(x)))}.(vec(dx),_pi),size(dx)...)
4344
_dx = dx
4445
end
@@ -121,7 +122,7 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
121122
for j in 1:chunksize
122123
col_index = (i-1)*chunksize + j
123124
(col_index > ncols) && return J
124-
Ji = mapreduce(i -> i==col_index ? partials.(vec(fx), j) : zeros(nrows), hcat, 1:ncols)
125+
Ji = mapreduce(i -> i==col_index ? partials.(vec(fx), j) : adapt(parameterless_type(J),zeros(eltype(J),nrows)), hcat, 1:ncols)
125126
J = J + (size(Ji)!=size(J) ? reshape(Ji,size(J)) : Ji) #branch when size(dx) == (1,) => size(Ji) == (1,) while size(J) == (1,1)
126127
end
127128
end

0 commit comments

Comments
 (0)