Skip to content

Commit ba60608

Browse files
authored
Complex Matrix Multiplication (unpacked) (#85)
* add unpacked complex matrix multiplication * increment version * Complex (unpacked) matrix multiplication * move compat up to v1.6 * remove unneeded `using` * test complex alpha, beta in 5-arg-mul * remove unneeded using IfElse * fix 5-arg mul test * restore accidentally deleted F32 test
1 parent 925d48f commit ba60608

File tree

6 files changed

+188
-11
lines changed

6 files changed

+188
-11
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ jobs:
2828
- '1'
2929
- '3' # GitHub runners have 2 cores, so `NUM_CORES+1` is 3
3030
version:
31-
- '1.5'
3231
- '1' # automatically expands to the latest stable 1.x release of Julia
3332
exclude:
3433
- os: macOS-latest

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Octavian"
22
uuid = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
33
authors = ["Mason Protter", "Chris Elrod", "Dilum Aluthge", "contributors"]
4-
version = "0.2.15"
4+
version = "0.2.16"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -12,13 +12,13 @@ ThreadingUtilities = "8290d209-cae3-49c0-8002-c8c24d57dab5"
1212
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1313

1414
[compat]
15-
ArrayInterface = "3"
15+
ArrayInterface = "3.1.14"
1616
LoopVectorization = "0.12"
1717
Static = "0.2"
1818
StrideArraysCore = "0.1.5"
1919
ThreadingUtilities = "0.4"
20-
VectorizationBase = "0.19,0.20"
21-
julia = "1.5"
20+
VectorizationBase = "0.20.5"
21+
julia = "1.6"
2222

2323
[extras]
2424
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
@@ -31,4 +31,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3131
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
3232

3333
[targets]
34-
test = ["Aqua", "BenchmarkTools", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "Random", "VectorizationBase", "Test"]
34+
test = ["Aqua", "BenchmarkTools", "InteractiveUtils", "LinearAlgebra", "LoopVectorization", "Random", "VectorizationBase", "Test"]

src/Octavian.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include("funcptrs.jl")
3232
include("macrokernels.jl")
3333
include("utils.jl")
3434
include("matmul.jl")
35+
include("complex_matmul.jl")
3536

3637
include("init.jl") # `Octavian.__init__()` is defined in this file
3738

src/complex_matmul.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
real_rep(a::AbstractArray{Complex{T}, N}) where {T, N} = reinterpret(reshape, T, a)
2+
#PtrArray(Ptr{T}(pointer(a)), (StaticInt(2), size(a)...))
3+
4+
@inline function _matmul!(_C::AbstractMatrix{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractMatrix{Complex{V}},
5+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
6+
C, A, B = real_rep.((_C, _A, _B))
7+
8+
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
9+
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
10+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
11+
ηθ = η*θ
12+
13+
@avxt for n indices((C, B), 3), m indices((C, A), 2)
14+
Cmn_re = zero(T)
15+
Cmn_im = zero(T)
16+
for k indices((A, B), (3, 2))
17+
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
18+
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
19+
end
20+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
21+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
22+
end
23+
_C
24+
end
25+
26+
@inline function _matmul!(_C::AbstractMatrix{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractMatrix{Complex{V}},
27+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
28+
C, B = real_rep.((_C, _B))
29+
30+
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
31+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
32+
33+
@avxt for n indices((C, B), 3), m indices((C, A), (2, 1))
34+
Cmn_re = zero(T)
35+
Cmn_im = zero(T)
36+
for k indices((A, B), (2, 2))
37+
Cmn_re += A[m, k] * B[1, k, n]
38+
Cmn_im += θ * A[m, k] * B[2, k, n]
39+
end
40+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
41+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
42+
end
43+
_C
44+
end
45+
46+
@inline function _matmul!(_C::AbstractMatrix{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractMatrix{V},
47+
α=One(), β=Zero(), nthread::Nothing=nothing, MKN=nothing, contig_axis=nothing) where {T,U,V}
48+
C, A = real_rep.((_C, _A))
49+
50+
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
51+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
52+
53+
@avxt for n indices((C, B), (3, 2)), m indices((C, A), 2)
54+
Cmn_re = zero(T)
55+
Cmn_im = zero(T)
56+
for k indices((A, B), (3, 1))
57+
Cmn_re += A[1, m, k] * B[k, n]
58+
Cmn_im += η * A[2, m, k] * B[k, n]
59+
end
60+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
61+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
62+
end
63+
_C
64+
end
65+
66+
67+
68+
69+
70+
@inline function _matmul_serial!(_C::AbstractMatrix{Complex{T}}, _A::AbstractMatrix{Complex{U}}, _B::AbstractMatrix{Complex{V}},
71+
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
72+
C, A, B = real_rep.((_C, _A, _B))
73+
74+
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
75+
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
76+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
77+
ηθ = η*θ
78+
@avxt for n indices((C, B), 3), m indices((C, A), 2)
79+
Cmn_re = zero(T)
80+
Cmn_im = zero(T)
81+
for k indices((A, B), (3, 2))
82+
Cmn_re += A[1, m, k] * B[1, k, n] - ηθ * A[2, m, k] * B[2, k, n]
83+
Cmn_im += θ * A[1, m, k] * B[2, k, n] + η * A[2, m, k] * B[1, k, n]
84+
end
85+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
86+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
87+
end
88+
_C
89+
end
90+
91+
@inline function _matmul_serial!(_C::AbstractMatrix{Complex{T}}, A::AbstractMatrix{U}, _B::AbstractMatrix{Complex{V}},
92+
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
93+
C, B = real_rep.((_C, _B))
94+
95+
θ = ifelse(ArrayInterface.is_lazy_conjugate(_B), StaticInt(-1), StaticInt(1))
96+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
97+
98+
@avx for n indices((C, B), 3), m indices((C, A), (2, 1))
99+
Cmn_re = zero(T)
100+
Cmn_im = zero(T)
101+
for k indices((A, B), (2, 2))
102+
Cmn_re += A[m, k] * B[1, k, n]
103+
Cmn_im += θ * A[m, k] * B[2, k, n]
104+
end
105+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
106+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
107+
end
108+
_C
109+
end
110+
111+
@inline function _matmul_serial!(_C::AbstractMatrix{Complex{T}}, _A::AbstractMatrix{Complex{U}}, B::AbstractMatrix{V},
112+
α=One(), β=Zero(), MKN=nothing, contig_axis=nothing) where {T,U,V}
113+
C, A = real_rep.((_C, _A))
114+
115+
η = ifelse(ArrayInterface.is_lazy_conjugate(_A), StaticInt(-1), StaticInt(1))
116+
(+ᶻ, -ᶻ) = ifelse(ArrayInterface.is_lazy_conjugate(_C), (-, +), (+, -))
117+
118+
@avx for n indices((C, B), (3, 2)), m indices((C, A), 2)
119+
Cmn_re = zero(T)
120+
Cmn_im = zero(T)
121+
for k indices((A, B), (3, 1))
122+
Cmn_re += A[1, m, k] * B[k, n]
123+
Cmn_im += η * A[2, m, k] * B[k, n]
124+
end
125+
C[1,m,n] = (real(α) * Cmn_re -imag(α) * Cmn_im) + (real(β) * C[1,m,n] -imag(β) * C[2,m,n])
126+
C[2,m,n] = (imag(α) * Cmn_re +real(α) * Cmn_im) + (imag(β) * C[1,m,n] +real(β) * C[2,m,n])
127+
end
128+
_C
129+
end

src/matmul.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ end
237237
return C
238238
end
239239

240-
241240
@inline function dontpack(pA::AbstractStridedPointer{Ta}, M, K, ::StaticInt{mc}, ::StaticInt{kc}, ::Type{Tc}, nspawn) where {mc, kc, Tc, Ta}
242241
# TODO: perhaps consider K vs kc by themselves?
243242
(contiguousstride1(pA) && ((M * K) (mc * kc) * nspawn >>> 1))

test/_matmul.jl

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,58 @@
33
# `n_values`
44
# `k_values`
55
# `m_values`
6+
for T (ComplexF32, ComplexF64, Complex{Int}, Complex{Int32})
7+
@time @testset "Matrix Multiply $T $(testset_name_suffix)" begin
8+
for n n_values
9+
for k k_values
10+
for m m_values
11+
A = rand(T, m, k)
12+
B = rand(T, k, n)
613

7-
@time @testset "Matrix Multiply Float32 $(testset_name_suffix)" begin
8-
T = Float32
14+
Are = real.(A)
15+
Bre = real.(B)
16+
17+
A′ = permutedims(A)'
18+
B′ = permutedims(B)'
19+
AB = A * B;
20+
A′B = A′*B
21+
AB′ = A*B′
22+
A′B′= A′*B′
23+
24+
AreB = Are*B
25+
ABre = A*Bre
26+
27+
@info "" T n k m
28+
@test @time(Octavian.matmul(A, B)) AB
29+
@test @time(Octavian.matmul(A, Bre)) ABre
30+
@test @time(Octavian.matmul(Are, B)) AreB
31+
@test @time(Octavian.matmul(A′, B)) A′B
32+
@test @time(Octavian.matmul(A, B′)) AB′
33+
@test @time(Octavian.matmul(A′, B′)) A′B′
34+
35+
36+
@test @time(Octavian.matmul_serial(A, B)) AB
37+
@test @time(Octavian.matmul_serial(A, Bre)) ABre
38+
@test @time(Octavian.matmul_serial(Are, B)) AreB
39+
@test @time(Octavian.matmul_serial(A′, B)) A′B
40+
@test @time(Octavian.matmul_serial(A, B′)) AB′
41+
@test @time(Octavian.matmul_serial(A′, B′)) A′B′
42+
43+
C = Matrix{T}(undef, n, m)'
44+
@test @time(Octavian.matmul!(C, A, B)) AB
45+
46+
C1 = rand(T, m, n)
47+
C2 = copy(C1)
48+
α, β = T(1 - 2im), T(3 + 4im)
49+
@test @time(Octavian.matmul!(C1, A, B, α, β)) Octavian.matmul!(C2, A, B, α, β)
50+
end
51+
end
52+
end
53+
end
54+
end
55+
56+
@time @testset "Matrix Multiply Float64 $(testset_name_suffix)" begin
57+
T = Float64
958
for n n_values
1059
for k k_values
1160
for m m_values
@@ -38,8 +87,8 @@
3887
@test matmul_pack_ab!(similar(AB), A′, B′) AB
3988
end
4089

41-
@time @testset "Matrix Multiply Float64 $(testset_name_suffix)" begin
42-
T = Float64
90+
@time @testset "Matrix Multiply Float32 $(testset_name_suffix)" begin
91+
T = Float32
4392
for n n_values
4493
for k k_values
4594
for m m_values

0 commit comments

Comments
 (0)