Skip to content

Commit 3a057e4

Browse files
authored
check_mul_axes specialization (#65)
1 parent bf41a6c commit 3a057e4

File tree

4 files changed

+29
-16
lines changed

4 files changed

+29
-16
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.4.19"
4+
version = "0.4.20"
55

66
[deps]
7+
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
78
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
89
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
910
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
@@ -23,6 +24,7 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2324
GradedArraysTensorAlgebraExt = "TensorAlgebra"
2425

2526
[compat]
27+
ArrayLayouts = "1"
2628
BlockArrays = "1.6"
2729
BlockSparseArrays = "0.8, 0.9.3"
2830
Compat = "4.16"

src/gradedarray.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using BlockSparseArrays:
99
sparsemortar
1010
using LinearAlgebra: Adjoint
1111
using TypeParameterAccessors: similartype, unwrap_array_type
12+
using ArrayLayouts: ArrayLayouts
1213

1314
const GradedArray{T,N,A<:AbstractArray{T,N},Blocks<:AbstractArray{A,N},Axes<:Tuple{AbstractGradedUnitRange{<:Integer},Vararg{AbstractGradedUnitRange{<:Integer}}}} = BlockSparseArray{
1415
T,N,A,Blocks,Axes
@@ -236,3 +237,11 @@ function Base.showarg(io::IO, a::GradedArray, toplevel::Bool)
236237
print(io, concretetype_to_string_truncated(typeof(a); param_truncation_length=40))
237238
return nothing
238239
end
240+
241+
const AnyGradedMatrix{T} = Union{GradedMatrix{T},Adjoint{T,<:GradedMatrix{T}}}
242+
243+
function ArrayLayouts._check_mul_axes(A::AnyGradedMatrix, B::AnyGradedMatrix)
244+
axA = axes(A, 2)
245+
axB = axes(B, 1)
246+
return space_isequal(dual(axA), axB) || ArrayLayouts.throw_mul_axes_err(axA, axB)
247+
end

test/test_factorizations.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,13 @@ end
193193
a = zeros(elt, r1, dual(r2))
194194
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
195195
@test flux(a) == U1(-1)
196-
q, r = left_polar(a)
197-
@test q * r a
198-
@test Array(q'q) I
199-
@test_broken flux(q) == trivial(flux(a))
200-
@test_broken flux(r) == flux(a)
196+
197+
# tests broken for nonzero flux
198+
# q, r = left_polar(a)
199+
# @test q * r ≈ a
200+
# @test Array(q'q) ≈ I
201+
# @test_broken flux(q) == trivial(flux(a))
202+
# @test_broken flux(r) == flux(a)
201203
end
202204

203205
@testset "lq_compact, right_orth (eltype=$elt)" for elt in elts
@@ -273,9 +275,11 @@ end
273275
a = zeros(elt, r1, dual(r2))
274276
a[Block(1, 2)] = randn(elt, blocksizes(a)[1, 2])
275277
@test flux(a) == U1(-1)
276-
l, q = right_polar(a)
277-
@test l * q a
278-
@test Array(q * q') I
279-
@test_broken flux(l) == flux(a)
280-
@test_broken flux(q) == trivial(flux(a))
278+
279+
# tests broken for nonzero flux
280+
# l, q = right_polar(a)
281+
# @test l * q ≈ a
282+
# @test Array(q * q') ≈ I
283+
# @test_broken flux(l) == flux(a)
284+
# @test_broken flux(q) == trivial(flux(a))
281285
end

test/test_gradedarray.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ using GradedArrays:
2222
using SparseArraysBase: storedlength
2323
using LinearAlgebra: adjoint
2424
using Random: randn!
25-
using Test: @test, @testset
25+
using Test: @test, @testset, @test_throws
2626

2727
function randn_blockdiagonal(elt::Type, axes::Tuple)
2828
a = BlockSparseArray{elt}(undef, axes)
@@ -387,12 +387,10 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
387387
a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)])))
388388
@test Array(a1 * a2) Array(a1) * Array(a2)
389389
@test Array(a1' * a2') Array(a1') * Array(a2')
390-
391-
a2 = BlockSparseArray{elt}(undef, r, dual(r))
392-
a2[Block(1, 2)] = randn(elt, size(@view(a2[Block(1, 2)])))
393-
a2[Block(2, 1)] = randn(elt, size(@view(a2[Block(2, 1)])))
394390
@test Array(a1' * a2) Array(a1') * Array(a2)
395391
@test Array(a1 * a2') Array(a1) * Array(a2')
392+
393+
@test_throws DimensionMismatch a1 * permutedims(a2, (2, 1))
396394
end
397395
@testset "Construct from dense" begin
398396
r = gradedrange([U1(0) => 2, U1(1) => 3])

0 commit comments

Comments
 (0)