Skip to content

Commit 60e8c93

Browse files
add GPU fixes
1 parent 462f9c5 commit 60e8c93

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,11 +318,13 @@ Jacobian-vector or Hessian-vector function, whereas `mul!(res,J,v)` utilizes
318318
the appropriate in-place versions. To update the location of differentiation
319319
in the operator, simply mutate the vector `u`: `J.u .= ...`.
320320

321-
# Note about sparse differentiation of BandedMatrices and BlockBandedMatrices
321+
# Note about sparse differentiation of GPUArrays, BandedMatrices, and BlockBandedMatrices
322322

323323
These two matrix types need the dependencies ArrayInterfaceBandedMatrices.jl and
324324
ArrayInterfaceBlockBandedMatrices.jl to basically work with any functionality
325325
(anywhere). For now, the right thing to do is to add these libraries and do
326326
`import` on them if you are using BandedMatrices.jl or BlockBandedMatrices.jl
327327
for sparsity patterns. In the future, those two packages should just depend on
328-
ArrayInterface.jl and remove this issue entirely from the user space.
328+
ArrayInterface.jl and remove this issue entirely from the user space.
329+
330+
Additionally, GPUs need ArrayInterfaceGPUArrays for proper determination of the indexing.

test/gpu/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
[deps]
22
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3+
ArrayInterfaceGPUArrays = "6ba088a2-8465-4c0a-af30-387133b534db"

test/test_gpu_ad.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using SparseDiffTools, CUDA, Test, LinearAlgebra
22
using ArrayInterfaceCore: allowed_getindex, allowed_setindex!
33
using SparseArrays
4+
using ArrayInterfaceGPUArrays
45

56
function f(dx,x)
67
dx[2:end-1] = x[1:end-2] - 2x[2:end-1] + x[3:end]

0 commit comments

Comments
 (0)