Skip to content

Commit b694678

Browse files
authored
Zero beta (#122)
* fix alpha != 1 * Bump version * Ensure zero-betas ignore NaNs in dest, fixes #121. * Bump version
1 parent cd0be03 commit b694678

File tree

2 files changed

+60
-52
lines changed

2 files changed

+60
-52
lines changed

src/matmul.jl

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,8 @@ end
146146
return C
147147
end
148148
@inline function matmul_serial!(C::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, α, β, MKN, ::StaticInt)
149-
_matmul_serial!(C, A, B, α, β, MKN)
150-
return C
149+
_matmul_serial!(C, A, B, α, β, MKN)
150+
return C
151151
end
152152

153153
"""
@@ -164,30 +164,32 @@ If the arrays are small and statically sized, it will dispatch to an inlined mul
164164
Otherwise, based on the array's size, whether they are transposed, and whether the columns are already aligned, it decides to not pack at all, to pack only `A`, or to pack both arrays `A` and `B`.
165165
"""
166166
@inline function _matmul_serial!(
167-
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
167+
C::AbstractMatrix{T}, A::AbstractMatrix, B::AbstractMatrix, α, β, MKN
168168
) where {T}
169-
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
170-
if M * N == 0
171-
return
172-
elseif K == 0
173-
matmul_only_β!(C, β)
174-
return
175-
end
176-
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
177-
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
178-
Mc, Kc, Nc = block_sizes(Val(T)); mᵣ, nᵣ = matmul_params(Val(T));
179-
GC.@preserve Cb Ab Bb begin
180-
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
181-
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
182-
return
183-
elseif (nᵣ N) || dontpack(pA, M, K, Mc, Kc, T)
184-
loopmul!(pC, pA, pB, α, β, M, K, N)
185-
return
186-
else
187-
matmul_st_pack_dispatcher!(pC, pA, pB, α, β, M, K, N)
188-
return
189-
end
169+
((β Zero()) && iszero(β)) && return _matmul_serial!(C, A, B, α, Zero(), MKN)
170+
isa Bool) && return _matmul_serial!(C, A, B, α, One(), MKN)
171+
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
172+
if M * N == 0
173+
return
174+
elseif K == 0
175+
matmul_only_β!(C, β)
176+
return
177+
end
178+
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
179+
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
180+
Mc, Kc, Nc = block_sizes(Val(T)); mᵣ, nᵣ = matmul_params(Val(T));
181+
GC.@preserve Cb Ab Bb begin
182+
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
183+
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
184+
return
185+
elseif (nᵣ N) || dontpack(pA, M, K, Mc, Kc, T)
186+
loopmul!(pC, pA, pB, α, β, M, K, N)
187+
return
188+
else
189+
matmul_st_pack_dispatcher!(pC, pA, pB, α, β, M, K, N)
190+
return
190191
end
192+
end
191193
end # function
192194

193195
function matmul_only_β!(C::AbstractMatrix{T}, β::StaticInt{0}) where T
@@ -266,35 +268,37 @@ end
266268

267269
# passing MKN directly would let osmeone skip the size check.
268270
@inline function _matmul!(C::AbstractMatrix{T}, A, B, α, β, nthread, MKN) where {T}
269-
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
270-
if M * N == 0
271-
return
272-
elseif K == 0
273-
matmul_only_β!(C, β)
274-
return
275-
end
276-
W = pick_vector_width(T)
277-
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
278-
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
279-
mᵣ, nᵣ = matmul_params(Val(T))
280-
GC.@preserve Cb Ab Bb begin
281-
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
282-
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
283-
return
284-
else
285-
(nᵣ N) && @goto LOOPMUL
286-
if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686)
287-
(M*K*N < (StaticInt{4_096}() * W)) && @goto LOOPMUL
288-
else
289-
(M*K*N < (StaticInt{32_000}() * W)) && @goto LOOPMUL
290-
end
291-
__matmul!(pC, pA, pB, α, β, M, K, N, nthread)
292-
return
293-
@label LOOPMUL
294-
loopmul!(pC, pA, pB, α, β, M, K, N)
295-
return
296-
end
271+
((β Zero()) && iszero(β)) && return _matmul!(C, A, B, α, Zero(), nthread, MKN)
272+
isa Bool) && return _matmul!(C, A, B, α, One(), nthread, MKN)
273+
M, K, N = MKN === nothing ? matmul_sizes(C, A, B) : MKN
274+
if M * N == 0
275+
return
276+
elseif K == 0
277+
matmul_only_β!(C, β)
278+
return
279+
end
280+
W = pick_vector_width(T)
281+
pA = zstridedpointer(A); pB = zstridedpointer(B); pC = zstridedpointer(C);
282+
Cb = preserve_buffer(C); Ab = preserve_buffer(A); Bb = preserve_buffer(B);
283+
mᵣ, nᵣ = matmul_params(Val(T))
284+
GC.@preserve Cb Ab Bb begin
285+
if maybeinline(M, N, T, ArrayInterface.is_column_major(A)) # check MUST be compile-time resolvable
286+
inlineloopmul!(pC, pA, pB, One(), Zero(), M, K, N)
287+
return
288+
else
289+
(nᵣ N) && @goto LOOPMUL
290+
if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686)
291+
(M*K*N < (StaticInt{4_096}() * W)) && @goto LOOPMUL
292+
else
293+
(M*K*N < (StaticInt{32_000}() * W)) && @goto LOOPMUL
294+
end
295+
__matmul!(pC, pA, pB, α, β, M, K, N, nthread)
296+
return
297+
@label LOOPMUL
298+
loopmul!(pC, pA, pB, α, β, M, K, N)
299+
return
297300
end
301+
end
298302
end
299303

300304
# This funciton is sort of a `pun`. It splits aggressively (it does a lot of "splitin'"), which often means it will split-N.

test/matmul_main.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,7 @@ for T ∈ (Float64,Float32,Int64,Int32)
99
@time test_real(T, m_values, k_values, n_values, testset_name_suffix)
1010
end
1111

12+
A = rand(2,2); B = rand(2,2); AB = A*B; C = fill(NaN, 2, 2);
13+
@test Octavian.matmul!(C, A, B, true, false) AB
14+
@test Octavian.matmul!(C, A, B, true, true) 2AB
15+

0 commit comments

Comments
 (0)