Skip to content

Commit bcc4638

Browse files
authored
Bounds-checking in triangular indexing branches (#1305)
1 parent cdd135e commit bcc4638

File tree

2 files changed

+160
-38
lines changed

2 files changed

+160
-38
lines changed

src/triangular.jl

Lines changed: 83 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,22 @@ Base.isassigned(A::UpperOrLowerTriangular, i::Int, j::Int) =
238238
Base.isstored(A::UpperOrLowerTriangular, i::Int, j::Int) =
239239
_shouldforwardindex(A, i, j) ? Base.isstored(A.data, i, j) : false
240240

241-
@propagate_inbounds getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T} =
242-
_shouldforwardindex(A, i, j) ? A.data[i,j] : ifelse(i == j, oneunit(T), zero(T))
243-
@propagate_inbounds getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int) =
244-
_shouldforwardindex(A, i, j) ? A.data[i,j] : diagzero(A,i,j)
241+
@propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, i::Int, j::Int) where {T}
242+
if _shouldforwardindex(A, i, j)
243+
A.data[i,j]
244+
else
245+
@boundscheck checkbounds(A, i, j)
246+
ifelse(i == j, oneunit(T), zero(T))
247+
end
248+
end
249+
@propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, i::Int, j::Int)
250+
if _shouldforwardindex(A, i, j)
251+
A.data[i,j]
252+
else
253+
@boundscheck checkbounds(A, i, j)
254+
@inbounds diagzero(A,i,j)
255+
end
256+
end
245257

246258
_shouldforwardindex(U::UpperTriangular, b::BandIndex) = b.band >= 0
247259
_shouldforwardindex(U::LowerTriangular, b::BandIndex) = b.band <= 0
@@ -250,62 +262,97 @@ _shouldforwardindex(U::UnitLowerTriangular, b::BandIndex) = b.band < 0
250262

251263
# these specialized getindex methods enable constant-propagation of the band
252264
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{UnitLowerTriangular{T}, UnitUpperTriangular{T}}, b::BandIndex) where {T}
253-
_shouldforwardindex(A, b) ? A.data[b] : ifelse(b.band == 0, oneunit(T), zero(T))
265+
if _shouldforwardindex(A, b)
266+
A.data[b]
267+
else
268+
@boundscheck checkbounds(A, b)
269+
ifelse(b.band == 0, oneunit(T), zero(T))
270+
end
254271
end
255272
Base.@constprop :aggressive @propagate_inbounds function getindex(A::Union{LowerTriangular, UpperTriangular}, b::BandIndex)
256-
_shouldforwardindex(A, b) ? A.data[b] : diagzero(A.data, b)
273+
if _shouldforwardindex(A, b)
274+
A.data[b]
275+
else
276+
@boundscheck checkbounds(A, b)
277+
@inbounds diagzero(A, b)
278+
end
257279
end
258280

259-
_zero_triangular_half_str(T::Type) = T <: UpperOrUnitUpperTriangular ? "lower" : "upper"
260-
261-
@noinline function throw_nonzeroerror(T::DataType, @nospecialize(x), i, j)
262-
Ts = _zero_triangular_half_str(T)
263-
Tn = nameof(T)
281+
@noinline function throw_nonzeroerror(Tn::Symbol, @nospecialize(x), i, j)
282+
zero_half = Tn in (:UpperTriangular, :UnitUpperTriangular) ? "lower" : "upper"
283+
nstr = Tn === :UpperTriangular ? "n" : ""
264284
throw(ArgumentError(
265-
lazy"cannot set index in the $Ts triangular part ($i, $j) of an $Tn matrix to a nonzero value ($x)"))
285+
LazyString(
286+
lazy"cannot set index ($i, $j) in the $zero_half triangular part ",
287+
lazy"of a$nstr $Tn matrix to a nonzero value ($x)")
288+
)
289+
)
266290
end
267-
@noinline function throw_nononeerror(T::DataType, @nospecialize(x), i, j)
268-
Tn = nameof(T)
291+
@noinline function throw_nonuniterror(Tn::Symbol, @nospecialize(x), i, j)
269292
throw(ArgumentError(
270-
lazy"cannot set index on the diagonal ($i, $j) of an $Tn matrix to a non-unit value ($x)"))
293+
lazy"cannot set index ($i, $j) on the diagonal of a $Tn matrix to a non-unit value ($x)"))
271294
end
272295

273296
@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
274-
if i > j
275-
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
276-
else
297+
if _shouldforwardindex(A, i, j)
277298
A.data[i,j] = x
299+
else
300+
@boundscheck checkbounds(A, i, j)
301+
# the value must be convertible to the eltype for setindex! to be meaningful
302+
# however, the converted value is unused, and the compiler is free to remove
303+
# the conversion if the call is guaranteed to succeed
304+
convert(eltype(A), x)
305+
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
278306
end
279307
return A
280308
end
281309

282310
@propagate_inbounds function setindex!(A::UnitUpperTriangular, x, i::Integer, j::Integer)
283-
if i > j
284-
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
285-
elseif i == j
286-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
287-
else
311+
if _shouldforwardindex(A, i, j)
288312
A.data[i,j] = x
313+
else
314+
@boundscheck checkbounds(A, i, j)
315+
# the value must be convertible to the eltype for setindex! to be meaningful
316+
# however, the converted value is unused, and the compiler is free to remove
317+
# the conversion if the call is guaranteed to succeed
318+
convert(eltype(A), x)
319+
if i == j # diagonal
320+
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
321+
else
322+
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
323+
end
289324
end
290325
return A
291326
end
292327

293328
@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
294-
if i < j
295-
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
296-
else
329+
if _shouldforwardindex(A, i, j)
297330
A.data[i,j] = x
331+
else
332+
@boundscheck checkbounds(A, i, j)
333+
# the value must be convertible to the eltype for setindex! to be meaningful
334+
# however, the converted value is unused, and the compiler is free to remove
335+
# the conversion if the call is guaranteed to succeed
336+
convert(eltype(A), x)
337+
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
298338
end
299339
return A
300340
end
301341

302342
@propagate_inbounds function setindex!(A::UnitLowerTriangular, x, i::Integer, j::Integer)
303-
if i < j
304-
iszero(x) || throw_nonzeroerror(typeof(A), x, i, j)
305-
elseif i == j
306-
x == oneunit(x) || throw_nononeerror(typeof(A), x, i, j)
307-
else
343+
if _shouldforwardindex(A, i, j)
308344
A.data[i,j] = x
345+
else
346+
@boundscheck checkbounds(A, i, j)
347+
# the value must be convertible to the eltype for setindex! to be meaningful
348+
# however, the converted value is unused, and the compiler is free to remove
349+
# the conversion if the call is guaranteed to succeed
350+
convert(eltype(A), x)
351+
if i == j # diagonal
352+
x == oneunit(eltype(A)) || throw_nonuniterror(nameof(typeof(A)), x, i, j)
353+
else
354+
iszero(x) || throw_nonzeroerror(nameof(typeof(A)), x, i, j)
355+
end
309356
end
310357
return A
311358
end
@@ -559,7 +606,7 @@ for (T, UT) in ((:UpperTriangular, :UnitUpperTriangular), (:LowerTriangular, :Un
559606
@eval @inline function _copy!(A::$UT, B::$T)
560607
for dind in diagind(A, IndexStyle(A))
561608
if A[dind] != B[dind]
562-
throw_nononeerror(typeof(A), B[dind], Tuple(dind)...)
609+
throw_nonuniterror(nameof(typeof(A)), B[dind], Tuple(dind)...)
563610
end
564611
end
565612
_copy!($T(parent(A)), B)
@@ -740,7 +787,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, B::UnitUpperTriangular, c::Nu
740787
checksize1(A, B)
741788
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
742789
for j in axes(B.data,2)
743-
@inbounds _modify!(_add, c, A, (j,j))
790+
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
744791
for i in firstindex(B.data,1):(j - 1)
745792
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
746793
end
@@ -751,7 +798,7 @@ function _triscale!(A::UpperOrUnitUpperTriangular, c::Number, B::UnitUpperTriang
751798
checksize1(A, B)
752799
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
753800
for j in axes(B.data,2)
754-
@inbounds _modify!(_add, c, A, (j,j))
801+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
755802
for i in firstindex(B.data,1):(j - 1)
756803
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
757804
end
@@ -782,7 +829,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, B::UnitLowerTriangular, c::Nu
782829
checksize1(A, B)
783830
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
784831
for j in axes(B.data,2)
785-
@inbounds _modify!(_add, c, A, (j,j))
832+
@inbounds _modify!(_add, B[BandIndex(0,j)] * c, A, (j,j))
786833
for i in (j + 1):lastindex(B.data,1)
787834
@inbounds _modify!(_add, B.data[i,j] * c, A.data, (i,j))
788835
end
@@ -793,7 +840,7 @@ function _triscale!(A::LowerOrUnitLowerTriangular, c::Number, B::UnitLowerTriang
793840
checksize1(A, B)
794841
_iszero_alpha(_add) && return _rmul_or_fill!(A, _add.beta)
795842
for j in axes(B.data,2)
796-
@inbounds _modify!(_add, c, A, (j,j))
843+
@inbounds _modify!(_add, c * B[BandIndex(0,j)], A, (j,j))
797844
for i in (j + 1):lastindex(B.data,1)
798845
@inbounds _modify!(_add, c * B.data[i,j], A.data, (i,j))
799846
end

test/triangular.jl

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,11 +641,11 @@ end
641641
@testset "error message" begin
642642
A = UpperTriangular(Ap)
643643
B = UpperTriangular(Bp)
644-
@test_throws "cannot set index in the lower triangular part" copyto!(A, B)
644+
@test_throws "cannot set index (3, 1) in the lower triangular part" copyto!(A, B)
645645

646646
A = LowerTriangular(Ap)
647647
B = LowerTriangular(Bp)
648-
@test_throws "cannot set index in the upper triangular part" copyto!(A, B)
648+
@test_throws "cannot set index (1, 2) in the upper triangular part" copyto!(A, B)
649649
end
650650
end
651651

@@ -950,6 +950,10 @@ end
950950
@test 2\U == 2\M
951951
@test U*2 == M*2
952952
@test 2*U == 2*M
953+
954+
U2 = copy(U)
955+
@test rmul!(U, 1) == U2
956+
@test lmul!(1, U) == U2
953957
end
954958

955959
@testset "scaling partly initialized unit triangular" begin
@@ -966,6 +970,77 @@ end
966970
end
967971
end
968972

973+
@testset "indexing checks" begin
974+
P = [1 2; 3 4]
975+
@testset "getindex" begin
976+
U = UnitUpperTriangular(P)
977+
@test_throws BoundsError U[0,0]
978+
@test_throws BoundsError U[1,0]
979+
@test_throws BoundsError U[BandIndex(0,0)]
980+
@test_throws BoundsError U[BandIndex(-1,0)]
981+
982+
U = UpperTriangular(P)
983+
@test_throws BoundsError U[1,0]
984+
@test_throws BoundsError U[BandIndex(-1,0)]
985+
986+
L = UnitLowerTriangular(P)
987+
@test_throws BoundsError L[0,0]
988+
@test_throws BoundsError L[0,1]
989+
@test_throws BoundsError U[BandIndex(0,0)]
990+
@test_throws BoundsError U[BandIndex(1,0)]
991+
992+
L = LowerTriangular(P)
993+
@test_throws BoundsError L[0,1]
994+
@test_throws BoundsError L[BandIndex(1,0)]
995+
end
996+
@testset "setindex!" begin
997+
A = SizedArrays.SizedArray{(2,2)}(P)
998+
M = fill(A, 2, 2)
999+
U = UnitUpperTriangular(M)
1000+
@test_throws "Cannot `convert` an object of type $Int" U[1,1] = 1
1001+
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitUpperTriangular matrix to a non-unit value"
1002+
@test_throws non_unit_msg U[1,1] = A
1003+
L = UnitLowerTriangular(M)
1004+
@test_throws "Cannot `convert` an object of type $Int" L[1,1] = 1
1005+
non_unit_msg = "cannot set index $((1,1)) on the diagonal of a UnitLowerTriangular matrix to a non-unit value"
1006+
@test_throws non_unit_msg L[1,1] = A
1007+
1008+
for UT in (UnitUpperTriangular, UpperTriangular)
1009+
U = UT(M)
1010+
@test_throws "Cannot `convert` an object of type $Int" U[2,1] = 0
1011+
end
1012+
for LT in (UnitLowerTriangular, LowerTriangular)
1013+
L = LT(M)
1014+
@test_throws "Cannot `convert` an object of type $Int" L[1,2] = 0
1015+
end
1016+
1017+
U = UnitUpperTriangular(P)
1018+
@test_throws BoundsError U[0,0] = 1
1019+
@test_throws BoundsError U[1,0] = 0
1020+
1021+
U = UpperTriangular(P)
1022+
@test_throws BoundsError U[1,0] = 0
1023+
1024+
L = UnitLowerTriangular(P)
1025+
@test_throws BoundsError L[0,0] = 1
1026+
@test_throws BoundsError L[0,1] = 0
1027+
1028+
L = LowerTriangular(P)
1029+
@test_throws BoundsError L[0,1] = 0
1030+
end
1031+
end
1032+
1033+
@testset "unit triangular l/rdiv!" begin
1034+
A = rand(3,3)
1035+
@testset for (UT,T) in ((UnitUpperTriangular, UpperTriangular),
1036+
(UnitLowerTriangular, LowerTriangular))
1037+
UnitTri = UT(A)
1038+
Tri = T(LinearAlgebra.full(UnitTri))
1039+
@test 2 \ UnitTri 2 \ Tri
1040+
@test UnitTri / 2 Tri / 2
1041+
end
1042+
end
1043+
9691044
@testset "fillband!" begin
9701045
@testset for TT in (UpperTriangular, UnitUpperTriangular)
9711046
U = TT(zeros(4,4))

0 commit comments

Comments
 (0)