Skip to content

Commit 5c833a7

Browse files
committed
Fix sparse gemm and gemv
1 parent ff1602e commit 5c833a7

File tree

3 files changed

+256
-7
lines changed

3 files changed

+256
-7
lines changed

lib/mkl/interfaces.jl

Lines changed: 245 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,13 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
77
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
88
end
99

10-
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
11-
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
10+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
11+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
12+
return sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13+
end
14+
15+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCOO{T}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
16+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
1217
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
1318
end
1419

@@ -18,8 +23,14 @@ function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseM
1823
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
1924
end
2025

21-
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
22-
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
26+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
27+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
28+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
29+
return sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
30+
end
31+
32+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCOO{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
33+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
2334
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
2435
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
2536
end
@@ -31,3 +42,233 @@ end
3142
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
3243
sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
3344
end
45+
46+
# Handle Transpose and Adjoint wrappers for sparse matrices
47+
# Let the low-level wrappers handle the CSC->CSR conversion and flip_trans logic
48+
49+
# Matrix-vector multiplication with transpose/adjoint
50+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
51+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
52+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
53+
return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
54+
end
55+
56+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
57+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
58+
if tA == 'T'
59+
alpha = _add.alpha
60+
beta = _add.beta
61+
B .= conj.(B)
62+
C .= conj.(C)
63+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
64+
C .= conj.(C)
65+
B .= conj.(B)
66+
else
67+
tA_final = tA == 'N' ? 'C' : 'N'
68+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
69+
end
70+
return C
71+
end
72+
73+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
74+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
75+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
76+
return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
77+
end
78+
79+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
80+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
81+
if tA == 'T'
82+
alpha = _add.alpha
83+
beta = _add.beta
84+
B .= conj.(B)
85+
C .= conj.(C)
86+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
87+
C .= conj.(C)
88+
B .= conj.(B)
89+
else
90+
tA_final = tA == 'N' ? 'C' : 'N'
91+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
92+
end
93+
return C
94+
end
95+
96+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
97+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
98+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
99+
return sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
100+
end
101+
102+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasFloat}
103+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
104+
if tA == 'T'
105+
alpha = _add.alpha
106+
beta = _add.beta
107+
B .= conj.(B)
108+
C .= conj.(C)
109+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
110+
C .= conj.(C)
111+
B .= conj.(B)
112+
else
113+
tA_final = tA == 'N' ? 'C' : 'N'
114+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
115+
end
116+
return C
117+
end
118+
119+
# Handle Transpose{T, Adjoint{T, ...}} for complex matrices
120+
# transpose(adjoint(A)) for complex matrices needs special handling
121+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex}
122+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
123+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
124+
alpha = _add.alpha
125+
beta = _add.beta
126+
B .= conj.(B)
127+
C .= conj.(C)
128+
if tA == 'N'
129+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
130+
elseif tA == 'T'
131+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
132+
else # tA == 'C'
133+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
134+
end
135+
C .= conj.(C)
136+
B .= conj.(B)
137+
return C
138+
end
139+
140+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex}
141+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
142+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
143+
alpha = _add.alpha
144+
beta = _add.beta
145+
B .= conj.(B)
146+
C .= conj.(C)
147+
if tA == 'N'
148+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
149+
elseif tA == 'T'
150+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
151+
else # tA == 'C'
152+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
153+
end
154+
C .= conj.(C)
155+
B .= conj.(B)
156+
return C
157+
end
158+
159+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasComplex}
160+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
161+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
162+
alpha = _add.alpha
163+
beta = _add.beta
164+
B .= conj.(B)
165+
C .= conj.(C)
166+
if tA == 'N'
167+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
168+
elseif tA == 'T'
169+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
170+
else # tA == 'C'
171+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
172+
end
173+
C .= conj.(C)
174+
B .= conj.(B)
175+
return C
176+
end
177+
178+
# Custom * operators for Transpose{T, Adjoint{T, ...}} to ensure correct output size allocation
179+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, x::oneVector{T}) where {T <: BlasComplex}
180+
m, n = size(A)
181+
y = similar(x, T, m)
182+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
183+
return y
184+
end
185+
186+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, x::oneVector{T}) where {T <: BlasComplex}
187+
m, n = size(A)
188+
y = similar(x, T, m)
189+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
190+
return y
191+
end
192+
193+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, x::oneVector{T}) where {T <: BlasComplex}
194+
m, n = size(A)
195+
y = similar(x, T, m)
196+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
197+
return y
198+
end
199+
200+
# Matrix-matrix multiplication with transpose/adjoint
201+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
202+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
203+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
204+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
205+
return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
206+
end
207+
208+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
209+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
210+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
211+
if tA == 'T'
212+
alpha = _add.alpha
213+
beta = _add.beta
214+
B .= conj.(B)
215+
C .= conj.(C)
216+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
217+
C .= conj.(C)
218+
B .= conj.(B)
219+
else
220+
tA_final = tA == 'N' ? 'C' : 'N'
221+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
222+
end
223+
return C
224+
end
225+
226+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
227+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
228+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
229+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
230+
return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
231+
end
232+
233+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
234+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
235+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
236+
if tA == 'T'
237+
alpha = _add.alpha
238+
beta = _add.beta
239+
B .= conj.(B)
240+
C .= conj.(C)
241+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
242+
C .= conj.(C)
243+
B .= conj.(B)
244+
else
245+
tA_final = tA == 'N' ? 'C' : 'N'
246+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
247+
end
248+
return C
249+
end
250+
251+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
252+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
253+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
254+
tA_final = tA == 'N' ? 'T' : (tA == 'T' ? 'N' : 'C')
255+
return sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
256+
end
257+
258+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasFloat}
259+
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
260+
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
261+
if tA == 'T'
262+
alpha = _add.alpha
263+
beta = _add.beta
264+
B .= conj.(B)
265+
C .= conj.(C)
266+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
267+
C .= conj.(C)
268+
B .= conj.(B)
269+
else
270+
tA_final = tA == 'N' ? 'C' : 'N'
271+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
272+
end
273+
return C
274+
end

lib/mkl/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,5 @@ end
113113
ptrs = pointer.(batch)
114114
return oneArray(ptrs)
115115
end
116-
117116
flip_trans(trans::Char) = trans == 'N' ? 'T' : 'N'
118117
flip_uplo(uplo::Char) = uplo == 'L' ? 'U' : 'L'

test/onemkl.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,9 +1130,16 @@ end
11301130
oneMKL.sparse_optimize_gemv!(transa, dA)
11311131
oneMKL.sparse_gemv!(transa, alpha, dA, dx, beta, dy)
11321132
@test alpha * opa(A) * x + beta * y collect(dy)
1133-
end
1134-
end
1133+
dy = oneVector{T}(y)
1134+
@test alpha * opa(A) * x + beta * y Array(alpha * opa(dA) * dx + beta * dy)
1135+
tx = transa == 'N' ? rand(T, 20) : rand(T, 10)
1136+
ty = transa == 'N' ? rand(T, 10) : rand(T, 20)
1137+
dtx = oneVector{T}(tx)
1138+
dty = oneVector{T}(ty)
1139+
t = @test alpha * opa(A') * tx + beta * ty Array(alpha * opa(dA') * dtx + beta * dty)
1140+
end
11351141
end
1142+
end
11361143

11371144
@testset "sparse gemm" begin
11381145
@testset "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
@@ -1153,6 +1160,8 @@ end
11531160
oneMKL.sparse_gemm!(transa, transb, alpha, dA, dB, beta, dC)
11541161

11551162
@test alpha * opa(A) * opb(B) + beta * C collect(dC)
1163+
dC = oneMatrix{T}(C)
1164+
@test alpha * opa(A) * opb(B) + beta * C Array(alpha * opa(dA) * opb(dB) + beta * dC)
11561165
oneMKL.sparse_optimize_gemm!(transa, transb, 2, dA)
11571166
end
11581167
end

0 commit comments

Comments
 (0)