Skip to content

Commit 4a5f26d

Browse files
authored
Fix 5-term mul with non-zero beta (#395)
* Fix 5-term mul with non-zero beta * Add tests * import ArrayLayouts
1 parent 20aada1 commit 4a5f26d

File tree

3 files changed

+38
-33
lines changed

3 files changed

+38
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BandedMatrices"
22
uuid = "aae01518-5342-5314-be14-df237901396f"
3-
version = "0.17.36"
3+
version = "0.17.37"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/generic/matmul.jl

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ end
4040

4141
function _banded_muladd!(α, A, x::AbstractVector, β, y)
4242
m, n = size(A)
43-
(length(y) m || length(x) n) && throw(DimensionMismatch("*"))
4443
l, u = bandwidths(A)
4544
if -l > u # no bands
4645
_fill_lmul!(β, y)
@@ -51,40 +50,43 @@ function _banded_muladd!(α, A, x::AbstractVector, β, y)
5150
elseif u < 0 # with -l <= u < 0, that is, all bands lie below the diagnoal.
5251
# E.g. (l,u) = (2,-1)
5352
# set lview = l + u >= 0 and uview = 0
54-
y[1:-u] .= zero(eltype(y))
53+
_fill_lmul!(β, @view(y[1:-u]))
5554
_banded_gbmv!('N', α, view(A, 1-u:m, :), x, β, view(y, 1-u:m))
5655
y
5756
else
5857
_banded_gbmv!('N', α, A, x, β, y)
5958
end
6059
end
6160

62-
materialize!(M::BlasMatMulVecAdd{<:BandedColumnMajor,<:AbstractStridedLayout,<:AbstractStridedLayout,T}) where T<:BlasFloat =
61+
function materialize!(M::BlasMatMulVecAdd{<:BandedColumnMajor,<:AbstractStridedLayout,<:AbstractStridedLayout,<:BlasFloat})
62+
checkdimensions(M)
6363
_banded_muladd!(M.α, M.A, M.B, M.β, M.C)
64+
end
6465

6566
function _banded_muladd_row!(tA, α, At, x, β, y)
6667
n, m = size(At)
67-
(length(y) m || length(x) n) && throw(DimensionMismatch("*"))
6868
u, l = bandwidths(At)
6969
if -l > u # no bands
7070
_fill_lmul!(β, y)
7171
elseif l < 0
7272
_banded_gbmv!(tA, α, view(At, 1-l:n, :,), view(x, 1-l:n), β, y)
7373
elseif u < 0
74-
y[1:-u] .= zero(eltype(y))
74+
_fill_lmul!(β, @view(y[1:-u]))
7575
_banded_gbmv!(tA, α, view(At, :, 1-u:m), x, β, view(y, 1-u:m))
7676
y
7777
else
7878
_banded_gbmv!(tA, α, At, x, β, y)
7979
end
8080
end
8181

82-
function materialize!(M::BlasMatMulVecAdd{<:BandedRowMajor,<:AbstractStridedLayout,<:AbstractStridedLayout,T}) where T<:BlasFloat
82+
function materialize!(M::BlasMatMulVecAdd{<:BandedRowMajor,<:AbstractStridedLayout,<:AbstractStridedLayout,<:BlasFloat})
83+
checkdimensions(M)
8384
α, A, x, β, y = M.α, M.A, M.B, M.β, M.C
8485
_banded_muladd_row!('T', α, transpose(A), x, β, y)
8586
end
8687

87-
function materialize!(M::BlasMatMulVecAdd{<:ConjLayout{<:BandedRowMajor},<:AbstractStridedLayout,<:AbstractStridedLayout,T}) where T<:BlasComplex
88+
function materialize!(M::BlasMatMulVecAdd{<:ConjLayout{<:BandedRowMajor},<:AbstractStridedLayout,<:AbstractStridedLayout,<:BlasComplex})
89+
checkdimensions(M)
8890
α, A, x, β, y = M.α, M.A, M.B, M.β, M.C
8991
_banded_muladd_row!('C', α, A', x, β, y)
9092
end
@@ -173,25 +175,17 @@ const ConjOrBandedLayout = Union{AbstractBandedLayout,ConjLayout{<:AbstractBande
173175
const ConjOrBandedColumnMajor = Union{<:BandedColumnMajor,ConjLayout{<:BandedColumnMajor}}
174176

175177
function _banded_muladd!::T, A, B::AbstractMatrix, β, C) where T
176-
Am, An = size(A)
177-
Bm, Bn = size(B)
178-
if An != Bm || size(C, 1) != Am || size(C, 2) != Bn
179-
throw(DimensionMismatch("*"))
180-
end
181-
182-
Al, Au = bandwidths(A)
183-
Bl, Bu = bandwidths(B)
184-
185178
gbmm!('N', 'N', α, A, B, β, C)
186-
187179
C
188180
end
189181

190182
materialize!(M::BlasMatMulMatAdd{<:AbstractBandedLayout,<:AbstractBandedLayout,<:BandedColumnMajor}) =
191183
materialize!(MulAdd(M.α, convert(DefaultBandedMatrix,M.A), convert(DefaultBandedMatrix,M.B), M.β, M.C))
192184

193-
materialize!(M::BlasMatMulMatAdd{<:BandedColumnMajor,<:BandedColumnMajor,<:BandedColumnMajor}) =
185+
function materialize!(M::BlasMatMulMatAdd{<:BandedColumnMajor,<:BandedColumnMajor,<:BandedColumnMajor})
186+
checkdimensions(M)
194187
_banded_muladd!(M.α, M.A, M.B, M.β, M.C)
188+
end
195189

196190

197191
# function generally_banded_matmatmul!(C::AbstractMatrix{T}, tA::Val, tB::Val, A::AbstractMatrix{U}, B::AbstractMatrix{V}) where {T, U, V}
@@ -246,13 +240,9 @@ end
246240
### BandedMatrix * dense matrix
247241

248242
function materialize!(M::MatMulMatAdd{<:BandedColumns, <:AbstractColumnMajor, <:AbstractColumnMajor})
243+
checkdimensions(M)
249244
α, β, A, B, C = M.α, M.β, M.A, M.B, M.C
250245

251-
mA, nA = size(A)
252-
mB, nB = size(B)
253-
mC, nC = size(C)
254-
(nA == mB && mC == mA && nC == nB) || throw(DimensionMismatch("A has size ($mA, $nA), B has size ($mB, $nB), C has size ($mC, $nC)"))
255-
256246
if iszero(α)
257247
lmul!(β, C)
258248
else
@@ -265,11 +255,8 @@ function materialize!(M::MatMulMatAdd{<:BandedColumns, <:AbstractColumnMajor, <:
265255
end
266256

267257
function materialize!(M::MatMulMatAdd{<:AbstractColumnMajor, <:BandedColumns, <:AbstractColumnMajor})
258+
checkdimensions(M)
268259
α, β, A, B, C = M.α, M.β, M.A, M.B, M.C
269-
mA, nA = size(A)
270-
mB, nB = size(B)
271-
mC, nC = size(C)
272-
(nA == mB && mC == mA && nC == nB) || throw(DimensionMismatch("A has size ($mA, $nA), B has size ($mB, $nB), C has size ($mC, $nC)"))
273260

274261
if iszero(α)
275262
lmul!(β, C)

test/test_banded.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
using BandedMatrices, FillArrays, Test, LinearAlgebra, SparseArrays
1+
using ArrayLayouts
2+
using BandedMatrices
23
import BandedMatrices: _BandedMatrix
3-
4+
using FillArrays
5+
using LinearAlgebra
6+
using SparseArrays
7+
using Test
48

59
# used to test general matrix backends
610
struct MyMatrix{T} <: AbstractMatrix{T}
@@ -105,24 +109,38 @@ Base.similar(::MyMatrix, ::Type{T}, m::Int, n::Int) where T = MyMatrix{T}(undef,
105109
end
106110

107111
@testset "BandedMatrix * Vector" begin
108-
let A=brand(10,12,2,3), v=rand(12), w=rand(10)
109-
@test A*v Matrix(A)*v
110-
@test A'*w Matrix(A)'*w
112+
let v=rand(12), w=rand(10)
113+
for (l,u) in ((2,3), (-2,2), (2,-2), (2,-3))
114+
A=brand(length(w),length(v),l,u)
115+
@test A*v Matrix(A)*v
116+
# the left-side uses BLAS, while the right doesn't
117+
@test mul!(ones(size(A,1)), A, v, 1.0, 2.0) mul!(ones(size(A,1)), A, v, 1, 2)
118+
@test A'*w Matrix(A)'*w
119+
@test mul!(ones(size(A,2)), A', w, 1.0, 2.0) mul!(ones(size(A,2)), A', w, 1, 2)
120+
# explicitly test materialize!
121+
@test materialize!(MulAdd(1.0, A', w, 2.0, ones(size(A,2)))) materialize!(MulAdd(1, A', w, 2, ones(size(A,2))))
122+
end
111123
end
112124

113125
let A=brand(Float64,5,3,2,2), v=rand(ComplexF64,3), w=rand(ComplexF64,5)
114126
@test A*v Matrix(A)*v
127+
@test mul!(ones(ComplexF64,size(A,1)), A, v, 1.0, 2.0) mul!(ones(ComplexF64,size(A,1)), A, v, 1, 2)
115128
@test A'*w Matrix(A)'*w
129+
@test mul!(ones(ComplexF64,size(A,2)), A', w, 1.0, 2.0) mul!(ones(ComplexF64,size(A,2)), A', w, 1, 2)
116130
end
117131

118132
let A=brand(ComplexF64,5,3,2,2), v=rand(ComplexF64,3), w=rand(ComplexF64,5)
119133
@test A*v Matrix(A)*v
134+
@test mul!(ones(ComplexF64,size(A,1)), A, v, 1.0, 2.0) mul!(ones(ComplexF64,size(A,1)), A, v, 1, 2)
120135
@test A'*w Matrix(A)'*w
136+
@test mul!(ones(ComplexF64,size(A,2)), A', w, 1.0, 2.0) mul!(ones(ComplexF64,size(A,2)), A', w, 1, 2)
121137
end
122138

123139
let A=brand(ComplexF64,5,3,2,2), v=rand(Float64,3), w=rand(Float64,5)
124140
@test A*v Matrix(A)*v
141+
@test mul!(ones(ComplexF64,size(A,1)), A, v, 1.0, 2.0) mul!(ones(ComplexF64,size(A,1)), A, v, 1, 2)
125142
@test A'*w Matrix(A)'*w
143+
@test mul!(ones(ComplexF64,size(A,2)), A', w, 1.0, 2.0) mul!(ones(ComplexF64,size(A,2)), A', w, 1, 2)
126144
end
127145

128146
@testset "empty" begin

0 commit comments

Comments
 (0)