Skip to content

Commit cefc49e

Browse files
authored
Add mat-vec mul. Fixes #83. (#95)
* Add mat-vec mul. Fixes #83. * Fix documentation.
1 parent 0292d3f commit cefc49e

File tree

7 files changed

+113
-62
lines changed

7 files changed

+113
-62
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Octavian"
22
uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
33
authors = ["Mason Protter", "Chris Elrod", "Dilum Aluthge", "contributors"]
4-
version = "0.2.17"
4+
version = "0.2.18"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -13,11 +13,11 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1313

1414
[compat]
1515
ArrayInterface = "3.1.14"
16-
LoopVectorization = "0.12.26"
16+
LoopVectorization = "0.12.32"
1717
Static = "0.2"
1818
StrideArraysCore = "0.1.11"
1919
ThreadingUtilities = "0.4"
20-
VectorizationBase = "0.20.11"
20+
VectorizationBase = "0.20.15"
2121
julia = "1.6"
2222

2323
[extras]

src/complex_matmul.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
real_rep(a::AbstractArray{Complex{T}, N}) where {T, N} = reinterpret(reshape, T, a)
22
#PtrArray(Ptr{T}(pointer(a)), (StaticInt(2), size(a)...))
33

4-
@inline function _matmul!(_C::AbstractMatrix{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractMatrix{Complex{V}},
4+
@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractVecOrMat{Complex{V}},
55
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
6-
C, A, B = real_rep.((_C, _A, _B))
6+
C, A, B = map(real_rep, (_C, _A, _B))
77

88
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
99
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
1010
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
1111
ηθ = η*θ
1212

13-
@avxt for n indices((C, B), 3), m indices((C, A), 2)
13+
@tturbo for n indices((C, B), 3), m indices((C, A), 2)
1414
Cmn_re = zero(T)
1515
Cmn_im = zero(T)
1616
for k indices((A, B), (3, 2))
@@ -23,14 +23,14 @@ real_rep(a::AbstractArray{Complex{T}, N}) where {T, N} = reinterpret(reshape, T,
2323
_C
2424
end
2525

26-
@inline function _matmul!(_C::AbstractMatrix{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractMatrix{Complex{V}},
26+
@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractVecOrMat{Complex{V}},
2727
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
28-
C, B = real_rep.((_C, _B))
28+
C, B = map(real_rep, (_C, _B))
2929

3030
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
3131
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
3232

33-
@avxt for n indices((C, B), 3), m indices((C, A), (2, 1))
33+
@tturbo for n indices((C, B), 3), m indices((C, A), (2, 1))
3434
Cmn_re = zero(T)
3535
Cmn_im = zero(T)
3636
for k indices((A, B), (2, 2))
@@ -43,14 +43,14 @@ end
4343
_C
4444
end
4545

46-
@inline function _matmul!(_C::AbstractMatrix{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractMatrix{V},
46+
@inline function _matmul!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractVecOrMat{V},
4747
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
48-
C, A = real_rep.((_C, _A))
48+
C, A = map(real_rep, (_C, _A))
4949

5050
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
5151
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
5252

53-
@avxt for n indices((C, B), (3, 2)), m indices((C, A), 2)
53+
@tturbo for n indices((C, B), (3, 2)), m indices((C, A), 2)
5454
Cmn_re = zero(T)
5555
Cmn_im = zero(T)
5656
for k indices((A, B), (3, 1))
@@ -67,15 +67,15 @@ end
6767

6868

6969

70-
@inline function _matmul_serial!(_C::AbstractMatrix{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractMatrix{Complex{V}},
70+
@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractVecOrMat{Complex{V}},
7171
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
72-
C, A, B = real_rep.((_C, _A, _B))
72+
C, A, B = map(real_rep, (_C, _A, _B))
7373

7474
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
7575
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
7676
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
7777
ηθ = η*θ
78-
@avx for n indices((C, B), 3), m indices((C, A), 2)
78+
@turbo for n indices((C, B), 3), m indices((C, A), 2)
7979
Cmn_re = zero(T)
8080
Cmn_im = zero(T)
8181
for k indices((A, B), (3, 2))
@@ -88,14 +88,14 @@ end
8888
_C
8989
end
9090

91-
@inline function _matmul_serial!(_C::AbstractMatrix{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractMatrix{Complex{V}},
91+
@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractVecOrMat{Complex{V}},
9292
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
93-
C, B = real_rep.((_C, _B))
93+
C, B = map(real_rep, (_C, _B))
9494

9595
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
9696
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
9797

98-
@avx for n indices((C, B), 3), m indices((C, A), (2, 1))
98+
@turbo for n indices((C, B), 3), m indices((C, A), (2, 1))
9999
Cmn_re = zero(T)
100100
Cmn_im = zero(T)
101101
for k indices((A, B), (2, 2))
@@ -108,14 +108,14 @@ end
108108
_C
109109
end
110110

111-
@inline function _matmul_serial!(_C::AbstractMatrix{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractMatrix{V},
111+
@inline function _matmul_serial!(_C::AbstractVecOrMat{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractVecOrMat{V},
112112
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
113-
C, A = real_rep.((_C, _A))
113+
C, A = map(real_rep, (_C, _A))
114114

115115
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
116116
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
117117

118-
@avx for n indices((C, B), (3, 2)), m indices((C, A), 2)
118+
@turbo for n indices((C, B), (3, 2)), m indices((C, A), 2)
119119
Cmn_re = zero(T)
120120
Cmn_im = zero(T)
121121
for k indices((A, B), (3, 1))

src/macrokernels.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ macro kernel(pack::Bool, ex::Expr)
2727
else
2828
Ainit = areconstruct = Expr[]
2929
end
30-
lvkern = esc(:(@avx inline=true $ex))
30+
lvkern = esc(:(@turbo inline=true $ex))
3131

3232
loopnest = quote
3333
let ãₚ = ãₚ, c = c, $(esc(:B)) = VectorizationBase.reconstruct_ptr(B, b), m = m
@@ -81,7 +81,7 @@ macro kernel(pack::Bool, ex::Expr)
8181
Expr(:block, preheader, loopnest)
8282
end
8383
@inline function loopmul!(C, A, B, α, β, M, K, N)
84-
@avx for n CloseOpen(N), m CloseOpen(M)
84+
@turbo for n CloseOpen(N), m CloseOpen(M)
8585
Cₘₙ = zero(eltype(C))
8686
for k CloseOpen(K)
8787
Cₘₙ += A[m,k] * B[k,n]
@@ -142,7 +142,7 @@ function packaloopmul!(
142142
end
143143

144144
@inline function inlineloopmul!(C, A, B, α, β, M, K, N)
145-
@avx inline=true for m CloseOpen(M), n CloseOpen(N)
145+
@turbo inline=true for m CloseOpen(M), n CloseOpen(N)
146146
Cₘₙ = zero(eltype(C))
147147
for k CloseOpen(K)
148148
Cₘₙ += A[m,k] * B[k,n]

src/matmul.jl

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function matmul_st_pack_A_and_B!(
5050
for k CloseOpen(Kiter)
5151
ksize = ifelse(k < Krem, Kblock_Krem, Kblock)
5252
_B = default_zerobased_stridedpointer(L3ptr, (One(),ksize))
53-
unsafe_copyto_avx!(_B, B, ksize, nsize)
53+
unsafe_copyto_turbo!(_B, B, ksize, nsize)
5454
let A = A, C = C, B = _B
5555
for m in CloseOpen(Miter)
5656
msize = ifelse((m+1) == Miter, Mremfinal, ifelse(m < Mrem, Mblock_Mrem, Mblock))
@@ -91,14 +91,21 @@ end
9191
end
9292

9393

94-
@inline function alloc_matmul_product(A::AbstractArray{TA}, B::AbstractArray{TB}) where {TA,TB}
94+
@inline function alloc_matmul_product(A::AbstractArray{TA}, B::AbstractMatrix{TB}) where {TA,TB}
9595
# TODO: if `M` and `N` are statically sized, shouldn't return a `Matrix`.
9696
M, KA = size(A)
9797
KB, N = size(B)
9898
@assert KA == KB "Size mismatch."
9999
Matrix{promote_type(TA,TB)}(undef, M, N), (M, KA, N)
100100
end
101-
@inline function matmul_serial(A::AbstractMatrix, B::AbstractMatrix)
101+
@inline function alloc_matmul_product(A::AbstractArray{TA}, B::AbstractVector{TB}) where {TA,TB}
102+
# TODO: if `M` and `N` are statically sized, shouldn't return a `Matrix`.
103+
M, KA = size(A)
104+
KB = length(B)
105+
@assert KA == KB "Size mismatch."
106+
Vector{promote_type(TA,TB)}(undef, M), (M, KA, One())
107+
end
108+
@inline function matmul_serial(A::AbstractMatrix, B::AbstractVecOrMat)
102109
C, (M,K,N) = alloc_matmul_product(A, B)
103110
_matmul_serial!(C, A, B, One(), Zero(), (M,K,N))
104111
return C
@@ -116,20 +123,20 @@ function maybeinline(::StaticInt{M}, ::StaticInt{N}, ::Type{T}, ::Val{false}) wh
116123
end
117124

118125

119-
@inline function matmul_serial!(C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix) where {T}
126+
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat)
120127
matmul_serial!(C, A, B, One(), Zero(), nothing, ArrayInterface.contiguous_axis(C))
121128
end
122-
@inline function matmul_serial!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α)
129+
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α)
123130
matmul_serial!(C, A, B, α, Zero(), nothing, ArrayInterface.contiguous_axis(C))
124131
end
125-
@inline function matmul_serial!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α, β)
132+
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β)
126133
matmul_serial!(C, A, B, α, β, nothing, ArrayInterface.contiguous_axis(C))
127134
end
128-
@inline function matmul_serial!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN, ::StaticInt{2})
135+
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, MKN, ::StaticInt{2})
129136
_matmul_serial!(C', B', A', α, β, nothing)
130137
return C
131138
end
132-
@inline function matmul_serial!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN, ::StaticInt)
139+
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, MKN, ::StaticInt)
133140
_matmul_serial!(C, A, B, α, β, nothing)
134141
return C
135142
end
@@ -149,7 +156,7 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
149156
"""
150157
@inline function _matmul_serial!(
151158
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
152-
) where {T}
159+
) where {T<:Real}
153160
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
154161
if M * N == 0
155162
return
@@ -175,13 +182,13 @@ Otherwise, based on the array's size, whether they are transposed, and whether t
175182
end # function
176183

177184
function matmul_only_β!(C::AbstractMatrix{T}, β::StaticInt{0}) where T
178-
@avx for i=1:length(C)
185+
@turbo for i=1:length(C)
179186
C[i] = zero(T)
180187
end
181188
end
182189

183190
function matmul_only_β!(C::AbstractMatrix{T}, β) where T
184-
@avx for i=1:length(C)
191+
@turbo for i=1:length(C)
185192
C[i] = β * C[i]
186193
end
187194
end
@@ -204,7 +211,7 @@ end
204211
205212
Multiply matrices `A` and `B`.
206213
"""
207-
@inline function matmul(A::AbstractMatrix, B::AbstractMatrix)
214+
@inline function matmul(A::AbstractMatrix, B::AbstractVecOrMat)
208215
C, (M,K,N) = alloc_matmul_product(A, B)
209216
_matmul!(C, A, B, One(), Zero(), nothing, (M,K,N))
210217
return C
@@ -213,26 +220,26 @@ end
213220
"""
214221
matmul!(C, A, B[, α, β, max_threads])
215222
216-
Calculates `C = α * A * B + β * C` in place, overwriting the contents of `A`.
223+
Calculates `C = α * A * B + β * C` in place, overwriting the contents of `C`.
217224
It may use up to `max_threads` threads. It will not use threads when nested in other threaded code.
218225
"""
219-
@inline function matmul!(C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix) where {T}
226+
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat)
220227
matmul!(C, A, B, One(), Zero(), nothing, nothing, ArrayInterface.contiguous_axis(C))
221228
end
222-
@inline function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α)
229+
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α)
223230
matmul!(C, A, B, α, Zero(), nothing, nothing, ArrayInterface.contiguous_axis(C))
224231
end
225-
@inline function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α, β)
232+
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β)
226233
matmul!(C, A, B, α, β, nothing, nothing, ArrayInterface.contiguous_axis(C))
227234
end
228-
@inline function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α, β, nthread)
235+
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, nthread)
229236
matmul!(C, A, B, α, β, nthread, nothing, ArrayInterface.contiguous_axis(C))
230237
end
231-
@inline function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α, β, nthread, MKN, ::StaticInt{2})
238+
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, nthread, MKN, ::StaticInt{2})
232239
_matmul!(C', B', A', α, β, nthread, MKN)
233240
return C
234241
end
235-
@inline function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α, β, nthread, MKN, ::StaticInt)
242+
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, nthread, MKN, ::StaticInt)
236243
_matmul!(C, A, B, α, β, nthread, MKN)
237244
return C
238245
end
@@ -243,7 +250,7 @@ end
243250
end
244251

245252
# passing MKN directly would let osmeone skip the size check.
246-
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}#::Union{Nothing,Tuple{Vararg{Integer,3}}}) where {T}
253+
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T<:Real}
247254
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
248255
if M * N == 0
249256
return
@@ -443,7 +450,7 @@ function sync_mul!(
443450
for k CloseOpen(Kiter)
444451
ksize = ifelse(k < Krem, Kblock_Krem, Kblock)
445452
_B = default_zerobased_stridedpointer(bc, (One(), ksize))
446-
unsafe_copyto_avx!(gesp(_B, (Zero(), pack_offset)), gesp(B, (Zero(), pack_offset)), ksize, pack_len)
453+
unsafe_copyto_turbo!(gesp(_B, (Zero(), pack_offset)), gesp(B, (Zero(), pack_offset)), ksize, pack_len)
447454
# synchronize before starting the multiplication, to ensure `B` is packed
448455
_mv = _atomic_add!(myp, one(UInt))
449456
sync_iters += one(UInt)
@@ -489,3 +496,26 @@ function sync_mul!(
489496
end
490497
nothing
491498
end
499+
500+
function _matmul!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN, contig_axis) where {T<:Real}
501+
@tturbo for m indices((A,y),1)
502+
yₘ = zero(T)
503+
for n indices((A,x),(2,1))
504+
yₘ += A[m,n]*x[n]
505+
end
506+
y[m] = α * yₘ + β * y[m]
507+
end
508+
return y
509+
end
510+
function _matmul_serial!(y::AbstractVector{T}, A::AbstractMatrix, x::AbstractVector, α, β, MKN) where {T<:Real}
511+
@turbo for m indices((A,y),1)
512+
yₘ = zero(T)
513+
for n indices((A,x),(2,1))
514+
yₘ += A[m,n]*x[n]
515+
end
516+
y[m] = α * yₘ + β * y[m]
517+
end
518+
return y
519+
end
520+
521+

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ given a mix of static and dynamic sizes.
1616
(_select(MA, MC), _select(KA, KB), _select(NB, NC))
1717
end
1818

19-
function unsafe_copyto_avx!(pB, pA, M, N)
20-
LoopVectorization.@avx for n CloseOpen(N), m CloseOpen(M)
19+
function unsafe_copyto_turbo!(pB, pA, M, N)
20+
LoopVectorization.@turbo for n CloseOpen(N), m CloseOpen(M)
2121
pB[m,n] = pA[m,n]
2222
end
2323
end

0 commit comments

Comments
 (0)