Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# interfacing with LinearAlgebra standard library

import LinearAlgebra
using LinearAlgebra: Transpose, Adjoint,
using LinearAlgebra: Transpose, Adjoint, AdjOrTrans,
Hermitian, Symmetric,
LowerTriangular, UnitLowerTriangular,
UpperTriangular, UnitUpperTriangular,
MulAddMul, wrap
UpperOrLowerTriangular, MulAddMul, wrap

#
# BLAS 1
Expand Down Expand Up @@ -163,12 +163,50 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end

const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T,<:oneStridedMatrix}}

function LinearAlgebra.generic_trimatmul!(
C::oneStridedMatrix{T}, uplocA, isunitcA,
tfunA::Function, A::oneStridedMatrix{T},
triB::UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}},
) where {T<:onemklFloat}
uplocB = LinearAlgebra.uplo_char(triB)
isunitcB = LinearAlgebra.isunit_char(triB)
B = parent(triB)
tfunB = LinearAlgebra.wrapperop(B)
transa = tfunA === identity ? 'N' : tfunA === transpose ? 'T' : 'C'
transb = tfunB === identity ? 'N' : tfunB === transpose ? 'T' : 'C'
if uplocA == 'L' && tfunA === identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' # lower * upper
triu!(B)
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
elseif uplocA == 'U' && tfunA === identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' # upper * lower
tril!(B)
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
elseif uplocA == 'U' && tfunA === identity && tfunB !== identity && uplocB == 'U' && isunitcA == 'N'
# operation is reversed to avoid executing the tranpose
triu!(A)
trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
elseif uplocA == 'L' && tfunA !== identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N'
tril!(B)
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
elseif uplocA == 'U' && tfunA !== identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N'
triu!(B)
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
elseif uplocA == 'L' && tfunA === identity && tfunB !== identity && uplocB == 'L' && isunitcA == 'N'
tril!(A)
trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
else
throw("mixed triangular-triangular multiplication") # TODO: rethink
end
return C
end

# triangular
LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T<:onemklFloat} =
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
70 changes: 70 additions & 0 deletions lib/mkl/wrappers_blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,76 @@ function trsm(side::Char,
trsm!(side, uplo, transa, diag, alpha, A, copy(B))
end

for (mmname_variant, smname_variant, elty) in
((:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64),
(:onemklStrmm_variant, :onemklStrsm_variant, :Float32),
(:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64),
(:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32))
@eval begin
function trmm!(side::Char,
uplo::Char,
transa::Char,
diag::Char,
alpha::Number,
beta::Number,
A::oneStridedMatrix{$elty},
B::oneStridedMatrix{$elty},
C::oneStridedMatrix{$elty})
m, n = size(B)
mA, nA = size(A)
if mA != nA throw(DimensionMismatch("A must be square")) end
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end
lda = max(1,stride(A,2))
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
queue = global_queue(context(A), device())
$mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
B
end

function trsm!(side::Char,
uplo::Char,
transa::Char,
diag::Char,
alpha::Number,
beta::Number,
A::oneStridedMatrix{$elty},
B::oneStridedMatrix{$elty},
C::oneStridedMatrix{$elty})
m, n = size(B)
mA, nA = size(A)
if mA != nA throw(DimensionMismatch("A must be square")) end
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end
lda = max(1,stride(A,2))
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
queue = global_queue(context(A), device())
$smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
B
end
end
end
function trmm!(side::Char,
uplo::Char,
transa::Char,
diag::Char,
alpha::Number,
A::oneStridedMatrix{T},
B::oneStridedMatrix{T},
C::oneStridedMatrix{T}) where T
trmm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
end
function trsm!(side::Char,
uplo::Char,
transa::Char,
diag::Char,
alpha::Number,
A::oneStridedMatrix{T},
B::oneStridedMatrix{T},
C::oneStridedMatrix{T}) where T
trsm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
end

## hemm
for (fname, elty) in ((:onemklZhemm,:ComplexF64),
(:onemklChemm,:ComplexF32))
Expand Down
24 changes: 24 additions & 0 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,14 @@ end
# move to host and compare
h_C = Array(dB)
@test C ≈ h_C

C = rand(T,m,n)
dC = oneArray(C)
beta = zero(T) # rand(T)
oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC)
h_C = Array(dC)
D = alpha*A*B + beta*C
@test D ≈ h_C
end

@testset "trmm" begin
Expand All @@ -684,6 +692,14 @@ end
dC = copy(dB)
oneMKL.trsm!('L','U','N','N',alpha,dA,dC)
@test C ≈ Array(dC)

C = rand(T,m,n)
dC = oneArray(C)
beta = rand(T)
oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC)
h_C = Array(dC)
D = alpha*(A\B) + beta*C
@test D ≈ h_C
end

@testset "left trsm" begin
Expand Down Expand Up @@ -725,6 +741,14 @@ end
dC = copy(dA)
oneMKL.trsm!('R','U','N','N',alpha,dB,dC)
@test C ≈ Array(dC)

C = rand(T,m,m)
dC = oneArray(C)
beta = rand(T)
oneMKL.trsm!('R','U','N','N',alpha,beta,dA,dB,dC)
h_C = Array(dC)
D = alpha*(A/B) + beta*C
@test D ≈ h_C
end

@testset "right trsm" begin
Expand Down