Skip to content

Commit 01ecc9a

Browse files
authored
Support broadcasting with more bands when result is zero (#259)
* Support broadcasting with more bands when result is zero * Update Project.toml * add tests
1 parent 0e4dd06 commit 01ecc9a

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "BandedMatrices"
22
uuid = "aae01518-5342-5314-be14-df237901396f"
3-
version = "0.17.1"
3+
version = "0.17.2"
44

55
[deps]
66
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"

src/generic/broadcast.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,24 @@ function _banded_broadcast!(dest::AbstractMatrix, f, src::AbstractMatrix{T}, _1,
9696

9797
d_l, d_u = bandwidths(dest)
9898
s_l, s_u = bandwidths(src)
99-
(d_l  min(s_l,m-1) && d_u min(s_u,n-1)) || throw(BandError(dest))
99+
if d_l < min(s_l,m-1)
100+
# check zeros
101+
for j = 1:n, k = max(1,j+d_l+1):min(j+s_l,j+d_l,m)
102+
iszero(f(inbands_getindex(src, k, j))) || throw(BandError(dest))
103+
end
104+
end
105+
if d_u < min(s_u,n-1)
106+
# check zeros
107+
for j = 1:n, k = max(1,j-d_u,j-s_u):min(j-d_u-1,m)
108+
iszero(f(inbands_getindex(src, k, j))) || throw(BandError(dest))
109+
end
110+
end
100111

101112
for j=1:n
102113
for k = max(1,j-d_u):min(j-s_u-1,m)
103114
inbands_setindex!(dest, z, k, j)
104115
end
105-
for k = max(1,j-s_u):min(j+s_l,m)
116+
for k = max(1,j-s_u,j-d_u):min(j+s_l,j+d_l,m)
106117
inbands_setindex!(dest, f(inbands_getindex(src, k, j)), k, j)
107118
end
108119
for k = max(1,j+s_l+1):min(j+d_l,m)

test/test_broadcasting.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,4 +560,13 @@ import BandedMatrices: BandedStyle, BandedRows
560560
@test_throws DimensionMismatch A .* Ones(3)
561561
@test_throws DimensionMismatch A .* Ones(3,4)
562562
end
563+
564+
@testset "degenerate bands" begin
565+
A = BandedMatrix{Float64}(undef, (5, 5), (1,-1)); A.data .= NaN
566+
B = BandedMatrix{Float64}(undef, (5, 5), (-1,1)); B.data .= NaN
567+
Z = Diagonal(Zeros(5))
568+
copyto!(A, Z)
569+
copyto!(B, Z)
570+
@test A == B == zeros(5,5)
571+
end
563572
end

0 commit comments

Comments
 (0)