diff --git a/lib/mkl/array.jl b/lib/mkl/array.jl index 0db254e8..3f117bc3 100644 --- a/lib/mkl/array.jl +++ b/lib/mkl/array.jl @@ -5,7 +5,7 @@ const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1} const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2} mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti} - handle::matrix_handle_t + handle::Union{Nothing, matrix_handle_t} rowPtr::oneVector{Ti} colVal::oneVector{Ti} nzVal::oneVector{Tv} @@ -14,7 +14,7 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti} end mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti} - handle::matrix_handle_t + handle::Union{Nothing, matrix_handle_t} colPtr::oneVector{Ti} rowVal::oneVector{Ti} nzVal::oneVector{Tv} @@ -23,7 +23,7 @@ mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti} end mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti} - handle::matrix_handle_t + handle::Union{Nothing, matrix_handle_t} rowInd::oneVector{Ti} colInd::oneVector{Ti} nzVal::oneVector{Tv} diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl index 4e038372..1071040e 100644 --- a/lib/mkl/wrappers_blas.jl +++ b/lib/mkl/wrappers_blas.jl @@ -153,7 +153,7 @@ for (fname, elty) in ((:onemklSsymm, :Float32), lda = max(1,stride(A,2)) ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $fname(sycl_queue(queue), side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc) C @@ -193,7 +193,7 @@ for (fname, elty) in ((:onemklSsyrk, :Float32), k = size(A, trans == 'N' ? 2 : 1) lda = max(1,stride(A,2)) ldc = max(1,stride(C,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $fname(sycl_queue(queue), uplo, trans, n, k, alpha, A, lda, beta, C, ldc) C end @@ -234,7 +234,7 @@ for (fname, elty) in ((:onemklDsyr2k,:Float64), lda = max(1,stride(A,2)) ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $fname(sycl_queue(queue), uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc) C end @@ -268,7 +268,7 @@ for (fname, elty) in ((:onemklZherk, :ComplexF64), k = size(A, trans == 'N' ? 2 : 1) lda = max(1,stride(A,2)) ldc = max(1,stride(C,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $fname(sycl_queue(queue), uplo, trans, n, k, alpha, A, lda, beta, C, ldc) C end @@ -305,7 +305,7 @@ for (fname, elty) in ((:onemklZher2k,:ComplexF64), lda = max(1,stride(A,2)) ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $fname(sycl_queue(queue), uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc) C end @@ -336,7 +336,7 @@ for (fname, elty) in ((:onemklSgemv, :Float32), x::oneStridedArray{$elty}, beta::Number, y::oneStridedArray{$elty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) # handle trans m,n = size(a) # check dimensions @@ -380,7 +380,7 @@ for (fname, elty) in ((:onemklChemv,:ComplexF32), lda = max(1,stride(A,2)) incx = stride(x,1) incy = stride(y,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, n, alpha, A, lda, x, incx, beta, y, incy) y end @@ -414,7 +414,7 @@ for (fname, elty) in ((:onemklChbmv,:ComplexF32), lda = max(1,stride(A,2)) incx = stride(x,1) incy = stride(y,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, n, k, alpha, A, lda, x, incx, beta, y, incy) y end @@ -443,7 +443,7 @@ for (fname, elty) in ((:onemklCher,:ComplexF32), length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions")) incx = stride(x,1) lda = max(1,stride(A,2)) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, n, alpha, x, incx, A, lda) A end @@ -466,7 +466,7 @@ for (fname, elty) in ((:onemklCher2,:ComplexF32), incx = stride(x,1) incy = stride(y,1) lda = max(1,stride(A,2)) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, n, alpha, x, incx, y, incy, A, lda) A end @@ -486,7 +486,7 @@ for (fname, elty) in alpha::Number, x::oneStridedArray{$elty}, y::oneStridedArray{$elty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) alpha = $elty(alpha) $fname(sycl_queue(queue), n, alpha, x, stride(x,1), y, stride(y,1)) y @@ -506,7 +506,7 @@ for (fname, elty) in x::oneStridedArray{$elty}, beta::Number, y::oneStridedArray{$elty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) alpha = $elty(alpha) beta = $elty(beta) $fname(sycl_queue(queue), n, alpha, x, stride(x,1), beta, y, stride(y,1)) @@ -528,7 +528,7 @@ for (fname, elty, cty, sty, supty) in ((:onemklSrot,:Float32,:Float32,:Float32,: y::oneStridedArray{$elty}, c::Real, s::$supty) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) c = $cty(c) s = $sty(s) $fname(sycl_queue(queue), n, x, stride(x, 1), y, stride(y, 1), c, s) @@ -560,7 +560,7 @@ for (fname, elty) in function scal!(n::Integer, alpha::$elty, x::oneStridedArray{$elty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), n, alpha, x, stride(x,1)) x end @@ -586,7 +586,7 @@ for (fname, elty, ret_type) in (:onemklZnrm2, :ComplexF64,:Float64)) @eval begin function nrm2(n::Integer, x::oneStridedArray{$elty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) result = oneArray{$ret_type}([0]); $fname(sycl_queue(queue), n, x, stride(x,1), result) res = Array(result) @@ -616,7 +616,7 @@ for (jname, fname, elty) in function $jname(n::Integer, x::oneStridedArray{$elty}, y::oneStridedArray{$elty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) result = oneArray{$elty}([0]); $fname(sycl_queue(queue), n, x, stride(x,1), y, stride(y,1), result) res = Array(result) @@ -649,7 +649,7 @@ for (fname, elty) in ((:onemklSsbmv, :Float32), if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end if m < 1+k throw(DimensionMismatch("Array A has fewer than 1+k rows")) end if n != length(x) || n != length(y) throw(DimensionMismatch("")) end - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) lda = max(1, stride(a,2)) incx = stride(x,1) incy = stride(y,1) @@ -676,7 +676,7 @@ for (fname, elty, celty) in ((:onemklCSscal, :Float32, :ComplexF32), function scal!(n::Integer, alpha::$elty, x::oneStridedArray{$celty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), n, alpha, x, stride(x,1)) end end @@ -696,7 +696,7 @@ for (fname, elty) in ((:onemklSger, :Float32), m,n = size(a) m == length(x) || throw(DimensionMismatch("")) n == length(y) || throw(DimensionMismatch("")) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), m, n, alpha, x, stride(x,1), y, stride(y,1), a, max(1,stride(a,2))) a end @@ -714,7 +714,7 @@ for (fname, elty) in ((:onemklSspr, :Float32), n = round(Int, (sqrt(8*length(A))-1)/2) length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions")) incx = stride(x,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, n, alpha, x, incx, A) A end @@ -738,7 +738,7 @@ for (fname, elty) in ((:onemklSsymv,:Float32), lda = max(1,stride(A,2)) incx = stride(x,1) incy = stride(y,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, n, alpha, A, lda, x, incx, beta, y, incy) y end @@ -764,7 +764,7 @@ for (fname, elty) in ((:onemklSsyr,:Float32), length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions")) incx = stride(x,1) lda = max(1,stride(A,2)) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, n, alpha, x, incx, A, lda) A end @@ -786,7 +786,7 @@ for (fname, elty) in function copy!(n::Integer, x::oneStridedArray{$elty}, y::oneStridedArray{$elty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), n, x, stride(x, 1), y, stride(y, 1)) y end @@ -807,7 +807,7 @@ for (fname, elty, ret_type) in function asum(n::Integer, x::oneStridedArray{$elty}) result = oneArray{$ret_type}([0]) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), n, x, stride(x, 1), result) res = Array(result) return res[1] @@ -824,7 +824,7 @@ for (fname, elty) in @eval begin function iamax(x::oneStridedArray{$elty}) n = length(x) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) result = oneArray{Int64}([0]); $fname(sycl_queue(queue), n, x, stride(x, 1), result, 'O') return Array(result)[1] @@ -842,7 +842,7 @@ for (fname, elty) in function iamin(x::StridedArray{$elty}) n = length(x) result = oneArray{Int64}([0]); - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue),n, x, stride(x, 1), result, 'O') return Array(result)[1] end @@ -859,7 +859,7 @@ for (fname, elty) in ((:onemklSswap,:Float32), x::oneStridedArray{$elty}, y::oneStridedArray{$elty}) # Assuming both memory allocated on same device & context - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), n, x, stride(x, 1), y, stride(y, 1)) x, y end @@ -885,7 +885,7 @@ for (fname, elty) in ((:onemklSgbmv, :Float32), n = size(a,2) length(x) == (trans == 'N' ? n : m) && length(y) == (trans == 'N' ? m : n) || throw(DimensionMismatch("")) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) lda = max(1, stride(a,2)) incx = stride(x,1) incy = stride(y,1) @@ -903,7 +903,7 @@ function gbmv(trans::Char, x::oneStridedArray{T}) where T n = size(a,2) leny = trans == 'N' ? m : n - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) gbmv!(trans, m, kl, ku, alpha, a, x, zero(T), similar(x, leny)) end function gbmv(trans::Char, @@ -912,7 +912,7 @@ function gbmv(trans::Char, ku::Integer, a::oneStridedArray{T}, x::oneStridedArray{T}) where T - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) gbmv(trans, m, kl, ku, one(T), a, x) end @@ -932,7 +932,7 @@ for (fname, elty) in ((:onemklSspmv, :Float32), end incx = stride(x,1) incy = stride(y,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, n, alpha, A, x, incx, beta, y, incy) y end @@ -966,7 +966,7 @@ for (fname, elty) in ((:onemklStbsv, :Float32), if n != length(x) throw(DimensionMismatch("")) end lda = max(1,stride(A,2)) incx = stride(x,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, trans, diag, n, k, A, lda, x, incx) x end @@ -996,7 +996,7 @@ for (fname, elty) in ((:onemklStbmv,:Float32), if n != length(x) throw(DimensionMismatch("")) end lda = max(1,stride(A,2)) incx = stride(x,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, trans, diag, n, k, A, lda, x, incx) x end @@ -1029,7 +1029,7 @@ for (fname, elty) in ((:onemklStrmv, :Float32), end lda = max(1,stride(A,2)) incx = stride(x,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, trans, diag, n, A, lda, x, incx) x end @@ -1061,7 +1061,7 @@ for (fname, elty) in ((:onemklStrsv, :Float32), end lda = max(1,stride(A,2)) incx = stride(x,1) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), uplo, trans, diag, n, A, lda, x, incx) x end @@ -1096,7 +1096,7 @@ for (mmname, smname, elty) in if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end lda = max(1,stride(A,2)) ldb = max(1,stride(B,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $mmname(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb) B end @@ -1114,7 +1114,7 @@ for (mmname, smname, elty) in if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end lda = max(1,stride(A,2)) ldb = max(1,stride(B,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $smname(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb) B end @@ -1160,7 +1160,7 @@ for (fname, elty) in ((:onemklZhemm,:ComplexF64), lda = max(1,stride(A,2)) ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $fname(sycl_queue(queue), side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc) C end @@ -1202,9 +1202,9 @@ for (fname, elty) in ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - device() == device(B) == device(C) || error("Multi-device GEMM not supported") + device(A) == device(B) == device(C) || error("Multi-device GEMM not supported") context(A) == context(B) == context(C) || error("Multi-context GEMM not supported") - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) alpha = $elty(alpha) beta = $elty(beta) @@ -1249,7 +1249,7 @@ for (fname, elty) in ((:onemklSdgmm, :Float32), lda = max(1,stride(A,2)) incx = stride(X,1) ldc = max(1,stride(C,2)) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) $fname(sycl_queue(queue), mode, m, n, A, lda, X, incx, C, ldc) C end @@ -1292,7 +1292,7 @@ for (fname, elty) in strideB = size(B, 3) == 1 ? 0 : stride(B, 3) strideC = stride(C, 3) batchCount = size(C, 3) - queue = global_queue(context(A), device()) + queue = global_queue(context(A), device(A)) alpha = $elty(alpha) beta = $elty(beta) $fname(sycl_queue(queue), transA, transB, m, n, k, alpha, A, lda, strideA, B, diff --git a/lib/mkl/wrappers_sparse.jl b/lib/mkl/wrappers_sparse.jl index 9a0f5170..6b06c00c 100644 --- a/lib/mkl/wrappers_sparse.jl +++ b/lib/mkl/wrappers_sparse.jl @@ -1,7 +1,16 @@ function sparse_release_matrix_handle(A::oneAbstractSparseMatrix) - queue = global_queue(context(A.nzVal), device(A.nzVal)) - handle_ptr = Ref{matrix_handle_t}(A.handle) - onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr) + if A.handle !== nothing + try + queue = global_queue(context(A.nzVal), device(A.nzVal)) + handle_ptr = Ref{matrix_handle_t}(A.handle) + onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr) + # Only synchronize after successful release to ensure completion + synchronize(queue) + catch err + # Don't let finalizer errors crash the program + @warn "Error releasing sparse matrix handle" exception=err + end + end end for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int32), @@ -13,46 +22,72 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3 (:onemklZsparse_set_csr_data , :ComplexF64, :Int32), (:onemklZsparse_set_csr_data_64, :ComplexF64, :Int64)) @eval begin - function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty}) + + function oneSparseMatrixCSR( + rowPtr::oneVector{$intty}, colVal::oneVector{$intty}, + nzVal::oneVector{$elty}, dims::NTuple{2, Int} + ) handle_ptr = Ref{matrix_handle_t}() onemklXsparse_init_matrix_handle(handle_ptr) + m, n = dims + nnzA = length(nzVal) + queue = global_queue(context(nzVal), device(nzVal)) + # Don't update handle if matrix is empty + if m != 0 && n != 0 + $fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal) + dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA) + finalizer(sparse_release_matrix_handle, dA) + else + dA = oneSparseMatrixCSR{$elty, $intty}(nothing, rowPtr, colVal, nzVal, (m, n), nnzA) + end + return dA + end + + function oneSparseMatrixCSC( + colPtr::oneVector{$intty}, rowVal::oneVector{$intty}, + nzVal::oneVector{$elty}, dims::NTuple{2, Int} + ) + queue = global_queue(context(nzVal), device(nzVal)) + handle_ptr = Ref{matrix_handle_t}() + onemklXsparse_init_matrix_handle(handle_ptr) + m, n = dims + nnzA = length(nzVal) + # Don't update handle if matrix is empty + if m != 0 && n != 0 + $fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ + dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m, n), nnzA) + finalizer(sparse_release_matrix_handle, dA) + else + dA = oneSparseMatrixCSC{$elty, $intty}(nothing, colPtr, rowVal, nzVal, (m, n), nnzA) + end + return dA + end + + + function oneSparseMatrixCSR(A::SparseMatrixCSC{$elty, $intty}) m, n = size(A) At = SparseMatrixCSC(A |> transpose) rowPtr = oneVector{$intty}(At.colptr) colVal = oneVector{$intty}(At.rowval) nzVal = oneVector{$elty}(At.nzval) - nnzA = length(At.nzval) - queue = global_queue(context(nzVal), device()) - $fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal) - dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m,n), nnzA) - finalizer(sparse_release_matrix_handle, dA) - return dA + return oneSparseMatrixCSR(rowPtr, colVal, nzVal, (m, n)) end function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty}) - handle_ptr = Ref{matrix_handle_t}() At = SparseMatrixCSC(reverse(A.dims)..., Vector(A.rowPtr), Vector(A.colVal), Vector(A.nzVal)) A_csc = SparseMatrixCSC(At |> transpose) return A_csc end function oneSparseMatrixCSC(A::SparseMatrixCSC{$elty, $intty}) - handle_ptr = Ref{matrix_handle_t}() - onemklXsparse_init_matrix_handle(handle_ptr) m, n = size(A) colPtr = oneVector{$intty}(A.colptr) rowVal = oneVector{$intty}(A.rowval) nzVal = oneVector{$elty}(A.nzval) - nnzA = length(A.nzval) - queue = global_queue(context(nzVal), device()) - $fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ - dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA) - finalizer(sparse_release_matrix_handle, dA) - return dA + return oneSparseMatrixCSC(colPtr, rowVal, nzVal, (m, n)) end function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty}) - handle_ptr = Ref{matrix_handle_t}() A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal)) return A_csc end @@ -77,15 +112,18 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3 colInd = oneVector{$intty}(col) nzVal = oneVector{$elty}(val) nnzA = length(val) - queue = global_queue(context(nzVal), device()) - $fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal) - dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA) - finalizer(sparse_release_matrix_handle, dA) + queue = global_queue(context(nzVal), device(nzVal)) + if m != 0 && n != 0 + $fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal) + dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA) + finalizer(sparse_release_matrix_handle, dA) + else + dA = oneSparseMatrixCOO{$elty, $intty}(nothing, rowInd, colInd, nzVal, (m,n), nnzA) + end return dA end function SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty}) - handle_ptr = Ref{matrix_handle_t}() A = sparse(Vector(A.rowInd), Vector(A.colInd), Vector(A.nzVal), A.dims...) return A end @@ -105,7 +143,7 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO) beta::Number, y::oneStridedVector{$elty}) - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), trans, alpha, A.handle, x, beta, y) y end @@ -140,8 +178,11 @@ for SparseMatrix in (:oneSparseMatrixCSC,) beta::Number, y::oneStridedVector{$elty}) - queue = global_queue(context(x), device()) - $fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y) + queue = global_queue(context(x), device(x)) + m, n = size(A) + if m != 0 && n != 0 + $fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y) + end y end end @@ -173,7 +214,7 @@ for SparseMatrix in (:oneSparseMatrixCSC,) beta = conj(beta) end - queue = global_queue(context(x), device()) + queue = global_queue(context(x), device(x)) $fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y) if trans == 'C' @@ -217,7 +258,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32), nrhs = size(B, 2) ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - queue = global_queue(context(C), device()) + queue = global_queue(context(C), device(C)) $fname(sycl_queue(queue), 'C', transa, transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc) C end @@ -254,7 +295,7 @@ for (fname, elty) in ((:onemklSsparse_gemm, :Float32), nrhs = size(B, 2) ldb = max(1,stride(B,2)) ldc = max(1,stride(C,2)) - queue = global_queue(context(C), device()) + queue = global_queue(context(C), device(C)) $fname(sycl_queue(queue), 'C', flip_trans(transa), transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc) C end @@ -289,7 +330,7 @@ for (fname, elty) in ( nrhs = size(B, 2) ldb = max(1, stride(B, 2)) ldc = max(1, stride(C, 2)) - queue = global_queue(context(C), device()) + queue = global_queue(context(C), device(C)) # Use identity: conj(C_new) = conj(alpha) * S * conj(opB(B)) + conj(beta) * conj(C) # Prepare conj(C) in-place and conj(B) into a temporary if needed @@ -359,7 +400,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32), beta::Number, y::oneStridedVector{$elty}) - queue = global_queue(context(y), device()) + queue = global_queue(context(y), device(y)) $fname(sycl_queue(queue), uplo, alpha, A.handle, x, beta, y) y end @@ -379,7 +420,7 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32), beta::Number, y::oneStridedVector{$elty}) - queue = global_queue(context(y), device()) + queue = global_queue(context(y), device(y)) $fname(sycl_queue(queue), flip_uplo(uplo), alpha, A.handle, x, beta, y) y end @@ -400,7 +441,7 @@ for (fname, elty) in ((:onemklSsparse_trmv, :Float32), beta::Number, y::oneStridedVector{$elty}) - queue = global_queue(context(y), device()) + queue = global_queue(context(y), device(y)) $fname(sycl_queue(queue), uplo, trans, diag, alpha, A.handle, x, beta, y) y end @@ -442,7 +483,7 @@ for (fname, elty) in ( "Convert to oneSparseMatrixCSR format instead." ) ) - queue = global_queue(context(y), device()) + queue = global_queue(context(y), device(y)) $fname(sycl_queue(queue), uplo, flip_trans(trans), diag, alpha, A.handle, x, beta, y) return y end @@ -475,7 +516,7 @@ for (fname, elty) in ((:onemklSsparse_trsv, :Float32), x::oneStridedVector{$elty}, y::oneStridedVector{$elty}) - queue = global_queue(context(y), device()) + queue = global_queue(context(y), device(y)) $fname(sycl_queue(queue), uplo, trans, diag, alpha, A.handle, x, y) y end @@ -512,7 +553,7 @@ for (fname, elty) in ( "Convert to oneSparseMatrixCSR format instead." ) ) - queue = global_queue(context(y), device()) + queue = global_queue(context(y), device(y)) onemklXsparse_optimize_trsv(sycl_queue(queue), uplo, flip_trans(trans), diag, A.handle) return A end @@ -555,7 +596,7 @@ for (fname, elty) in ((:onemklSsparse_trsm, :Float32), nrhs = size(X, 2) ldx = max(1,stride(X,2)) ldy = max(1,stride(Y,2)) - queue = global_queue(context(Y), device()) + queue = global_queue(context(Y), device(Y)) $fname(sycl_queue(queue), 'C', transA, transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy) Y end @@ -614,7 +655,7 @@ for (fname, elty) in ( nrhs = size(X, 2) ldx = max(1, stride(X, 2)) ldy = max(1, stride(Y, 2)) - queue = global_queue(context(Y), device()) + queue = global_queue(context(Y), device(Y)) $fname(sycl_queue(queue), 'C', flip_trans(transA), transX, uplo, diag, alpha, A.handle, X, nrhs, ldx, Y, ldy) return Y end diff --git a/src/indexing.jl b/src/indexing.jl index 661deaaf..d46c67bc 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -20,17 +20,17 @@ function Base.findall(bools::oneArray{Bool}) I = keytype(bools) indices = cumsum(reshape(bools, prod(size(bools)))) - oneL0.synchronize() n = isempty(indices) ? 0 : @allowscalar indices[end] ys = oneArray{I}(undef, n) if n > 0 - @oneapi items = length(bools) _ker!(ys, bools, indices) + kernel = @oneapi launch=false _ker!(ys, bools, indices) + group_size = launch_configuration(kernel) + kernel(ys, bools, indices; items=group_size, groups=cld(length(bools), group_size)) end - oneL0.synchronize() - unsafe_free!(indices) + # unsafe_free!(indices) return ys end diff --git a/src/oneAPI.jl b/src/oneAPI.jl index b7f8b527..2b2fde8a 100644 --- a/src/oneAPI.jl +++ b/src/oneAPI.jl @@ -69,6 +69,7 @@ include("utils.jl") include("oneAPIKernels.jl") import .oneAPIKernels: oneAPIBackend include("accumulate.jl") +include("sorting.jl") include("indexing.jl") export oneAPIBackend diff --git a/src/sorting.jl b/src/sorting.jl new file mode 100644 index 00000000..f83ea078 --- /dev/null +++ b/src/sorting.jl @@ -0,0 +1,3 @@ +Base.sort!(x::oneArray; kwargs...) = (AK.sort!(x; kwargs...); return x) +Base.sortperm!(ix::oneArray, x::oneArray; kwargs...) = (AK.sortperm!(ix, x; kwargs...); return ix) +Base.sortperm(x::oneArray; kwargs...) = sortperm!(oneArray(1:length(x)), x; kwargs...) diff --git a/test/Project.toml b/test/Project.toml index c214ed96..90670d48 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,4 +19,5 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36" diff --git a/test/indexing.jl b/test/indexing.jl index 4f1bfcc7..cbf1b32b 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -17,4 +17,18 @@ using oneAPI data = oneArray(collect(1:6)) mask = oneArray(Bool[true, false, true, false, false, true]) @test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])] + + # Test with array larger than 1024 to trigger multiple groups + large_size = 2048 + large_mask = oneArray(rand(Bool, large_size)) + large_result_gpu = Array(findall(large_mask)) + large_result_cpu = findall(Array(large_mask)) + @test large_result_gpu == large_result_cpu + + # Test with even larger array to ensure robustness + very_large_size = 5000 + very_large_mask = oneArray(fill(true, very_large_size)) # all true for predictable result + very_large_result_gpu = Array(findall(very_large_mask)) + very_large_result_cpu = findall(fill(true, very_large_size)) + @test very_large_result_gpu == very_large_result_cpu end diff --git a/test/onemkl.jl b/test/onemkl.jl index 24410dac..d702333e 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -1090,6 +1090,10 @@ end B = oneSparseMatrixCSR(A) A2 = SparseMatrixCSC(B) @test A == A2 + C = oneSparseMatrixCSR(B.rowPtr, B.colVal, B.nzVal, size(B)) + A3 = SparseMatrixCSC(C) + @test A == A3 + D = oneSparseMatrixCSR(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix end end @@ -1101,6 +1105,10 @@ end B = oneSparseMatrixCSC(A) A2 = SparseMatrixCSC(B) @test A == A2 + C = oneSparseMatrixCSC(A.colptr |> oneVector, A.rowval |> oneVector, A.nzval |> oneVector, size(A)) + A3 = SparseMatrixCSC(C) + @test A == A3 + D = oneSparseMatrixCSC(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix end end diff --git a/test/sorting.jl b/test/sorting.jl new file mode 100644 index 00000000..21ffc97b --- /dev/null +++ b/test/sorting.jl @@ -0,0 +1,17 @@ +using Test +using oneAPI + +@testset "sorting" begin + data = oneArray([3, 1, 4, 1, 5]) + sort!(data) + @test Array(data) == [1, 1, 3, 4, 5] + + data_rev = oneArray([3, 1, 4, 1, 5]) + sort!(data_rev, rev = true) + @test Array(data_rev) == [5, 4, 3, 1, 1] + data = oneArray([3, 1, 4, 1, 5]) + @test Array(sortperm(data)) == sortperm([3, 1, 4, 1, 5]) + + data_rev = oneArray([3, 1, 4, 1, 5]) + @test Array(sortperm(data_rev, rev = true)) == sortperm([3, 1, 4, 1, 5], rev = true) +end