Skip to content
Open
10 changes: 6 additions & 4 deletions lib/cusparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function LinearAlgebra.kron(A::CuSparseMatrixCOO{T, Ti}, B::CuSparseMatrixCOO{T,
sparse(row, col, data, out_shape..., fmt = :coo)
end

function LinearAlgebra.kron(A::CuSparseMatrixCOO{T, Ti}, B::Diagonal) where {Ti, T}
function LinearAlgebra.kron(A::CuSparseMatrixCOO{T, Ti}, B::Diagonal{TB}) where {Ti, T, TB}
mA,nA = size(A)
mB,nB = size(B)
out_shape = (mA * mB, nA * nB)
Expand All @@ -102,12 +102,13 @@ function LinearAlgebra.kron(A::CuSparseMatrixCOO{T, Ti}, B::Diagonal) where {Ti,
row .+= CuVector(repeat(0:nB-1, outer = Annz)) .+ 1
col .+= CuVector(repeat(0:nB-1, outer = Annz)) .+ 1

data .*= repeat(CUDA.ones(T, nB), outer = Annz)
Bdiag = (TB == Bool) ? CUDA.ones(T, nB) : B.diag
data .*= repeat(Bdiag, outer = Annz)

sparse(row, col, data, out_shape..., fmt = :coo)
end

function LinearAlgebra.kron(A::Diagonal, B::CuSparseMatrixCOO{T, Ti}) where {Ti, T}
function LinearAlgebra.kron(A::Diagonal{TA}, B::CuSparseMatrixCOO{T, Ti}) where {Ti, T, TA}
mA,nA = size(A)
mB,nB = size(B)
out_shape = (mA * mB, nA * nB)
Expand All @@ -122,7 +123,8 @@ function LinearAlgebra.kron(A::Diagonal, B::CuSparseMatrixCOO{T, Ti}) where {Ti,
row = CuVector(repeat(row, inner = Bnnz))
col = (0:nA-1) .* nB
col = CuVector(repeat(col, inner = Bnnz))
data = repeat(CUDA.ones(T, nA), inner = Bnnz)
Adiag = (TA == Bool) ? CUDA.ones(T, nA) : A.diag
data = repeat(Adiag, inner = Bnnz)

row .+= repeat(B.rowInd .- 1, outer = Annz) .+ 1
col .+= repeat(B.colInd .- 1, outer = Annz) .+ 1
Expand Down
8 changes: 8 additions & 0 deletions test/libraries/cusparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ m = 10
B = sprand(T, m, m, 0.3)
ZA = spzeros(T, m, m)
C = I(div(m, 2))
D = Diagonal(rand(T, m))
@testset "type = $typ" for typ in [CuSparseMatrixCSR, CuSparseMatrixCSC]
dA = typ(A)
dB = typ(B)
dZA = typ(ZA)
dD = Diagonal(CuArray(D.diag))
@testset "opnorm and norm" begin
@test opnorm(A, Inf) ≈ opnorm(dA, Inf)
@test opnorm(A, 1) ≈ opnorm(dA, 1)
Expand Down Expand Up @@ -42,6 +44,12 @@ m = 10
@test collect(kron(opa(dZA), C)) ≈ kron(opa(ZA), C)
@test collect(kron(C, opa(dZA))) ≈ kron(C, opa(ZA))
end
@testset "kronecker product with Diagonal opa = $opa" for opa in (identity, transpose, adjoint)
@test collect(kron(opa(dA), dD)) ≈ kron(opa(A), D)
@test collect(kron(dD, opa(dA))) ≈ kron(D, opa(A))
@test collect(kron(opa(dZA), dD)) ≈ kron(opa(ZA), D)
@test collect(kron(dD, opa(dZA))) ≈ kron(D, opa(ZA))
end
end
end

Expand Down