Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lib/mkl/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2}

mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
handle::Union{Nothing, matrix_handle_t}
rowPtr::oneVector{Ti}
colVal::oneVector{Ti}
nzVal::oneVector{Tv}
Expand All @@ -14,7 +14,7 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
end

mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
handle::Union{Nothing, matrix_handle_t}
colPtr::oneVector{Ti}
rowVal::oneVector{Ti}
nzVal::oneVector{Tv}
Expand All @@ -23,7 +23,7 @@ mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
end

mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
handle::matrix_handle_t
handle::Union{Nothing, matrix_handle_t}
rowInd::oneVector{Ti}
colInd::oneVector{Ti}
nzVal::oneVector{Tv}
Expand Down
84 changes: 42 additions & 42 deletions lib/mkl/wrappers_blas.jl

Large diffs are not rendered by default.

121 changes: 81 additions & 40 deletions lib/mkl/wrappers_sparse.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
function sparse_release_matrix_handle(A::oneAbstractSparseMatrix)
queue = global_queue(context(A.nzVal), device(A.nzVal))
handle_ptr = Ref{matrix_handle_t}(A.handle)
onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
if A.handle !== nothing
try
queue = global_queue(context(A.nzVal), device(A.nzVal))
handle_ptr = Ref{matrix_handle_t}(A.handle)
onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
# Only synchronize after successful release to ensure completion
synchronize(queue)
catch err
# Don't let finalizer errors crash the program
@warn "Error releasing sparse matrix handle" exception=err
end
end
end

for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32),
Expand All @@ -13,46 +22,72 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
(:onemklZsparse_set_csr_data , :ComplexF64, :Int32),
(:onemklZsparse_set_csr_data_64, :ComplexF64, :Int64))
@eval begin
function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty})

function oneSparseMatrixCSR(
rowPtr::oneVector{$intty}, colVal::oneVector{$intty},
nzVal::oneVector{$elty}, dims::NTuple{2, Int}
)
handle_ptr = Ref{matrix_handle_t}()
onemklXsparse_init_matrix_handle(handle_ptr)
m, n = dims
nnzA = length(nzVal)
queue = global_queue(context(nzVal), device(nzVal))
# Don't update handle if matrix is empty
if m != 0 && n != 0
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
finalizer(sparse_release_matrix_handle, dA)
else
dA = oneSparseMatrixCSR{$elty, $intty}(nothing, rowPtr, colVal, nzVal, (m, n), nnzA)
end
return dA
end

function oneSparseMatrixCSC(
colPtr::oneVector{$intty}, rowVal::oneVector{$intty},
nzVal::oneVector{$elty}, dims::NTuple{2, Int}
)
queue = global_queue(context(nzVal), device(nzVal))
handle_ptr = Ref{matrix_handle_t}()
onemklXsparse_init_matrix_handle(handle_ptr)
m, n = dims
nnzA = length(nzVal)
# Don't update handle if matrix is empty
if m != 0 && n != 0
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m, n), nnzA)
finalizer(sparse_release_matrix_handle, dA)
else
dA = oneSparseMatrixCSC{$elty, $intty}(nothing, colPtr, rowVal, nzVal, (m, n), nnzA)
end
return dA
end


function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty})
m, n = size(A)
At = SparseMatrixCSC(A |> transpose)
rowPtr = oneVector{$intty}(At.colptr)
colVal = oneVector{$intty}(At.rowval)
nzVal = oneVector{$elty}(At.nzval)
nnzA = length(At.nzval)
queue = global_queue(context(nzVal), device())
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA)
finalizer(sparse_release_matrix_handle, dA)
return dA
return oneSparseMatrixCSR(rowPtr, colVal, nzVal, (m, n))
end

function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
At = SparseMatrixCSC(reverse(A.dims)..., Vector(A.rowPtr), Vector(A.colVal), Vector(A.nzVal))
A_csc = SparseMatrixCSC(At |> transpose)
return A_csc
end

function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
onemklXsparse_init_matrix_handle(handle_ptr)
m, n = size(A)
colPtr = oneVector{$intty}(A.colptr)
rowVal = oneVector{$intty}(A.rowval)
nzVal = oneVector{$elty}(A.nzval)
nnzA = length(A.nzval)
queue = global_queue(context(nzVal), device())
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
finalizer(sparse_release_matrix_handle, dA)
return dA
return oneSparseMatrixCSC(colPtr, rowVal, nzVal, (m, n))
end

function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
return A_csc
end
Expand All @@ -77,15 +112,18 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
colInd = oneVector{$intty}(col)
nzVal = oneVector{$elty}(val)
nnzA = length(val)
queue = global_queue(context(nzVal), device())
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
finalizer(sparse_release_matrix_handle, dA)
queue = global_queue(context(nzVal), device(nzVal))
if m != 0 && n != 0
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
finalizer(sparse_release_matrix_handle, dA)
else
dA = oneSparseMatrixCOO{$elty, $intty}(nothing, rowInd, colInd, nzVal, (m,n), nnzA)
end
return dA
end

function SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty})
handle_ptr = Ref{matrix_handle_t}()
A = sparse(Vector(A.rowInd), Vector(A.colInd), Vector(A.nzVal), A.dims...)
return A
end
Expand All @@ -105,7 +143,7 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
beta::Number,
y::oneStridedVector{$elty})

queue = global_queue(context(x), device())
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue), trans, alpha, A.handle, x, beta, y)
y
end
Expand Down Expand Up @@ -140,8 +178,11 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
beta::Number,
y::oneStridedVector{$elty})

queue = global_queue(context(x), device())
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
queue = global_queue(context(x), device(x))
m, n = size(A)
if m != 0 && n != 0
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
end
y
end
end
Expand Down Expand Up @@ -173,7 +214,7 @@ for SparseMatrix in (:oneSparseMatrixCSC,)
beta = conj(beta)
end

queue = global_queue(context(x), device())
queue = global_queue(context(x), device(x))
$fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)

if trans == 'C'
Expand Down Expand Up @@ -217,7 +258,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
nrhs = size(B, 2)
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
queue = global_queue(context(C), device())
queue = global_queue(context(C), device(C))
$fname(sycl_queue(queue), 'C', transa, transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
C
end
Expand Down Expand Up @@ -254,7 +295,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
nrhs = size(B, 2)
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
queue = global_queue(context(C), device())
queue = global_queue(context(C), device(C))
$fname(sycl_queue(queue), 'C', flip_trans(transa), transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
C
end
Expand Down Expand Up @@ -289,7 +330,7 @@ for (fname, elty) in (
nrhs = size(B, 2)
ldb = max(1, stride(B, 2))
ldc = max(1, stride(C, 2))
queue = global_queue(context(C), device())
queue = global_queue(context(C), device(C))

# Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C)
# Prepare conj(C) in-place and conj(B) into a temporary if needed
Expand Down Expand Up @@ -359,7 +400,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
beta::Number,
y::oneStridedVector{$elty})

queue = global_queue(context(y), device())
queue = global_queue(context(y), device(y))
$fname(sycl_queue(queue), uplo, alpha, A.handle, x, beta, y)
y
end
Expand All @@ -379,7 +420,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
beta::Number,
y::oneStridedVector{$elty})

queue = global_queue(context(y), device())
queue = global_queue(context(y), device(y))
$fname(sycl_queue(queue), flip_uplo(uplo), alpha, A.handle, x, beta, y)
y
end
Expand All @@ -400,7 +441,7 @@ for (fname, elty) in ((:onemklSsparse_trmv, :Float32),
beta::Number,
y::oneStridedVector{$elty})

queue = global_queue(context(y), device())
queue = global_queue(context(y), device(y))
$fname(sycl_queue(queue), uplo, trans, diag, alpha, A.handle, x, beta, y)
y
end
Expand Down Expand Up @@ -442,7 +483,7 @@ for (fname, elty) in (
"Convert to oneSparseMatrixCSR format instead."
)
)
queue = global_queue(context(y), device())
queue = global_queue(context(y), device(y))
$fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, beta, y)
return y
end
Expand Down Expand Up @@ -475,7 +516,7 @@ for (fname, elty) in ((:onemklSsparse_trsv, :Float32),
x::oneStridedVector{$elty},
y::oneStridedVector{$elty})

queue = global_queue(context(y), device())
queue = global_queue(context(y), device(y))
$fname(sycl_queue(queue), uplo, trans, diag, alpha, A.handle, x, y)
y
end
Expand Down Expand Up @@ -512,7 +553,7 @@ for (fname, elty) in (
"Convert to oneSparseMatrixCSR format instead."
)
)
queue = global_queue(context(y), device())
queue = global_queue(context(y), device(y))
onemklXsparse_optimize_trsv(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle)
return A
end
Expand Down Expand Up @@ -555,7 +596,7 @@ for (fname, elty) in ((:onemklSsparse_trsm, :Float32),
nrhs = size(X, 2)
ldx = max(1,stride(X,2))
ldy = max(1,stride(Y,2))
queue = global_queue(context(Y), device())
queue = global_queue(context(Y), device(Y))
$fname(sycl_queue(queue), 'C', transA, transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
Y
end
Expand Down Expand Up @@ -614,7 +655,7 @@ for (fname, elty) in (
nrhs = size(X, 2)
ldx = max(1, stride(X, 2))
ldy = max(1, stride(Y, 2))
queue = global_queue(context(Y), device())
queue = global_queue(context(Y), device(Y))
$fname(sycl_queue(queue), 'C', flip_trans(transA), transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy)
return Y
end
Expand Down
8 changes: 4 additions & 4 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ function Base.findall(bools::oneArray{Bool})
I = keytype(bools)

indices = cumsum(reshape(bools, prod(size(bools))))
oneL0.synchronize()

n = isempty(indices) ? 0 : @allowscalar indices[end]

ys = oneArray{I}(undef, n)

if n > 0
@oneapi items = length(bools) _ker!(ys, bools, indices)
kernel = @oneapi launch=false _ker!(ys, bools, indices)
group_size = launch_configuration(kernel)
kernel(ys, bools, indices; items=group_size, groups=cld(length(bools), group_size))
end
oneL0.synchronize()
unsafe_free!(indices)
# unsafe_free!(indices)

return ys
end
1 change: 1 addition & 0 deletions src/oneAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ include("utils.jl")
include("oneAPIKernels.jl")
import .oneAPIKernels: oneAPIBackend
include("accumulate.jl")
include("sorting.jl")
include("indexing.jl")
export oneAPIBackend

Expand Down
3 changes: 3 additions & 0 deletions src/sorting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Base.sort!(x::oneArray; kwargs...) = (AK.sort!(x; kwargs...); return x)
Base.sortperm!(ix::oneArray, x::oneArray; kwargs...) = (AK.sortperm!(ix, x; kwargs...); return ix)
Base.sortperm(x::oneArray; kwargs...) = sortperm!(oneArray(1:length(x)), x; kwargs...)
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
14 changes: 14 additions & 0 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,18 @@ using oneAPI
data = oneArray(collect(1:6))
mask = oneArray(Bool[true, false, true, false, false, true])
@test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])]

# Test with array larger than 1024 to trigger multiple groups
large_size = 2048
large_mask = oneArray(rand(Bool, large_size))
large_result_gpu = Array(findall(large_mask))
large_result_cpu = findall(Array(large_mask))
@test large_result_gpu == large_result_cpu

# Test with even larger array to ensure robustness
very_large_size = 5000
very_large_mask = oneArray(fill(true, very_large_size)) # all true for predictable result
very_large_result_gpu = Array(findall(very_large_mask))
very_large_result_cpu = findall(fill(true, very_large_size))
@test very_large_result_gpu == very_large_result_cpu
end
8 changes: 8 additions & 0 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,10 @@ end
B = oneSparseMatrixCSR(A)
A2 = SparseMatrixCSC(B)
@test A == A2
C = oneSparseMatrixCSR(B.rowPtr, B.colVal, B.nzVal, size(B))
A3 = SparseMatrixCSC(C)
@test A == A3
D = oneSparseMatrixCSR(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
end
end

Expand All @@ -1101,6 +1105,10 @@ end
B = oneSparseMatrixCSC(A)
A2 = SparseMatrixCSC(B)
@test A == A2
C = oneSparseMatrixCSC(A.colptr |> oneVector, A.rowval |> oneVector, A.nzval |> oneVector, size(A))
A3 = SparseMatrixCSC(C)
@test A == A3
D = oneSparseMatrixCSC(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
end
end

Expand Down
17 changes: 17 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using Test
using oneAPI

@testset "sorting" begin
data = oneArray([3, 1, 4, 1, 5])
sort!(data)
@test Array(data) == [1, 1, 3, 4, 5]

data_rev = oneArray([3, 1, 4, 1, 5])
sort!(data_rev, rev = true)
@test Array(data_rev) == [5, 4, 3, 1, 1]
data = oneArray([3, 1, 4, 1, 5])
@test Array(sortperm(data)) == sortperm([3, 1, 4, 1, 5])

data_rev = oneArray([3, 1, 4, 1, 5])
@test Array(sortperm(data_rev, rev = true)) == sortperm([3, 1, 4, 1, 5], rev = true)
end