|
6 | 6 |
|
7 | 7 | module JLArrays |
8 | 8 |
|
9 | | -export JLArray, JLVector, JLMatrix, jl, JLBackend |
| 9 | +export JLArray, JLVector, JLMatrix, jl, JLBackend, JLSparseVector, JLSparseMatrixCSC, JLSparseMatrixCSR |
10 | 10 |
|
11 | 11 | using GPUArrays |
12 | 12 |
|
13 | 13 | using Adapt |
| 14 | +using SparseArrays, LinearAlgebra |
| 15 | + |
| 16 | +import GPUArrays: _dense_array_type |
14 | 17 |
|
15 | 18 | import KernelAbstractions |
16 | 19 | import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config |
@@ -115,7 +118,102 @@ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N} |
115 | 118 | end |
116 | 119 | end |
117 | 120 |
|
| 121 | +mutable struct JLSparseVector{Tv, Ti} <: GPUArrays.AbstractGPUSparseVector{Tv, Ti} |
| 122 | + iPtr::JLArray{Ti, 1} |
| 123 | + nzVal::JLArray{Tv, 1} |
| 124 | + len::Int |
| 125 | + nnz::Ti |
| 126 | + |
| 127 | + function JLSparseVector{Tv, Ti}(iPtr::JLArray{<:Integer, 1}, nzVal::JLArray{Tv, 1}, |
| 128 | + len::Integer) where {Tv, Ti <: Integer} |
| 129 | + new{Tv, Ti}(iPtr, nzVal, len, length(nzVal)) |
| 130 | + end |
| 131 | +end |
| 132 | +SparseArrays.SparseVector(x::JLSparseVector) = SparseVector(length(x), Array(x.iPtr), Array(x.nzVal)) |
| 133 | +SparseArrays.nnz(x::JLSparseVector) = x.nnz |
| 134 | +SparseArrays.nonzeroinds(x::JLSparseVector) = x.iPtr |
| 135 | +SparseArrays.nonzeros(x::JLSparseVector) = x.nzVal |
| 136 | + |
| 137 | +mutable struct JLSparseMatrixCSC{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSC{Tv, Ti} |
| 138 | + colPtr::JLArray{Ti, 1} |
| 139 | + rowVal::JLArray{Ti, 1} |
| 140 | + nzVal::JLArray{Tv, 1} |
| 141 | + dims::NTuple{2,Int} |
| 142 | + nnz::Ti |
| 143 | + |
| 144 | + function JLSparseMatrixCSC{Tv, Ti}(colPtr::JLArray{<:Integer, 1}, rowVal::JLArray{<:Integer, 1}, |
| 145 | + nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} |
| 146 | + new{Tv, Ti}(colPtr, rowVal, nzVal, dims, length(nzVal)) |
| 147 | + end |
| 148 | +end |
| 149 | +function JLSparseMatrixCSC(colPtr::JLArray{Ti, 1}, rowVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} |
| 150 | + return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, dims) |
| 151 | +end |
| 152 | +SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSC) = SparseMatrixCSC(size(x)..., Array(x.colPtr), Array(x.rowVal), Array(x.nzVal)) |
| 153 | + |
| 154 | +JLSparseMatrixCSC(A::JLSparseMatrixCSC) = A |
| 155 | + |
| 156 | +function Base.getindex(A::JLSparseMatrixCSC{Tv, Ti}, i::Integer, j::Integer) where {Tv, Ti} |
| 157 | + r1 = Int(@inbounds A.colPtr[j]) |
| 158 | + r2 = Int(@inbounds A.colPtr[j+1]-1) |
| 159 | + (r1 > r2) && return zero(Tv) |
| 160 | + r1 = searchsortedfirst(view(A.rowVal, r1:r2), i) + r1 - 1 |
| 161 | + ((r1 > r2) || (A.rowVal[r1] != i)) ? zero(Tv) : A.nzVal[r1] |
| 162 | +end |
| 163 | + |
| 164 | +mutable struct JLSparseMatrixCSR{Tv, Ti} <: GPUArrays.AbstractGPUSparseMatrixCSR{Tv, Ti} |
| 165 | + rowPtr::JLArray{Ti, 1} |
| 166 | + colVal::JLArray{Ti, 1} |
| 167 | + nzVal::JLArray{Tv, 1} |
| 168 | + dims::NTuple{2,Int} |
| 169 | + nnz::Ti |
| 170 | + |
| 171 | + function JLSparseMatrixCSR{Tv, Ti}(rowPtr::JLArray{<:Integer, 1}, colVal::JLArray{<:Integer, 1}, |
| 172 | + nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti<:Integer} |
| 173 | + new{Tv, Ti}(rowPtr, colVal, nzVal, dims, length(nzVal)) |
| 174 | + end |
| 175 | +end |
| 176 | +function JLSparseMatrixCSR(rowPtr::JLArray{Ti, 1}, colVal::JLArray{Ti, 1}, nzVal::JLArray{Tv, 1}, dims::NTuple{2,<:Integer}) where {Tv, Ti <: Integer} |
| 177 | + return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, dims) |
| 178 | +end |
| 179 | +function SparseArrays.SparseMatrixCSC(x::JLSparseMatrixCSR) |
| 180 | + x_transpose = SparseMatrixCSC(size(x, 2), size(x, 1), Array(x.rowPtr), Array(x.colVal), Array(x.nzVal)) |
| 181 | + return SparseMatrixCSC(transpose(x_transpose)) |
| 182 | +end |
| 183 | + |
| 184 | +JLSparseMatrixCSR(A::JLSparseMatrixCSR) = A |
| 185 | + |
| 186 | +function Base.getindex(A::JLSparseMatrixCSR{Tv, Ti}, i0::Integer, i1::Integer) where {Tv, Ti} |
| 187 | + c1 = Int(A.rowPtr[i0]) |
| 188 | + c2 = Int(A.rowPtr[i0+1]-1) |
| 189 | + (c1 > c2) && return zero(Tv) |
| 190 | + c1 = searchsortedfirst(A.colVal, i1, c1, c2, Base.Order.Forward) |
| 191 | + (c1 > c2 || A.colVal[c1] != i1) && return zero(Tv) |
| 192 | + nonzeros(A)[c1] |
| 193 | +end |
| 194 | + |
118 | 195 | GPUArrays.storage(a::JLArray) = a.data |
| 196 | +GPUArrays._dense_array_type(a::JLArray{T, N}) where {T, N} = JLArray{T, N} |
| 197 | +GPUArrays._dense_array_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, N} |
| 198 | +GPUArrays._dense_vector_type(a::JLArray{T, N}) where {T, N} = JLArray{T, 1} |
| 199 | +GPUArrays._dense_vector_type(::Type{JLArray{T, N}}) where {T, N} = JLArray{T, 1} |
| 200 | + |
| 201 | +GPUArrays._sparse_array_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSC |
| 202 | +GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSC}) = JLSparseMatrixCSC |
| 203 | +GPUArrays._sparse_array_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSR |
| 204 | +GPUArrays._sparse_array_type(::Type{<:JLSparseMatrixCSR}) = JLSparseMatrixCSR |
| 205 | +GPUArrays._sparse_array_type(sa::JLSparseVector) = JLSparseVector |
| 206 | +GPUArrays._sparse_array_type(::Type{<:JLSparseVector}) = JLSparseVector |
| 207 | + |
| 208 | +GPUArrays._dense_array_type(sa::JLSparseVector) = JLArray |
| 209 | +GPUArrays._dense_array_type(::Type{<:JLSparseVector}) = JLArray |
| 210 | +GPUArrays._dense_array_type(sa::JLSparseMatrixCSC) = JLArray |
| 211 | +GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSC}) = JLArray |
| 212 | +GPUArrays._dense_array_type(sa::JLSparseMatrixCSR) = JLArray |
| 213 | +GPUArrays._dense_array_type(::Type{<:JLSparseMatrixCSR}) = JLArray |
| 214 | + |
| 215 | +GPUArrays._csc_type(sa::JLSparseMatrixCSR) = JLSparseMatrixCSC |
| 216 | +GPUArrays._csr_type(sa::JLSparseMatrixCSC) = JLSparseMatrixCSR |
119 | 217 |
|
120 | 218 | # conversion of untyped data to a typed Array |
121 | 219 | function typed_data(x::JLArray{T}) where {T} |
@@ -217,6 +315,47 @@ JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs) |
217 | 315 | (::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x) |
218 | 316 | JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A) |
219 | 317 |
|
| 318 | +function JLSparseVector(xs::SparseVector{Tv, Ti}) where {Ti, Tv} |
| 319 | + iPtr = JLVector{Ti}(undef, length(xs.nzind)) |
| 320 | + nzVal = JLVector{Tv}(undef, length(xs.nzval)) |
| 321 | + copyto!(iPtr, convert(Vector{Ti}, xs.nzind)) |
| 322 | + copyto!(nzVal, convert(Vector{Tv}, xs.nzval)) |
| 323 | + return JLSparseVector{Tv, Ti}(iPtr, nzVal, length(xs),) |
| 324 | +end |
| 325 | +Base.length(x::JLSparseVector) = x.len |
| 326 | +Base.size(x::JLSparseVector) = (x.len,) |
| 327 | + |
| 328 | +function JLSparseMatrixCSC(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv} |
| 329 | + colPtr = JLVector{Ti}(undef, length(xs.colptr)) |
| 330 | + rowVal = JLVector{Ti}(undef, length(xs.rowval)) |
| 331 | + nzVal = JLVector{Tv}(undef, length(xs.nzval)) |
| 332 | + copyto!(colPtr, convert(Vector{Ti}, xs.colptr)) |
| 333 | + copyto!(rowVal, convert(Vector{Ti}, xs.rowval)) |
| 334 | + copyto!(nzVal, convert(Vector{Tv}, xs.nzval)) |
| 335 | + return JLSparseMatrixCSC{Tv, Ti}(colPtr, rowVal, nzVal, (xs.m, xs.n)) |
| 336 | +end |
| 337 | +Base.length(x::JLSparseMatrixCSC) = prod(x.dims) |
| 338 | +Base.size(x::JLSparseMatrixCSC) = x.dims |
| 339 | + |
| 340 | +function JLSparseMatrixCSR(xs::SparseMatrixCSC{Tv, Ti}) where {Ti, Tv} |
| 341 | + csr_xs = SparseMatrixCSC(transpose(xs)) |
| 342 | + rowPtr = JLVector{Ti}(undef, length(csr_xs.colptr)) |
| 343 | + colVal = JLVector{Ti}(undef, length(csr_xs.rowval)) |
| 344 | + nzVal = JLVector{Tv}(undef, length(csr_xs.nzval)) |
| 345 | + copyto!(rowPtr, convert(Vector{Ti}, csr_xs.colptr)) |
| 346 | + copyto!(colVal, convert(Vector{Ti}, csr_xs.rowval)) |
| 347 | + copyto!(nzVal, convert(Vector{Tv}, csr_xs.nzval)) |
| 348 | + return JLSparseMatrixCSR{Tv, Ti}(rowPtr, colVal, nzVal, (xs.m, xs.n)) |
| 349 | +end |
| 350 | +function JLSparseMatrixCSR(xs::JLSparseMatrixCSC{Tv, Ti}) where {Ti, Tv} |
| 351 | + return JLSparseMatrixCSR(SparseMatrixCSC(xs)) |
| 352 | +end |
| 353 | +function JLSparseMatrixCSC(xs::JLSparseMatrixCSR{Tv, Ti}) where {Ti, Tv} |
| 354 | + return JLSparseMatrixCSC(SparseMatrixCSC(xs)) |
| 355 | +end |
| 356 | +Base.length(x::JLSparseMatrixCSR) = prod(x.dims) |
| 357 | +Base.size(x::JLSparseMatrixCSR) = x.dims |
| 358 | + |
220 | 359 | # idempotency |
221 | 360 | JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs |
222 | 361 |
|
@@ -358,9 +497,17 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br |
358 | 497 | R |
359 | 498 | end |
360 | 499 |
|
| 500 | +Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = |
| 501 | +GPUSparseDeviceMatrixCSC{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.colPtr), adapt(to, x.rowVal), adapt(to, x.nzVal), x.dims, x.nnz) |
| 502 | +Adapt.adapt_structure(to::Adaptor, x::JLSparseMatrixCSR{Tv,Ti}) where {Tv,Ti} = |
| 503 | +GPUSparseDeviceMatrixCSR{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.rowPtr), adapt(to, x.colVal), adapt(to, x.nzVal), x.dims, x.nnz) |
| 504 | +Adapt.adapt_structure(to::Adaptor, x::JLSparseVector{Tv,Ti}) where {Tv,Ti} = |
| 505 | +GPUSparseDeviceVector{Tv,Ti,JLDeviceArray{Ti, 1}, JLDeviceArray{Tv, 1}}(adapt(to, x.iPtr), adapt(to, x.nzVal), x.len, x.nnz) |
| 506 | + |
361 | 507 | ## KernelAbstractions interface |
362 | 508 |
|
363 | 509 | KernelAbstractions.get_backend(a::JLA) where JLA <: JLArray = JLBackend() |
| 510 | +KernelAbstractions.get_backend(a::JLA) where JLA <: Union{JLSparseMatrixCSC, JLSparseMatrixCSR, JLSparseVector} = JLBackend() |
364 | 511 |
|
365 | 512 | function KernelAbstractions.mkcontext(kernel::Kernel{JLBackend}, I, _ndrange, iterspace, ::Dynamic) where Dynamic |
366 | 513 | return KernelAbstractions.CompilerMetadata{KernelAbstractions.ndrange(kernel), Dynamic}(I, _ndrange, iterspace) |
|
0 commit comments