Skip to content

Commit bbf9bff

Browse files
authored
Better definition of isapprox (#55)
1 parent 9f6111a commit bbf9bff

File tree

5 files changed

+146
-57
lines changed

5 files changed

+146
-57
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.8"
4+
version = "0.2.9"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/kroneckerarray.jl

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,11 +357,53 @@ function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
357357
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
358358
end
359359

360-
# TODO: this definition doesn't fully retain the original meaning:
361-
# ‖a - b‖ < atol could be true even if the following check isn't
362-
function Base.isapprox(a::AbstractKroneckerArray, b::AbstractKroneckerArray; kwargs...)
363-
return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...)
360+
# norm(a - b) = norm(a1 ⊗ a2 - b1 ⊗ b2)
361+
# = norm((a1 - b1) ⊗ a2 + b1 ⊗ (a2 - b2) + (a1 - b1) ⊗ (a2 - b2))
362+
function dist_kronecker(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
363+
a1, a2 = arg1(a), arg2(a)
364+
b1, b2 = arg1(b), arg2(b)
365+
diff1 = a1 - b1
366+
diff2 = a2 - b2
367+
# x = (a1 - b1) ⊗ a2
368+
# y = b1 ⊗ (a2 - b2)
369+
# z = (a1 - b1) ⊗ (a2 - b2)
370+
xx = norm(diff1)^2 * norm(a2)^2
371+
yy = norm(b1)^2 * norm(diff2)^2
372+
zz = norm(diff1)^2 * norm(diff2)^2
373+
xy = real(dot(diff1, b1) * dot(a2, diff2))
374+
xz = real(dot(diff1, diff1) * dot(a2, diff2))
375+
yz = real(dot(b1, diff1) * dot(diff2, diff2))
376+
# `abs` is used in case there are negative values due to floating point roundoff errors.
377+
return sqrt(abs(xx + yy + zz + 2 * (xy + xz + yz)))
378+
end
379+
380+
using LinearAlgebra: dot, promote_leaf_eltypes
381+
function Base.isapprox(
382+
a::AbstractKroneckerArray, b::AbstractKroneckerArray; atol::Real = 0,
383+
rtol::Real = Base.rtoldefault(promote_leaf_eltypes(a), promote_leaf_eltypes(b), atol),
384+
)
385+
a1, a2 = arg1(a), arg2(a)
386+
b1, b2 = arg1(b), arg2(b)
387+
if a1 == b1
388+
return isapprox(a2, b2; atol = atol / norm(a1), rtol)
389+
elseif a2 == b2
390+
return isapprox(a1, b1; atol = atol / norm(a2), rtol)
391+
else
392+
# This could be defined as:
393+
# ```julia
394+
# d = KroneckerArrays.dist_kronecker(a, b)
395+
# iszero(rtol) ? d <= atol : d <= max(atol, rtol * max(norm(a), norm(b)))
396+
# ```
397+
# but that might have numerical precision issues so for now we just error.
398+
throw(
399+
ArgumentError(
400+
"`isapprox` not implemented for KroneckerArrays where both arguments differ. " *
401+
"In those cases, you can use `isapprox(collect(a), collect(b); kwargs...)`."
402+
)
403+
)
404+
end
364405
end
406+
365407
function Base.iszero(a::AbstractKroneckerArray)
366408
return iszero(arg1(a)) || iszero(arg2(a))
367409
end

src/linearalgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function LinearAlgebra.tr(a::AbstractKroneckerArray)
5858
end
5959

6060
using LinearAlgebra: norm
61-
function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Int = 2)
61+
function LinearAlgebra.norm(a::AbstractKroneckerArray, p::Real = 2)
6262
return norm(arg1(a), p) * norm(arg2(a), p)
6363
end
6464

test/test_basics.jl

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,9 @@ using DerivableInterfaces: zero!
44
using DiagonalArrays: diagonal
55
using GPUArraysCore: @allowscalar
66
using JLArrays: JLArray
7-
using KroneckerArrays:
8-
KroneckerArrays,
9-
KroneckerArray,
10-
KroneckerStyle,
11-
CartesianProductUnitRange,
12-
CartesianProductVector,
13-
,
14-
×,
15-
arg1,
16-
arg2,
17-
cartesianproduct,
18-
cartesianrange,
19-
kron_nd,
20-
unproduct
7+
using KroneckerArrays: KroneckerArrays, KroneckerArray, KroneckerStyle,
8+
CartesianProductUnitRange, CartesianProductVector, , ×, arg1, arg2, cartesianproduct,
9+
cartesianrange, kron_nd, unproduct
2110
using LinearAlgebra: Diagonal, I, det, eigen, eigvals, lq, norm, pinv, qr, svd, svdvals, tr
2211
using StableRNGs: StableRNG
2312
using Test: @test, @test_broken, @test_throws, @testset
@@ -219,10 +208,11 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
219208

220209
a = randn(elt, 2, 2) randn(elt, 3, 3)
221210
b = randn(elt, 2, 2) randn(elt, 3, 3)
222-
c = a.arg1 b.arg2
211+
c = arg1(a) arg2(b)
223212
U, S, V = svd(a)
224213
@test collect(U * diagonal(S) * V') collect(a)
225-
@test svdvals(a) S
214+
@test arg1(svdvals(a)) arg1(S)
215+
@test arg2(svdvals(a)) arg2(S)
226216
@test sort(collect(S); rev = true) svdvals(collect(a))
227217
@test collect(U'U) I
228218
@test collect(V * V') I
@@ -246,4 +236,48 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
246236
@test_throws ArgumentError $f($a)
247237
end
248238
end
239+
240+
# isapprox
241+
242+
rng = StableRNG(123)
243+
a1 = randn(rng, elt, (2, 2))
244+
a = a1 randn(rng, elt, (3, 3))
245+
b = a1 randn(rng, elt, (3, 3))
246+
@test isapprox(a, b; atol = norm(a - b) * (1 + 2eps(real(elt))))
247+
@test !isapprox(a, b; atol = norm(a - b) * (1 - 2eps(real(elt))))
248+
@test isapprox(
249+
a, b;
250+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
251+
)
252+
@test !isapprox(
253+
a, b;
254+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
255+
)
256+
@test isapprox(
257+
a, b; atol = norm(a - b) * (1 + 2eps(real(elt))),
258+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
259+
)
260+
@test isapprox(
261+
a, b; atol = norm(a - b) * (1 + 2eps(real(elt))),
262+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
263+
)
264+
@test isapprox(
265+
a, b; atol = norm(a - b) * (1 - 2eps(real(elt))),
266+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 + 2eps(real(elt)))
267+
)
268+
@test !isapprox(
269+
a, b; atol = norm(a - b) * (1 - 2eps(real(elt))),
270+
rtol = norm(a - b) / max(norm(a), norm(b)) * (1 - 2eps(real(elt)))
271+
)
272+
273+
a = randn(elt, (2, 2)) randn(elt, (3, 3))
274+
b = randn(elt, (2, 2)) randn(elt, (3, 3))
275+
@test_throws ArgumentError isapprox(a, b)
276+
277+
# KroneckerArrays.dist_kronecker
278+
rng = StableRNG(123)
279+
a = randn(rng, (100, 100)) randn(rng, (100, 100))
280+
b = (arg1(a) + randn(rng, size(arg1(a))) / 10)
281+
(arg2(a) + randn(rng, size(arg2(a))) / 10)
282+
@test KroneckerArrays.dist_kronecker(a, b) norm(collect(a) - collect(b)) rtol = 1.0e-2
249283
end

test/test_matrixalgebrakit.jl

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,8 @@
1-
using KroneckerArrays: , arguments
1+
using KroneckerArrays: , arg1, arg2
22
using LinearAlgebra: Hermitian, I, diag, hermitianpart, norm
3-
using MatrixAlgebraKit:
4-
eig_full,
5-
eig_trunc,
6-
eig_vals,
7-
eigh_full,
8-
eigh_trunc,
9-
eigh_vals,
10-
left_null,
11-
left_orth,
12-
left_polar,
13-
lq_compact,
14-
lq_full,
15-
qr_compact,
16-
qr_full,
17-
right_null,
18-
right_orth,
19-
right_polar,
20-
svd_compact,
21-
svd_full,
22-
svd_trunc,
3+
using MatrixAlgebraKit: eig_full, eig_trunc, eig_vals, eigh_full, eigh_trunc,
4+
eigh_vals, left_null, left_orth, left_polar, lq_compact, lq_full, qr_compact,
5+
qr_full, right_null, right_orth, right_polar, svd_compact, svd_full, svd_trunc,
236
svd_vals
247
using Test: @test, @test_throws, @testset
258
using TestExtras: @constinferred
@@ -31,18 +14,26 @@ herm(a) = parent(hermitianpart(a))
3114

3215
a = randn(elt, 2, 2) randn(elt, 3, 3)
3316
d, v = eig_full(a)
34-
@test a * v v * d
17+
av = a * v
18+
vd = v * d
19+
@test arg1(av) arg1(vd)
20+
@test arg2(av) arg2(vd)
3521

3622
a = randn(elt, 2, 2) randn(elt, 3, 3)
3723
@test_throws ArgumentError eig_trunc(a)
3824

3925
a = randn(elt, 2, 2) randn(elt, 3, 3)
4026
d = eig_vals(a)
41-
@test d diag(eig_full(a)[1])
27+
d′ = diag(eig_full(a)[1])
28+
@test arg1(d) arg1(d′)
29+
@test arg2(d) arg2(d′)
4230

4331
a = herm(randn(elt, 2, 2)) herm(randn(elt, 3, 3))
4432
d, v = eigh_full(a)
45-
@test a * v v * d
33+
av = a * v
34+
vd = v * d
35+
@test arg1(av) arg1(vd)
36+
@test arg2(av) arg2(vd)
4637
@test eltype(d) === real(elt)
4738
@test eltype(v) === elt
4839

@@ -56,22 +47,30 @@ herm(a) = parent(hermitianpart(a))
5647

5748
a = randn(elt, 2, 2) randn(elt, 3, 3)
5849
u, c = qr_compact(a)
59-
@test u * c a
50+
uc = u * c
51+
@test arg1(uc) arg1(a)
52+
@test arg2(uc) arg2(a)
6053
@test collect(u'u) I
6154

6255
a = randn(elt, 2, 2) randn(elt, 3, 3)
6356
u, c = qr_full(a)
64-
@test u * c a
57+
uc = u * c
58+
@test arg1(uc) arg1(a)
59+
@test arg2(uc) arg2(a)
6560
@test collect(u'u) I
6661

6762
a = randn(elt, 2, 2) randn(elt, 3, 3)
6863
c, u = lq_compact(a)
69-
@test c * u a
64+
cu = c * u
65+
@test arg1(cu) arg1(a)
66+
@test arg2(cu) arg2(a)
7067
@test collect(u * u') I
7168

7269
a = randn(elt, 2, 2) randn(elt, 3, 3)
7370
c, u = lq_full(a)
74-
@test c * u a
71+
cu = c * u
72+
@test arg1(cu) arg1(a)
73+
@test arg2(cu) arg2(a)
7574
@test collect(u * u') I
7675

7776
a = randn(elt, 3, 2) randn(elt, 4, 3)
@@ -84,27 +83,37 @@ herm(a) = parent(hermitianpart(a))
8483

8584
a = randn(elt, 2, 2) randn(elt, 3, 3)
8685
u, c = left_orth(a)
87-
@test u * c a
86+
uc = u * c
87+
@test arg1(uc) arg1(a)
88+
@test arg2(uc) arg2(a)
8889
@test collect(u'u) I
8990

9091
a = randn(elt, 2, 2) randn(elt, 3, 3)
9192
c, u = right_orth(a)
92-
@test c * u a
93+
cu = c * u
94+
@test arg1(cu) arg1(a)
95+
@test arg2(cu) arg2(a)
9396
@test collect(u * u') I
9497

9598
a = randn(elt, 2, 2) randn(elt, 3, 3)
9699
u, c = left_polar(a)
97-
@test u * c a
100+
uc = u * c
101+
@test arg1(uc) arg1(a)
102+
@test arg2(uc) arg2(a)
98103
@test collect(u'u) I
99104

100105
a = randn(elt, 2, 2) randn(elt, 3, 3)
101106
c, u = right_polar(a)
102-
@test c * u a
107+
cu = c * u
108+
@test arg1(cu) arg1(a)
109+
@test arg2(cu) arg2(a)
103110
@test collect(u * u') I
104111

105112
a = randn(elt, 2, 2) randn(elt, 3, 3)
106113
u, s, v = svd_compact(a)
107-
@test u * s * v a
114+
usv = u * s * v
115+
@test arg1(usv) arg1(a)
116+
@test arg2(usv) arg2(a)
108117
@test eltype(u) === elt
109118
@test eltype(s) === real(elt)
110119
@test eltype(v) === elt
@@ -113,7 +122,9 @@ herm(a) = parent(hermitianpart(a))
113122

114123
a = randn(elt, 2, 2) randn(elt, 3, 3)
115124
u, s, v = svd_full(a)
116-
@test u * s * v a
125+
usv = u * s * v
126+
@test arg1(usv) arg1(a)
127+
@test arg2(usv) arg2(a)
117128
@test eltype(u) === elt
118129
@test eltype(s) === real(elt)
119130
@test eltype(v) === elt
@@ -125,5 +136,7 @@ herm(a) = parent(hermitianpart(a))
125136

126137
a = randn(elt, 2, 2) randn(elt, 3, 3)
127138
s = svd_vals(a)
128-
@test s diag(svd_compact(a)[2])
139+
s′ = diag(svd_compact(a)[2])
140+
@test arg1(s) arg1(s′)
141+
@test arg2(s) arg2(s′)
129142
end

0 commit comments

Comments
 (0)