Skip to content

Commit d51f130

Browse files
authored
Faster broadcasting for band view (#368)
* faster broadcasting for band view * bump version to 0.17.27 * broadcasting with band view on the right * remove inbounds in getindex
1 parent 23efecb commit d51f130

File tree

3 files changed

+34
-22
lines changed

3 files changed

+34
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BandedMatrices"
22
uuid = "aae01518-5342-5314-be14-df237901396f"
3-
version = "0.17.26"
3+
version = "0.17.27"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/banded/BandedMatrix.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -447,8 +447,18 @@ function dataview(V::BandedMatrixBand)
447447
view(A.data, A.u - b + 1, max(b,0)+1:min(n,m+b))
448448
end
449449

450+
@propagate_inbounds function Base.getindex(B::BandedMatrixBand, i::Int)
451+
A = parent(parent(B))
452+
b = band(B)
453+
if -A.l band(B) A.u
454+
dataview(B)[i]
455+
else
456+
zero(eltype(B))
457+
end
458+
end
459+
450460
# B[band(i)]
451-
function copyto!(v::AbstractVector, B::BandedMatrixBand)
461+
@inline function copyto!(v::AbstractVector, B::BandedMatrixBand)
452462
A = parent(parent(B))
453463
if -A.l band(B) A.u
454464
copyto!(v, dataview(B))
@@ -460,7 +470,7 @@ function copyto!(v::AbstractVector, B::BandedMatrixBand)
460470
end
461471

462472
# B[band(i)] .= x::Number
463-
function fill!(Bv::BandedMatrixBand, x)
473+
@inline function fill!(Bv::BandedMatrixBand, x)
464474
b = band(Bv)
465475
A = parent(parent(Bv))
466476
l, u = bandwidths(A)
@@ -472,29 +482,23 @@ function fill!(Bv::BandedMatrixBand, x)
472482
Bv
473483
end
474484

475-
# B[band(i)] .= V::AbstractVector
476-
function _copyto!(Bv::BandedMatrixBand, V::AbstractVector)
477-
isempty(V) && return Bv
478-
A = parent(parent(Bv))
479-
if -A.l band(Bv) A.u
480-
copyto!(dataview(Bv), V)
485+
@noinline throwdm(destaxes, srcaxes) =
486+
throw(DimensionMismatch("destination axes $destaxes do not match source axes $srcaxes"))
487+
488+
# more complicated broadcating
489+
# e.g. B[band(i)] .= a .* x .+ v
490+
@inline function copyto!(dest::BandedMatrixBand, bc::Broadcasted{Nothing})
491+
axes(dest) == axes(bc) || throwdm(axes(dest), axes(src))
492+
493+
A = parent(parent(dest))
494+
if -A.l band(dest) A.u
495+
copyto!(dataview(dest), bc)
481496
else
482-
# bounds-checking to work around axis offset of V
483-
destinds, srcinds = LinearIndices(Bv), LinearIndices(V)
484-
idf, isf = first(destinds), first(srcinds)
485-
Δi = idf - isf
486-
(checkbounds(Bool, destinds, isf+Δi) &
487-
checkbounds(Bool, destinds, last(srcinds)+Δi)) ||
488-
throw(BoundsError(dest, srcinds))
489-
490-
all(iszero, V) || throw(BandError(A, band(Bv)))
497+
all(iszero, bc) || throw(BandError(A, band(dest)))
491498
end
492-
return Bv
499+
return dest
493500
end
494501

495-
copyto!(Bv::BandedMatrixBand, V::AbstractVector) = _copyto!(Bv, V)
496-
copyto!(Bv::BandedMatrixBand, V::BandedMatrixBand) = _copyto!(Bv, V)
497-
498502
# ~ indexing along a row
499503

500504

test/test_broadcasting.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,10 @@ import BandedMatrices: BandedStyle, BandedRows, BandError
736736
@test all(==(10), diag(B, 2))
737737
@test all(==(10), B[band(2)])
738738

739+
B = BandedMatrix{Int}(1=>1:4, -1=>2:5)
740+
B[band(1)] .= 2 .* view(B, band(-1)) .+ 4
741+
@test B[band(1)] == 2 .* (2:5) .+ 4
742+
739743
B = brand(Int8, 2, 4, 1, 1)
740744
B[band(-1)] .= 2
741745
B[band(0)] .= 3
@@ -749,8 +753,12 @@ import BandedMatrices: BandedStyle, BandedRows, BandError
749753
@test (@view B[band(0)]) == 4:5
750754
B[band(1)] .= 3:4
751755
@test (@view B[band(1)]) == 3:4
756+
B[band(1)] .= [3,4] .* 2 .+ 4
757+
@test (@view B[band(1)]) == [10,12]
752758
B[band(2)] .= [0,0]
753759
@test all(iszero, @view B[band(2)])
760+
B[band(2)] .= [0,0] .* 2
761+
@test all(iszero, @view B[band(2)])
754762

755763
@test_throws BandError B[band(100)] .= 10
756764
@test_throws BandError B[band(-100)] .= 10

0 commit comments

Comments
 (0)