You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lib/mkl/interfaces.jl
+245-4Lines changed: 245 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -7,8 +7,13 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
7
7
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
8
8
end
9
9
10
-
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <:BlasReal
11
-
tA = tA in ('S', 's', 'H', 'h') ?'T':flip_trans(tA)
10
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasFloat}
11
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
12
+
returnsparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13
+
end
14
+
15
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCOO{T}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasFloat}
16
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
12
17
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
13
18
end
14
19
@@ -18,8 +23,14 @@ function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseM
18
23
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
19
24
end
20
25
21
-
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <:BlasReal
22
-
tA = tA in ('S', 's', 'H', 'h') ?'T':flip_trans(tA)
26
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <:BlasFloat}
27
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
28
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
29
+
returnsparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
30
+
end
31
+
32
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCOO{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <:BlasFloat}
33
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
23
34
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
24
35
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
25
36
end
@@ -31,3 +42,233 @@ end
31
42
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <:BlasFloat
32
43
sparse_trsm!(uploc, tfun === identity ?'N': tfun === transpose ?'T':'C', 'N', isunitc, one(T), A, B, C)
33
44
end
45
+
46
+
# Handle Transpose and Adjoint wrappers for sparse matrices
47
+
# Let the low-level wrappers handle the CSC->CSR conversion and flip_trans logic
48
+
49
+
# Matrix-vector multiplication with transpose/adjoint
50
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasFloat}
51
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
52
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
53
+
returnsparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
54
+
end
55
+
56
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasFloat}
57
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
58
+
if tA =='T'
59
+
alpha = _add.alpha
60
+
beta = _add.beta
61
+
B .=conj.(B)
62
+
C .=conj.(C)
63
+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
64
+
C .=conj.(C)
65
+
B .=conj.(B)
66
+
else
67
+
tA_final = tA =='N'?'C':'N'
68
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
69
+
end
70
+
return C
71
+
end
72
+
73
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasFloat}
74
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
75
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
76
+
returnsparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
77
+
end
78
+
79
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasFloat}
80
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
81
+
if tA =='T'
82
+
alpha = _add.alpha
83
+
beta = _add.beta
84
+
B .=conj.(B)
85
+
C .=conj.(C)
86
+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
87
+
C .=conj.(C)
88
+
B .=conj.(B)
89
+
else
90
+
tA_final = tA =='N'?'C':'N'
91
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
92
+
end
93
+
return C
94
+
end
95
+
96
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasFloat}
97
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
98
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
99
+
returnsparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
100
+
end
101
+
102
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasFloat}
103
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
104
+
if tA =='T'
105
+
alpha = _add.alpha
106
+
beta = _add.beta
107
+
B .=conj.(B)
108
+
C .=conj.(C)
109
+
sparse_gemv!('N', conj(alpha), A.parent, B, conj(beta), C)
110
+
C .=conj.(C)
111
+
B .=conj.(B)
112
+
else
113
+
tA_final = tA =='N'?'C':'N'
114
+
sparse_gemv!(tA_final, _add.alpha, A.parent, B, _add.beta, C)
115
+
end
116
+
return C
117
+
end
118
+
119
+
# Handle Transpose{T, Adjoint{T, ...}} for complex matrices
120
+
# transpose(adjoint(A)) for complex matrices needs special handling
121
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasComplex}
122
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
123
+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
124
+
alpha = _add.alpha
125
+
beta = _add.beta
126
+
B .=conj.(B)
127
+
C .=conj.(C)
128
+
if tA =='N'
129
+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
130
+
elseif tA =='T'
131
+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
132
+
else# tA == 'C'
133
+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
134
+
end
135
+
C .=conj.(C)
136
+
B .=conj.(B)
137
+
return C
138
+
end
139
+
140
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasComplex}
141
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
142
+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
143
+
alpha = _add.alpha
144
+
beta = _add.beta
145
+
B .=conj.(B)
146
+
C .=conj.(C)
147
+
if tA =='N'
148
+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
149
+
elseif tA =='T'
150
+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
151
+
else# tA == 'C'
152
+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
153
+
end
154
+
C .=conj.(C)
155
+
B .=conj.(B)
156
+
return C
157
+
end
158
+
159
+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, B::oneVector{T}, _add::MulAddMul) where {T <:BlasComplex}
160
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
161
+
# transpose(adjoint(A)) = conj(A), so we need to conjugate
162
+
alpha = _add.alpha
163
+
beta = _add.beta
164
+
B .=conj.(B)
165
+
C .=conj.(C)
166
+
if tA =='N'
167
+
sparse_gemv!('N', conj(alpha), A.parent.parent, B, conj(beta), C)
168
+
elseif tA =='T'
169
+
sparse_gemv!('T', conj(alpha), A.parent.parent, B, conj(beta), C)
170
+
else# tA == 'C'
171
+
sparse_gemv!('C', conj(alpha), A.parent.parent, B, conj(beta), C)
172
+
end
173
+
C .=conj.(C)
174
+
B .=conj.(B)
175
+
return C
176
+
end
177
+
178
+
# Custom * operators for Transpose{T, Adjoint{T, ...}} to ensure correct output size allocation
179
+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSR{T}}}, x::oneVector{T}) where {T <:BlasComplex}
180
+
m, n =size(A)
181
+
y =similar(x, T, m)
182
+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
183
+
return y
184
+
end
185
+
186
+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCSC{T}}}, x::oneVector{T}) where {T <:BlasComplex}
187
+
m, n =size(A)
188
+
y =similar(x, T, m)
189
+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
190
+
return y
191
+
end
192
+
193
+
function Base.:*(A::Transpose{T, <:Adjoint{T, <:oneSparseMatrixCOO{T}}}, x::oneVector{T}) where {T <:BlasComplex}
194
+
m, n =size(A)
195
+
y =similar(x, T, m)
196
+
LinearAlgebra.generic_matvecmul!(y, 'N', A, x, LinearAlgebra.MulAddMul(one(T), zero(T)))
197
+
return y
198
+
end
199
+
200
+
# Matrix-matrix multiplication with transpose/adjoint
201
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <:BlasFloat}
202
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
203
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
204
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
205
+
returnsparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
206
+
end
207
+
208
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSR{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <:BlasFloat}
209
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
210
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
211
+
if tA =='T'
212
+
alpha = _add.alpha
213
+
beta = _add.beta
214
+
B .=conj.(B)
215
+
C .=conj.(C)
216
+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
217
+
C .=conj.(C)
218
+
B .=conj.(B)
219
+
else
220
+
tA_final = tA =='N'?'C':'N'
221
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
222
+
end
223
+
return C
224
+
end
225
+
226
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <:BlasFloat}
227
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
228
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
229
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
230
+
returnsparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
231
+
end
232
+
233
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCSC{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <:BlasFloat}
234
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
235
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
236
+
if tA =='T'
237
+
alpha = _add.alpha
238
+
beta = _add.beta
239
+
B .=conj.(B)
240
+
C .=conj.(C)
241
+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
242
+
C .=conj.(C)
243
+
B .=conj.(B)
244
+
else
245
+
tA_final = tA =='N'?'C':'N'
246
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
247
+
end
248
+
return C
249
+
end
250
+
251
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Transpose{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <:BlasFloat}
252
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
253
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
254
+
tA_final = tA =='N'?'T': (tA =='T'?'N':'C')
255
+
returnsparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
256
+
end
257
+
258
+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::Adjoint{T, <:oneSparseMatrixCOO{T}}, B::oneMatrix{T}, _add::MulAddMul) where {T <:BlasFloat}
259
+
tA = tA in ('S', 's', 'H', 'h') ?'N': tA
260
+
tB = tB in ('S', 's', 'H', 'h') ?'N': tB
261
+
if tA =='T'
262
+
alpha = _add.alpha
263
+
beta = _add.beta
264
+
B .=conj.(B)
265
+
C .=conj.(C)
266
+
sparse_gemm!('N', tB, conj(alpha), A.parent, B, conj(beta), C)
267
+
C .=conj.(C)
268
+
B .=conj.(B)
269
+
else
270
+
tA_final = tA =='N'?'C':'N'
271
+
sparse_gemm!(tA_final, tB, _add.alpha, A.parent, B, _add.beta, C)
0 commit comments