92
92
93
93
94
94
@inline function alloc_matmul_product (A:: AbstractArray{TA} , B:: AbstractMatrix{TB} ) where {TA,TB}
95
- # TODO : if `M` and `N` are statically sized, shouldn't return a `Matrix`.
96
- M, KA = size (A)
97
- KB, N = size (B)
98
- @assert KA == KB " Size mismatch."
95
+ # TODO : if `M` and `N` are statically sized, shouldn't return a `Matrix`.
96
+ M, KA = size (A)
97
+ KB, N = size (B)
98
+ @assert KA == KB " Size mismatch."
99
+ if M === StaticInt (1 )
100
+ transpose (Vector {promote_type(TA,TB)} (undef, N)), (M, KA, N)
101
+ else
99
102
Matrix {promote_type(TA,TB)} (undef, M, N), (M, KA, N)
103
+ end
100
104
end
101
105
@inline function alloc_matmul_product (A:: AbstractArray{TA} , B:: AbstractVector{TB} ) where {TA,TB}
102
106
# TODO : if `M` and `N` are statically sized, shouldn't return a `Matrix`.
@@ -105,10 +109,11 @@ end
105
109
@assert KA == KB " Size mismatch."
106
110
Vector {promote_type(TA,TB)} (undef, M), (M, KA, One ())
107
111
end
112
+
108
113
@inline function matmul_serial (A:: AbstractMatrix , B:: AbstractVecOrMat )
109
- C, (M,K,N) = alloc_matmul_product (A, B)
110
- _matmul_serial ! (C, A, B, One (), Zero (), (M,K,N))
111
- return C
114
+ C, (M,K,N) = alloc_matmul_product (A, B)
115
+ matmul_serial ! (C, A, B, One (), Zero (), (M,K,N), ArrayInterface . contiguous_axis (C ))
116
+ return C
112
117
end
113
118
114
119
@@ -132,12 +137,16 @@ end
132
137
@inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β)
133
138
matmul_serial! (C, A, B, α, β, nothing , ArrayInterface. contiguous_axis (C))
134
139
end
135
- @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, MKN, :: StaticInt{2} )
136
- _matmul_serial! (C' , B' , A' , α, β, nothing )
137
- return C
140
+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, :: Nothing , :: StaticInt{2} )
141
+ _matmul_serial! (transpose (C), transpose (B), transpose (A), α, β, nothing )
142
+ return C
143
+ end
144
+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, (M,K,N):: Tuple{Vararg{Integer,3}} , :: StaticInt{2} )
145
+ _matmul_serial! (transpose (C), transpose (B), transpose (A), α, β, (N,K,M))
146
+ return C
138
147
end
139
148
@inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, MKN, :: StaticInt )
140
- _matmul_serial! (C, A, B, α, β, nothing )
149
+ _matmul_serial! (C, A, B, α, β, MKN )
141
150
return C
142
151
end
143
152
212
221
Multiply matrices `A` and `B`.
213
222
"""
214
223
@inline function matmul (A:: AbstractMatrix , B:: AbstractVecOrMat )
215
- C, (M,K,N) = alloc_matmul_product (A, B)
216
- _matmul ! (C, A, B, One (), Zero (), nothing , (M,K,N))
217
- return C
224
+ C, (M,K,N) = alloc_matmul_product (A, B)
225
+ matmul ! (C, A, B, One (), Zero (), nothing , (M,K,N), ArrayInterface . contiguous_axis (C ))
226
+ return C
218
227
end
219
228
220
229
"""
@@ -235,13 +244,17 @@ end
235
244
@inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread)
236
245
matmul! (C, A, B, α, β, nthread, nothing , ArrayInterface. contiguous_axis (C))
237
246
end
238
- @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, MKN, :: StaticInt{2} )
239
- _matmul! (C' , B' , A' , α, β, nthread, MKN)
240
- return C
247
+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, :: Nothing , :: StaticInt{2} )
248
+ _matmul! (transpose (C), transpose (B), transpose (A), α, β, nthread, nothing )
249
+ return C
250
+ end
251
+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, (M,K,N):: Tuple{Vararg{Integer,3}} , :: StaticInt{2} )
252
+ _matmul! (transpose (C), transpose (B), transpose (A), α, β, nthread, (N,K,M))
253
+ return C
241
254
end
242
255
@inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, MKN, :: StaticInt )
243
- _matmul! (C, A, B, α, β, nthread, MKN)
244
- return C
256
+ _matmul! (C, A, B, α, β, nthread, MKN)
257
+ return C
245
258
end
246
259
247
260
@inline function dontpack (pA:: AbstractStridedPointer{Ta} , M, K, :: StaticInt{mc} , :: StaticInt{kc} , :: Type{Tc} , nspawn) where {mc, kc, Tc, Ta}
0 commit comments