@@ -50,7 +50,7 @@ function matmul_st_pack_A_and_B!(
50
50
for k ∈ CloseOpen (Kiter)
51
51
ksize = ifelse (k < Krem, Kblock_Krem, Kblock)
52
52
_B = default_zerobased_stridedpointer (L3ptr, (One (),ksize))
53
- unsafe_copyto_avx ! (_B, B, ksize, nsize)
53
+ unsafe_copyto_turbo ! (_B, B, ksize, nsize)
54
54
let A = A, C = C, B = _B
55
55
for m in CloseOpen (Miter)
56
56
msize = ifelse ((m+ 1 ) == Miter, Mremfinal, ifelse (m < Mrem, Mblock_Mrem, Mblock))
91
91
end
92
92
93
93
94
- @inline function alloc_matmul_product (A:: AbstractArray{TA} , B:: AbstractArray {TB} ) where {TA,TB}
94
+ @inline function alloc_matmul_product (A:: AbstractArray{TA} , B:: AbstractMatrix {TB} ) where {TA,TB}
95
95
# TODO : if `M` and `N` are statically sized, shouldn't return a `Matrix`.
96
96
M, KA = size (A)
97
97
KB, N = size (B)
98
98
@assert KA == KB " Size mismatch."
99
99
Matrix {promote_type(TA,TB)} (undef, M, N), (M, KA, N)
100
100
end
101
- @inline function matmul_serial (A:: AbstractMatrix , B:: AbstractMatrix )
101
+ @inline function alloc_matmul_product (A:: AbstractArray{TA} , B:: AbstractVector{TB} ) where {TA,TB}
102
+ # TODO : if `M` and `N` are statically sized, shouldn't return a `Matrix`.
103
+ M, KA = size (A)
104
+ KB = length (B)
105
+ @assert KA == KB " Size mismatch."
106
+ Vector {promote_type(TA,TB)} (undef, M), (M, KA, One ())
107
+ end
108
+ @inline function matmul_serial (A:: AbstractMatrix , B:: AbstractVecOrMat )
102
109
C, (M,K,N) = alloc_matmul_product (A, B)
103
110
_matmul_serial! (C, A, B, One (), Zero (), (M,K,N))
104
111
return C
@@ -116,20 +123,20 @@ function maybeinline(::StaticInt{M}, ::StaticInt{N}, ::Type{T}, ::Val{false}) wh
116
123
end
117
124
118
125
119
- @inline function matmul_serial! (C:: AbstractMatrix{T} , A:: AbstractMatrix , B:: AbstractMatrix ) where {T}
126
+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat )
120
127
matmul_serial! (C, A, B, One (), Zero (), nothing , ArrayInterface. contiguous_axis (C))
121
128
end
122
- @inline function matmul_serial! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α)
129
+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α)
123
130
matmul_serial! (C, A, B, α, Zero (), nothing , ArrayInterface. contiguous_axis (C))
124
131
end
125
- @inline function matmul_serial! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α, β)
132
+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β)
126
133
matmul_serial! (C, A, B, α, β, nothing , ArrayInterface. contiguous_axis (C))
127
134
end
128
- @inline function matmul_serial! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α, β, MKN, :: StaticInt{2} )
135
+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, MKN, :: StaticInt{2} )
129
136
_matmul_serial! (C' , B' , A' , α, β, nothing )
130
137
return C
131
138
end
132
- @inline function matmul_serial! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α, β, MKN, :: StaticInt )
139
+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, MKN, :: StaticInt )
133
140
_matmul_serial! (C, A, B, α, β, nothing )
134
141
return C
135
142
end
@@ -149,7 +156,7 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
149
156
"""
150
157
@inline function _matmul_serial! (
151
158
C:: AbstractMatrix{T} , A:: AbstractMatrix , B:: AbstractMatrix , α, β, MKN
152
- ) where {T}
159
+ ) where {T<: Real }
153
160
M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
154
161
if M * N == 0
155
162
return
@@ -175,13 +182,13 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
175
182
end # function
176
183
177
184
function matmul_only_β! (C:: AbstractMatrix{T} , β:: StaticInt{0} ) where T
178
- @avx for i= 1 : length (C)
185
+ @turbo for i= 1 : length (C)
179
186
C[i] = zero (T)
180
187
end
181
188
end
182
189
183
190
function matmul_only_β! (C:: AbstractMatrix{T} , β) where T
184
- @avx for i= 1 : length (C)
191
+ @turbo for i= 1 : length (C)
185
192
C[i] = β * C[i]
186
193
end
187
194
end
204
211
205
212
Multiply matrices `A` and `B`.
206
213
"""
207
- @inline function matmul (A:: AbstractMatrix , B:: AbstractMatrix )
214
+ @inline function matmul (A:: AbstractMatrix , B:: AbstractVecOrMat )
208
215
C, (M,K,N) = alloc_matmul_product (A, B)
209
216
_matmul! (C, A, B, One (), Zero (), nothing , (M,K,N))
210
217
return C
@@ -213,26 +220,26 @@ end
213
220
"""
214
221
matmul!(C, A, B[, α, β, max_threads])
215
222
216
- Calculates `C = α * A * B + β * C` in place, overwriting the contents of `A `.
223
+ Calculates `C = α * A * B + β * C` in place, overwriting the contents of `C `.
217
224
It may use up to `max_threads` threads. It will not use threads when nested in other threaded code.
218
225
"""
219
- @inline function matmul! (C:: AbstractMatrix{T} , A:: AbstractMatrix , B:: AbstractMatrix ) where {T}
226
+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat )
220
227
matmul! (C, A, B, One (), Zero (), nothing , nothing , ArrayInterface. contiguous_axis (C))
221
228
end
222
- @inline function matmul! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α)
229
+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α)
223
230
matmul! (C, A, B, α, Zero (), nothing , nothing , ArrayInterface. contiguous_axis (C))
224
231
end
225
- @inline function matmul! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α, β)
232
+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β)
226
233
matmul! (C, A, B, α, β, nothing , nothing , ArrayInterface. contiguous_axis (C))
227
234
end
228
- @inline function matmul! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α, β, nthread)
235
+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread)
229
236
matmul! (C, A, B, α, β, nthread, nothing , ArrayInterface. contiguous_axis (C))
230
237
end
231
- @inline function matmul! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α, β, nthread, MKN, :: StaticInt{2} )
238
+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, MKN, :: StaticInt{2} )
232
239
_matmul! (C' , B' , A' , α, β, nthread, MKN)
233
240
return C
234
241
end
235
- @inline function matmul! (C:: AbstractMatrix , A:: AbstractMatrix , B:: AbstractMatrix , α, β, nthread, MKN, :: StaticInt )
242
+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, MKN, :: StaticInt )
236
243
_matmul! (C, A, B, α, β, nthread, MKN)
237
244
return C
238
245
end
243
250
end
244
251
245
252
# passing MKN directly would let osmeone skip the size check.
246
- @inline function _matmul! (C:: AbstractMatrix{T} , A, B, α, β, nthread, MKN) where {T} # ::Union{Nothing,Tuple{Vararg{Integer,3}}}) where {T }
253
+ @inline function _matmul! (C:: AbstractMatrix{T} , A, B, α, β, nthread, MKN) where {T<: Real }
247
254
M, K, N = MKN === nothing ? matmul_sizes (C, A, B) : MKN
248
255
if M * N == 0
249
256
return
@@ -443,7 +450,7 @@ function sync_mul!(
443
450
for k ∈ CloseOpen (Kiter)
444
451
ksize = ifelse (k < Krem, Kblock_Krem, Kblock)
445
452
_B = default_zerobased_stridedpointer (bc, (One (), ksize))
446
- unsafe_copyto_avx ! (gesp (_B, (Zero (), pack_offset)), gesp (B, (Zero (), pack_offset)), ksize, pack_len)
453
+ unsafe_copyto_turbo ! (gesp (_B, (Zero (), pack_offset)), gesp (B, (Zero (), pack_offset)), ksize, pack_len)
447
454
# synchronize before starting the multiplication, to ensure `B` is packed
448
455
_mv = _atomic_add! (myp, one (UInt))
449
456
sync_iters += one (UInt)
@@ -489,3 +496,26 @@ function sync_mul!(
489
496
end
490
497
nothing
491
498
end
499
+
500
+ function _matmul! (y:: AbstractVector{T} , A:: AbstractMatrix , x:: AbstractVector , α, β, MKN, contig_axis) where {T<: Real }
501
+ @tturbo for m ∈ indices ((A,y),1 )
502
+ yₘ = zero (T)
503
+ for n ∈ indices ((A,x),(2 ,1 ))
504
+ yₘ += A[m,n]* x[n]
505
+ end
506
+ y[m] = α * yₘ + β * y[m]
507
+ end
508
+ return y
509
+ end
510
+ function _matmul_serial! (y:: AbstractVector{T} , A:: AbstractMatrix , x:: AbstractVector , α, β, MKN) where {T<: Real }
511
+ @turbo for m ∈ indices ((A,y),1 )
512
+ yₘ = zero (T)
513
+ for n ∈ indices ((A,x),(2 ,1 ))
514
+ yₘ += A[m,n]* x[n]
515
+ end
516
+ y[m] = α * yₘ + β * y[m]
517
+ end
518
+ return y
519
+ end
520
+
521
+
0 commit comments