Skip to content

Commit 2c2d41e

Browse files
authored
Update to BlockSparseArrays v0.9 (#36)
1 parent 59c6c60 commit 2c2d41e

File tree

10 files changed

+90
-57
lines changed

10 files changed

+90
-57
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.26"
4+
version = "0.1.27"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -16,19 +16,22 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1616
[weakdeps]
1717
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1818
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
19+
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
1920

2021
[extensions]
2122
KroneckerArraysBlockSparseArraysExt = ["BlockArrays", "BlockSparseArrays"]
23+
KroneckerArraysTensorProductsExt = "TensorProducts"
2224

2325
[compat]
2426
Adapt = "4.3"
2527
BlockArrays = "1.6"
26-
BlockSparseArrays = "0.8.1"
28+
BlockSparseArrays = "0.9"
2729
DerivableInterfaces = "0.5"
2830
DiagonalArrays = "0.3.5"
2931
FillArrays = "1.13"
3032
GPUArraysCore = "0.2"
3133
LinearAlgebra = "1.10"
3234
MapBroadcast = "0.1.9"
3335
MatrixAlgebraKit = "0.2"
36+
TensorProducts = "0.1.7"
3437
julia = "1.10"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module KroneckerArraysTensorProductsExt
2+
3+
using KroneckerArrays: CartesianProductOneTo, ×, arg1, arg2, cartesianrange, unproduct
4+
using TensorProducts: TensorProducts, tensor_product
5+
function TensorProducts.tensor_product(a1::CartesianProductOneTo, a2::CartesianProductOneTo)
6+
prod = tensor_product(arg1(a1), arg1(a2)) × tensor_product(arg2(a1), arg2(a2))
7+
range = tensor_product(unproduct(a1), unproduct(a2))
8+
return cartesianrange(prod, range)
9+
end
10+
11+
end

src/cartesianproduct.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ unproduct(r::CartesianProductVector) = getfield(r, :values)
6262
Base.length(a::CartesianProductVector) = length(unproduct(a))
6363
Base.size(a::CartesianProductVector) = (length(a),)
6464
function Base.axes(r::CartesianProductVector)
65-
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
65+
prod = cartesianproduct(r)
66+
prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod)))
67+
return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),)
6668
end
6769
function Base.copy(a::CartesianProductVector)
6870
return CartesianProductVector(copy(cartesianproduct(a)), copy(unproduct(a)))

src/fillarrays/kroneckerarray.jl

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,11 @@ function _convert(::Type{AbstractMatrix{T}}, a::RectDiagonal) where {T}
4141
RectDiagonal(convert(AbstractVector{T}, _diagview(a)), axes(a))
4242
end
4343

44-
# Like `similar` but preserves `Eye`.
45-
function _similar(a::AbstractArray, elt::Type, ax::Tuple)
46-
return similar(a, elt, ax)
44+
# Like `similar` but preserves `Eye`, `Ones`, etc.
45+
using FillArrays: Ones
46+
function _similar(arrayt::Type{<:Ones}, axs::Tuple)
47+
return Ones{eltype(arrayt)}(axs)
4748
end
48-
function _similar(A::Type{<:AbstractArray}, ax::Tuple)
49-
return similar(A, ax)
50-
end
51-
function _similar(a::AbstractArray, ax::Tuple)
52-
return _similar(a, eltype(a), ax)
53-
end
54-
function _similar(a::AbstractArray, elt::Type)
55-
return _similar(a, elt, axes(a))
56-
end
57-
function _similar(a::AbstractArray)
58-
return _similar(a, eltype(a), axes(a))
59-
end
60-
61-
# Like `similar` but preserves `Eye`.
6249
function _similar(a::Eye, elt::Type, axs::NTuple{2,AbstractUnitRange})
6350
return Eye{elt}(axs)
6451
end
@@ -77,19 +64,6 @@ end
7764
# Like `copy` but preserves `Eye`.
7865
_copy(a::Eye) = a
7966

80-
using DerivableInterfaces: DerivableInterfaces, zero!
81-
function DerivableInterfaces.zero!(a::EyeKronecker)
82-
zero!(a.b)
83-
return a
84-
end
85-
function DerivableInterfaces.zero!(a::KroneckerEye)
86-
zero!(a.a)
87-
return a
88-
end
89-
function DerivableInterfaces.zero!(a::EyeEye)
90-
return throw(ArgumentError("Can't zero out `Eye ⊗ Eye`."))
91-
end
92-
9367
using Base.Broadcast:
9468
AbstractArrayStyle, AbstractArrayStyle, BroadcastStyle, Broadcasted, broadcasted
9569

src/kroneckerarray.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,19 @@ function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where
5454
end
5555

5656
# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
57-
function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}})
57+
function _similar(a::AbstractArray, elt::Type, axs::Tuple)
5858
return similar(a, elt, axs)
5959
end
60-
function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple{Vararg{AbstractUnitRange}})
60+
function _similar(a::AbstractArray, ax::Tuple)
61+
return _similar(a, eltype(a), ax)
62+
end
63+
function _similar(a::AbstractArray, elt::Type)
64+
return _similar(a, elt, axes(a))
65+
end
66+
function _similar(a::AbstractArray)
67+
return _similar(a, eltype(a), axes(a))
68+
end
69+
function _similar(arrayt::Type{<:AbstractArray}, axs::Tuple)
6170
return similar(arrayt, axs)
6271
end
6372

@@ -130,6 +139,16 @@ Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
130139

131140
Base.zero(a::KroneckerArray) = zero(arg1(a)) zero(arg2(a))
132141

142+
using DerivableInterfaces: DerivableInterfaces, zero!
143+
function DerivableInterfaces.zero!(a::KroneckerArray)
144+
ismut1 = ismutable(arg1(a))
145+
ismut2 = ismutable(arg2(a))
146+
(ismut1 || ismut2) || throw(ArgumentError("Can't zero out immutable KroneckerArray."))
147+
ismut1 && zero!(arg1(a))
148+
ismut2 && zero!(arg2(a))
149+
return a
150+
end
151+
133152
function Base.Array{T,N}(a::KroneckerArray{S,N}) where {T,S,N}
134153
return convert(Array{T,N}, collect(a))
135154
end
@@ -372,13 +391,15 @@ _eltype(x) = eltype(x)
372391
_eltype(x::Broadcasted) = Base.promote_op(x.f, _eltype.(x.args)...)
373392

374393
using Base.Broadcast: broadcasted
375-
struct KroneckerBroadcasted{A<:Broadcasted,B<:Broadcasted}
394+
struct KroneckerBroadcasted{A,B}
376395
a::A
377396
b::B
378397
end
379398
arg1(a::KroneckerBroadcasted) = a.a
380399
arg2(a::KroneckerBroadcasted) = a.b
381400
(a::Broadcasted, b::Broadcasted) = KroneckerBroadcasted(a, b)
401+
(a::Broadcasted, b) = KroneckerBroadcasted(a, b)
402+
(a, b::Broadcasted) = KroneckerBroadcasted(a, b)
382403
Broadcast.materialize(a::KroneckerBroadcasted) = copy(a)
383404
Broadcast.materialize!(dest, a::KroneckerBroadcasted) = copyto!(dest, a)
384405
Broadcast.broadcastable(a::KroneckerBroadcasted) = a

src/linearalgebra.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,3 @@ function LinearAlgebra.lq(a::KroneckerArray)
179179
Fb = lq(a.b)
180180
return KroneckerLQ(Fa.L Fb.L, Fa.Q Fb.Q)
181181
end
182-
183-
using DerivableInterfaces: DerivableInterfaces, zero!
184-
function DerivableInterfaces.zero!(a::KroneckerArray)
185-
zero!(a.a)
186-
zero!(a.b)
187-
return a
188-
end

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
1414
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1515
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1616
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
17+
TensorProducts = "decf83d6-1968-43f4-96dc-fdb3fe15fc6d"
1718
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1819
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
1920

2021
[compat]
2122
Adapt = "4"
2223
Aqua = "0.8"
2324
BlockArrays = "1.6"
24-
BlockSparseArrays = "0.8.1"
25+
BlockSparseArrays = "0.9"
2526
DerivableInterfaces = "0.5"
2627
DiagonalArrays = "0.3.7"
2728
FillArrays = "1"
@@ -33,5 +34,6 @@ MatrixAlgebraKit = "0.2"
3334
SafeTestsets = "0.1"
3435
StableRNGs = "1.0"
3536
Suppressor = "0.2"
37+
TensorProducts = "0.1.7"
3638
Test = "1.10"
3739
TestExtras = "0.3"

test/test_basics.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using KroneckerArrays:
99
KroneckerArray,
1010
KroneckerStyle,
1111
CartesianProductUnitRange,
12+
CartesianProductVector,
1213
,
1314
×,
1415
arg1,
@@ -45,6 +46,14 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
4546
@test r[2 × 2] == 5
4647
@test r[2 × 3] == 6
4748

49+
# CartesianProductUnitRange axes
50+
r = cartesianrange((2:3) × (3:4), 2:5)
51+
@test axes(r) (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),)
52+
53+
# CartesianProductVector axes
54+
r = CartesianProductVector(([2, 4]) × ([3, 5]), [3, 5, 7, 9])
55+
@test axes(r) (CartesianProductUnitRange(Base.OneTo(2) × Base.OneTo(2), Base.OneTo(4)),)
56+
4857
r = @constinferred(cartesianrange(2 × 3, 2:7))
4958
@test r === cartesianrange(Base.OneTo(2) × Base.OneTo(3), 2:7)
5059
@test cartesianproduct(r) === Base.OneTo(2) × Base.OneTo(3)

test/test_blocksparsearrays.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ arrayts = (Array, JLArray)
2323
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
2424
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
2525
)
26-
a = dev(blocksparse(d, r, r))
26+
a = dev(blocksparse(d, (r, r)))
2727
@test sprint(show, a) isa String
2828
@test sprint(show, MIME("text/plain"), a) isa String
2929
@test blocktype(a) === valtype(d)
@@ -45,7 +45,7 @@ arrayts = (Array, JLArray)
4545
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
4646
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
4747
)
48-
a = dev(blocksparse(d, r, r))
48+
a = dev(blocksparse(d, (r, r)))
4949
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
5050
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
5151
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@@ -68,7 +68,7 @@ arrayts = (Array, JLArray)
6868
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
6969
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
7070
)
71-
a = dev(blocksparse(d, r, r))
71+
a = dev(blocksparse(d, (r, r)))
7272
i1 = Block(1)[(1:2) × (1:2)]
7373
i2 = Block(2)[(2:3) × (2:3)]
7474
I = mortar([i1, i2])
@@ -83,7 +83,7 @@ arrayts = (Array, JLArray)
8383
Block(1, 1) => dev(randn(elt, 2, 2) randn(elt, 2, 2)),
8484
Block(2, 2) => dev(randn(elt, 3, 3) randn(elt, 3, 3)),
8585
)
86-
a = dev(blocksparse(d, r, r))
86+
a = dev(blocksparse(d, (r, r)))
8787
i1 = Block(1)[(1:2) × (1:2)]
8888
i2 = Block(2)[(2:3) × (2:3)]
8989
I = [i1, i2]
@@ -130,9 +130,12 @@ arrayts = (Array, JLArray)
130130
@test_broken svd_compact(a)
131131
end
132132

133+
b = a[Block.(1:2), Block(2)]
134+
@test b[Block(1)] == a[Block(1, 2)]
135+
@test b[Block(2)] == a[Block(2, 2)]
136+
133137
# Broken operations
134138
@test_broken exp(a)
135-
@test_broken a[Block.(1:2), Block(2)]
136139
end
137140

138141
@testset "BlockSparseArraysExt, EyeKronecker blocks (arraytype=$arrayt, eltype=$elt)" for arrayt in
@@ -145,7 +148,7 @@ end
145148
Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)),
146149
Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)),
147150
)
148-
a = @constinferred dev(blocksparse(d, r, r))
151+
a = @constinferred dev(blocksparse(d, (r, r)))
149152
@test sprint(show, a) == sprint(show, Array(a))
150153
@test sprint(show, MIME("text/plain"), a) isa String
151154
@test @constinferred(blocktype(a)) === valtype(d)
@@ -167,7 +170,7 @@ end
167170
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
168171
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
169172
)
170-
a = dev(blocksparse(d, r, r))
173+
a = dev(blocksparse(d, (r, r)))
171174
@test a[Block(2, 2)[(2:3) × (2:3), (2:3) × (2:3)]] ==
172175
a[Block(2, 2)][(2:3) × (2:3), (2:3) × (2:3)]
173176
@test a[Block(2, 2)[(:) × (2:3), (:) × (2:3)]] == a[Block(2, 2)][(:) × (2:3), (:) × (2:3)]
@@ -194,7 +197,7 @@ end
194197
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
195198
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
196199
)
197-
a = dev(blocksparse(d, r, r))
200+
a = dev(blocksparse(d, (r, r)))
198201
i1 = Block(1)[(1:2) × (1:2)]
199202
i2 = Block(2)[(2:3) × (2:3)]
200203
I = mortar([i1, i2])
@@ -209,7 +212,7 @@ end
209212
Block(1, 1) => dev(Eye{elt}(2, 2) randn(elt, 2, 2)),
210213
Block(2, 2) => dev(Eye{elt}(3, 3) randn(elt, 3, 3)),
211214
)
212-
a = dev(blocksparse(d, r, r))
215+
a = dev(blocksparse(d, (r, r)))
213216
i1 = Block(1)[(1:2) × (1:2)]
214217
i2 = Block(2)[(2:3) × (2:3)]
215218
I = [i1, i2]
@@ -272,7 +275,9 @@ end
272275
end
273276

274277
# Broken operations
275-
@test_broken a[Block.(1:2), Block(2)]
278+
b = a[Block.(1:2), Block(2)]
279+
@test b[Block(1)] == a[Block(1, 2)]
280+
@test b[Block(2)] == a[Block(2, 2)]
276281

277282
# svd_trunc
278283
dev = adapt(arrayt)
@@ -282,7 +287,7 @@ end
282287
Block(1, 1) => Eye{elt}(2, 2) randn(rng, elt, 2, 2),
283288
Block(2, 2) => Eye{elt}(3, 3) randn(rng, elt, 3, 3),
284289
)
285-
a = @constinferred dev(blocksparse(d, r, r))
290+
a = @constinferred dev(blocksparse(d, (r, r)))
286291
if arrayt === Array
287292
u, s, v = svd_trunc(a; trunc=(; maxrank=6))
288293
u′, s′, v′ = svd_trunc(Matrix(a); trunc=(; maxrank=5))
@@ -293,10 +298,10 @@ end
293298

294299
@testset "Block deficient" begin
295300
da = Dict(Block(1, 1) => Eye{elt}(2, 2) dev(randn(elt, 2, 2)))
296-
a = @constinferred dev(blocksparse(da, r, r))
301+
a = @constinferred dev(blocksparse(da, (r, r)))
297302

298303
db = Dict(Block(2, 2) => Eye{elt}(3, 3) dev(randn(elt, 3, 3)))
299-
b = @constinferred dev(blocksparse(db, r, r))
304+
b = @constinferred dev(blocksparse(db, (r, r)))
300305

301306
@test Array(a + b) Array(a) + Array(b)
302307
@test Array(2a) 2Array(a)

test/test_tensorproducts.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using KroneckerArrays: ×, arg1, arg2, cartesianrange, unproduct
2+
using TensorProducts: tensor_product
3+
using Test: @test, @testset
4+
5+
@testset "KroneckerArraysTensorProductsExt" begin
6+
r1 = cartesianrange(2, 3)
7+
r2 = cartesianrange(4, 5)
8+
r = tensor_product(r1, r2)
9+
@test r cartesianrange(8, 15)
10+
@test arg1(r) Base.OneTo(8)
11+
@test arg2(r) Base.OneTo(15)
12+
@test unproduct(r) Base.OneTo(120)
13+
end

0 commit comments

Comments
 (0)