Skip to content

Commit 529be5d

Browse files
authored
consistent band indexing for structured matrices (#376)
* consistent band indexing for structured matrices * Add test for Eye * bump FillArrays compat bound
1 parent 29c38df commit 529be5d

File tree

3 files changed

+26
-42
lines changed

3 files changed

+26
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1313
Aqua = "0.6"
1414
ArrayLayouts = "1"
1515
Documenter = "0.27"
16-
FillArrays = "1.0.1"
16+
FillArrays = "1.3"
1717
PrecompileTools = "1"
1818
julia = "1.6"
1919

src/interfaceimpl.jl

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,6 @@ function rot180(A::AbstractBandedMatrix)
6767
_BandedMatrix(bandeddata(A)[end:-1:1,end:-1:1], m, u+sh,l-sh)
6868
end
6969

70-
function getindex(D::Diagonal{T,V}, b::Band) where {T,V}
71-
iszero(b.i) && return copy(D.diag)
72-
convert(V, Zeros{T}(size(D,1)-abs(b.i)))
73-
end
74-
75-
function getindex(D::Tridiagonal{T,V}, b::Band) where {T,V}
76-
b.i == -1 && return copy(D.dl)
77-
iszero(b.i) && return copy(D.d)
78-
b.i == 1 && return copy(D.du)
79-
convert(V, Zeros{T}(size(D,1)-abs(b.i)))
80-
end
81-
82-
function getindex(D::SymTridiagonal{T,V}, b::Band) where {T,V}
83-
iszero(b.i) && return copy(D.dv)
84-
abs(b.i) == 1 && return copy(D.ev)
85-
convert(V, Zeros{T}(size(D,1)-abs(b.i)))
86-
end
87-
88-
function getindex(D::Bidiagonal{T,V}, b::Band) where {T,V}
89-
iszero(b.i) && return copy(D.dv)
90-
D.uplo == 'L' && b.i == -1 && return copy(D.ev)
91-
D.uplo == 'U' && b.i == 1 && return copy(D.ev)
92-
convert(V, Zeros{T}(size(D,1)-abs(b.i)))
70+
for MT in (:Diagonal, :SymTridiagonal, :Tridiagonal, :Bidiagonal)
71+
@eval getindex(D::$MT, b::Band) = diag(D, b.i)
9372
end

test/test_interface.jl

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ LinearAlgebra.fill!(A::PseudoBandedMatrix, v) = fill!(A.data,v)
5858
@test B * Eye(5) == B
5959
@test muladd!(2.0, Eye(5), B, 0.0, zeros(5,5)) == 2B
6060
@test muladd!(2.0, B, Eye(5), 0.0, zeros(5,5)) == 2B
61+
62+
E = Eye(4)
63+
@test (@inferred E[band(0)]) == Ones(4)
64+
@test (@inferred E[band(1)]) == Zeros(3)
65+
@test (@inferred E[band(-1)]) == Zeros(3)
6166
end
6267

6368
@testset "Diagonal" begin
@@ -79,9 +84,9 @@ LinearAlgebra.fill!(A::PseudoBandedMatrix, v) = fill!(A.data,v)
7984
@test A[band(0)] == [2; ones(4)]
8085

8186
B = Diagonal(Fill(1,5))
82-
@test B[band(0)] Fill(1,5)
83-
@test B[band(1)] B[band(-1)] Fill(0,4)
84-
@test B[band(2)] B[band(-2)] Fill(0,3)
87+
@test (@inferred B[band(0)]) == Fill(1,5)
88+
@test (@inferred B[band(1)]) == B[band(-1)] == Fill(0,4)
89+
@test (@inferred B[band(2)]) == B[band(-2)] == Fill(0,3)
8590
end
8691

8792
@testset "SymTridiagonal" begin
@@ -93,32 +98,32 @@ LinearAlgebra.fill!(A::PseudoBandedMatrix, v) = fill!(A.data,v)
9398
@test A[1,1] == 2
9499

95100
B = SymTridiagonal(Fill(1,5), Fill(2,4))
96-
@test B[band(0)] Fill(1,5)
97-
@test B[band(1)] B[band(-1)] Fill(2,4)
98-
@test B[band(2)] B[band(-2)] Fill(0,3)
101+
@test (@inferred B[band(0)]) == Fill(1,5)
102+
@test (@inferred B[band(1)]) == B[band(-1)] == Fill(2,4)
103+
@test (@inferred B[band(2)]) == B[band(-2)] == Fill(0,3)
99104
end
100105

101106
@testset "Tridiagonal" begin
102107
B = Tridiagonal(Fill(1,4), Fill(2,5), Fill(3,4))
103-
@test B[band(0)] Fill(2,5)
104-
@test B[band(1)] Fill(3,4)
105-
@test B[band(-1)] Fill(1,4)
106-
@test B[band(2)] B[band(-2)] Fill(0,3)
108+
@test (@inferred B[band(0)]) == Fill(2,5)
109+
@test (@inferred B[band(1)]) == Fill(3,4)
110+
@test (@inferred B[band(-1)]) == Fill(1,4)
111+
@test B[band(2)] == B[band(-2)] == Fill(0,3)
107112
end
108113

109114
@testset "Bidiagonal" begin
110115
L = Bidiagonal(Fill(2,5), Fill(1,4), :L)
111-
@test L[band(0)] Fill(2,5)
112-
@test L[band(1)] Fill(0,4)
113-
@test L[band(-1)] Fill(1,4)
114-
@test L[band(2)] L[band(-2)] Fill(0,3)
116+
@test (@inferred L[band(0)]) == Fill(2,5)
117+
@test (@inferred L[band(1)]) == Fill(0,4)
118+
@test (@inferred L[band(-1)]) == Fill(1,4)
119+
@test (@inferred L[band(2)]) == L[band(-2)] == Fill(0,3)
115120
@test BandedMatrix(L) == L
116121

117122
U = Bidiagonal(Fill(2,5), Fill(1,4), :U)
118-
@test U[band(0)] Fill(2,5)
119-
@test U[band(1)] Fill(1,4)
120-
@test U[band(-1)] Fill(0,4)
121-
@test U[band(2)] U[band(-2)] Fill(0,3)
123+
@test (@inferred U[band(0)]) == Fill(2,5)
124+
@test (@inferred U[band(1)]) == Fill(1,4)
125+
@test (@inferred U[band(-1)]) == Fill(0,4)
126+
@test (@inferred U[band(2)]) == U[band(-2)] == Fill(0,3)
122127
@test BandedMatrix(U) == U
123128
end
124129

0 commit comments

Comments
 (0)