Skip to content

Commit 5ded095

Browse files
authored
fix performance of vecmatmul (#97)
1 parent fb0d625 commit 5ded095

File tree

3 files changed

+41
-20
lines changed

3 files changed

+41
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Octavian"
22
uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
33
authors = ["Mason Protter", "Chris Elrod", "Dilum Aluthge", "contributors"]
4-
version = "0.2.18"
4+
version = "0.2.19"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/matmul.jl

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,15 @@ end
9292

9393

9494
@inline function alloc_matmul_product(A::AbstractArray{TA}, B::AbstractMatrix{TB}) where {TA,TB}
95-
# TODO: if `M` and `N` are statically sized, shouldn't return a `Matrix`.
96-
M, KA = size(A)
97-
KB, N = size(B)
98-
@assert KA == KB "Size mismatch."
95+
# TODO: if `M` and `N` are statically sized, shouldn't return a `Matrix`.
96+
M, KA = size(A)
97+
KB, N = size(B)
98+
@assert KA == KB "Size mismatch."
99+
if M === StaticInt(1)
100+
transpose(Vector{promote_type(TA,TB)}(undef, N)), (M, KA, N)
101+
else
99102
Matrix{promote_type(TA,TB)}(undef, M, N), (M, KA, N)
103+
end
100104
end
101105
@inline function alloc_matmul_product(A::AbstractArray{TA}, B::AbstractVector{TB}) where {TA,TB}
102106
# TODO: if `M` and `N` are statically sized, shouldn't return a `Matrix`.
@@ -105,10 +109,11 @@ end
105109
@assert KA == KB "Size mismatch."
106110
Vector{promote_type(TA,TB)}(undef, M), (M, KA, One())
107111
end
112+
108113
@inline function matmul_serial(A::AbstractMatrix, B::AbstractVecOrMat)
109-
C, (M,K,N) = alloc_matmul_product(A, B)
110-
_matmul_serial!(C, A, B, One(), Zero(), (M,K,N))
111-
return C
114+
C, (M,K,N) = alloc_matmul_product(A, B)
115+
matmul_serial!(C, A, B, One(), Zero(), (M,K,N), ArrayInterface.contiguous_axis(C))
116+
return C
112117
end
113118

114119

@@ -132,12 +137,16 @@ end
132137
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β)
133138
matmul_serial!(C, A, B, α, β, nothing, ArrayInterface.contiguous_axis(C))
134139
end
135-
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, MKN, ::StaticInt{2})
136-
_matmul_serial!(C', B', A', α, β, nothing)
137-
return C
140+
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, ::Nothing, ::StaticInt{2})
141+
_matmul_serial!(transpose(C), transpose(B), transpose(A), α, β, nothing)
142+
return C
143+
end
144+
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, (M,K,N)::Tuple{Vararg{Integer,3}}, ::StaticInt{2})
145+
_matmul_serial!(transpose(C), transpose(B), transpose(A), α, β, (N,K,M))
146+
return C
138147
end
139148
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, MKN, ::StaticInt)
140-
_matmul_serial!(C, A, B, α, β, nothing)
149+
_matmul_serial!(C, A, B, α, β, MKN)
141150
return C
142151
end
143152

@@ -212,9 +221,9 @@ end
212221
Multiply matrices `A` and `B`.
213222
"""
214223
@inline function matmul(A::AbstractMatrix, B::AbstractVecOrMat)
215-
C, (M,K,N) = alloc_matmul_product(A, B)
216-
_matmul!(C, A, B, One(), Zero(), nothing, (M,K,N))
217-
return C
224+
C, (M,K,N) = alloc_matmul_product(A, B)
225+
matmul!(C, A, B, One(), Zero(), nothing, (M,K,N), ArrayInterface.contiguous_axis(C))
226+
return C
218227
end
219228

220229
"""
@@ -235,13 +244,17 @@ end
235244
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, nthread)
236245
matmul!(C, A, B, α, β, nthread, nothing, ArrayInterface.contiguous_axis(C))
237246
end
238-
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, nthread, MKN, ::StaticInt{2})
239-
_matmul!(C', B', A', α, β, nthread, MKN)
240-
return C
247+
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, nthread, ::Nothing, ::StaticInt{2})
248+
_matmul!(transpose(C), transpose(B), transpose(A), α, β, nthread, nothing)
249+
return C
250+
end
251+
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, nthread, (M,K,N)::Tuple{Vararg{Integer,3}}, ::StaticInt{2})
252+
_matmul!(transpose(C), transpose(B), transpose(A), α, β, nthread, (N,K,M))
253+
return C
241254
end
242255
@inline function matmul!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, nthread, MKN, ::StaticInt)
243-
_matmul!(C, A, B, α, β, nthread, MKN)
244-
return C
256+
_matmul!(C, A, B, α, β, nthread, MKN)
257+
return C
245258
end
246259

247260
@inline function dontpack(pA::AbstractStridedPointer{Ta}, M, K, ::StaticInt{mc}, ::StaticInt{kc}, ::Type{Tc}, nspawn) where {mc, kc, Tc, Ta}

test/_matmul.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ for T ∈ (ComplexF32, ComplexF64, Complex{Int64}, Complex{Int32})
3939
@test @time(Octavian.matmul(A′, B′)) A′B′
4040

4141
@test @time(Octavian.matmul(A, b)) Ab
42+
@test transpose(@time(Octavian.matmul(transpose(b), transpose(A)))) Ab
4243
@test @time(Octavian.matmul(A, bre)) Abre
4344
@test @time(Octavian.matmul(Are, b)) Areb
4445
@test @time(Octavian.matmul(A′, b)) A′b
46+
@test transpose(@time(Octavian.matmul(transpose(b), transpose(A′)))) A′b
4547

4648
@test @time(Octavian.matmul_serial(A, B)) AB
4749
@test @time(Octavian.matmul_serial(A, Bre)) ABre
@@ -51,9 +53,11 @@ for T ∈ (ComplexF32, ComplexF64, Complex{Int64}, Complex{Int32})
5153
@test @time(Octavian.matmul_serial(A′, B′)) A′B′
5254

5355
@test @time(Octavian.matmul_serial(A, b)) Ab
56+
@test transpose(@time(Octavian.matmul_serial(transpose(b), transpose(A)))) Ab
5457
@test @time(Octavian.matmul_serial(A, bre)) Abre
5558
@test @time(Octavian.matmul_serial(Are, b)) Areb
5659
@test @time(Octavian.matmul_serial(A′, b)) A′b
60+
@test transpose(@time(Octavian.matmul_serial(transpose(b), transpose(A′)))) A′b
5761

5862
C = Matrix{T}(undef, n, m)'
5963
@test @time(Octavian.matmul!(C, A, B)) AB
@@ -92,8 +96,12 @@ end
9296
@test @time(Octavian.matmul_serial(A′, B′)) AB
9397
@test @time(Octavian.matmul(A, b)) Ab
9498
@test @time(Octavian.matmul(A′, b)) Ab
99+
@test @time(Octavian.matmul(b', A'))' Ab
100+
@test @time(Octavian.matmul(b', A′'))' Ab
95101
@test @time(Octavian.matmul_serial(A, b)) Ab
96102
@test @time(Octavian.matmul_serial(A′, b)) Ab
103+
@test @time(Octavian.matmul_serial(b', A'))' Ab
104+
@test @time(Octavian.matmul_serial(b', A′'))' Ab
97105
end
98106
end
99107
end

0 commit comments

Comments
 (0)