Skip to content

Commit 9f6111a

Browse files
authored
Introduce AbstractKroneckerArray (#54)
1 parent 06d9939 commit 9f6111a

File tree

5 files changed

+159
-183
lines changed

5 files changed

+159
-183
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.7"
4+
version = "0.2.8"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/KroneckerArraysTensorAlgebraExt/KroneckerArraysTensorAlgebraExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module KroneckerArraysTensorAlgebraExt
22

3-
using KroneckerArrays: KroneckerArrays, KroneckerArray, , arg1, arg2
3+
using KroneckerArrays: KroneckerArrays, AbstractKroneckerArray, , arg1, arg2
44
using TensorAlgebra:
55
TensorAlgebra, AbstractBlockPermutation, FusionStyle, matricize, unmatricize
66

@@ -10,7 +10,7 @@ struct KroneckerFusion{A <: FusionStyle, B <: FusionStyle} <: FusionStyle
1010
end
1111
KroneckerArrays.arg1(style::KroneckerFusion) = style.a
1212
KroneckerArrays.arg2(style::KroneckerFusion) = style.b
13-
function TensorAlgebra.FusionStyle(a::KroneckerArray)
13+
function TensorAlgebra.FusionStyle(a::AbstractKroneckerArray)
1414
return KroneckerFusion(FusionStyle(arg1(a)), FusionStyle(arg2(a)))
1515
end
1616
function matricize_kronecker(

src/kroneckerarray.jl

Lines changed: 90 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,35 @@
1+
"""
2+
abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end
3+
4+
Abstract supertype for arrays that have a kronecker product structure,
5+
i.e. that can be written as `AB = A ⊗ B`.
6+
"""
7+
abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end
8+
9+
const AbstractKroneckerVector{T} = AbstractKroneckerArray{T, 1}
10+
const AbstractKroneckerMatrix{T} = AbstractKroneckerArray{T, 2}
11+
12+
@doc """
13+
arg1(AB::AbstractKroneckerArray{T, N})
14+
15+
Extract the first factor (`A`) of the Kronecker array `AB = A ⊗ B`.
16+
""" arg1
17+
18+
@doc """
19+
arg2(AB::AbstractKroneckerArray{T, N})
20+
21+
Extract the second factor (`B`) of the Kronecker array `AB = A ⊗ B`.
22+
""" arg2
23+
24+
arg1type(x::AbstractKroneckerArray) = arg1type(typeof(x))
25+
arg1type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg1type`.")
26+
arg2type(x::AbstractKroneckerArray) = arg2type(typeof(x))
27+
arg2type(::Type{<:AbstractKroneckerArray}) = error("`AbstractKroneckerArray` subtypes have to implement `arg2type`.")
28+
29+
arguments(a::AbstractKroneckerArray) = (arg1(a), arg2(a))
30+
arguments(a::AbstractKroneckerArray, n::Int) = arguments(a)[n]
31+
argument_types(a::AbstractKroneckerArray) = argument_types(typeof(a))
32+
133
function unwrap_array(a::AbstractArray)
234
p = parent(a)
335
p a && return a
@@ -26,7 +58,7 @@ function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
2658
end
2759

2860
struct KroneckerArray{T, N, A1 <: AbstractArray{T, N}, A2 <: AbstractArray{T, N}} <:
29-
AbstractArray{T, N}
61+
AbstractKroneckerArray{T, N}
3062
arg1::A1
3163
arg2::A2
3264
end
@@ -48,6 +80,10 @@ const KroneckerVector{T, A1 <: AbstractVector{T}, A2 <: AbstractVector{T}} = Kro
4880

4981
@inline arg1(a::KroneckerArray) = getfield(a, :arg1)
5082
@inline arg2(a::KroneckerArray) = getfield(a, :arg2)
83+
arg1type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A1
84+
arg2type(::Type{KroneckerArray{T, N, A1, A2}}) where {T, N, A1, A2} = A2
85+
86+
argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2)
5187

5288
function mutate_active_args!(f!, f, dest, src)
5389
(isactive(arg1(dest)) || isactive(arg2(dest))) ||
@@ -66,7 +102,7 @@ function mutate_active_args!(f!, f, dest, src)
66102
end
67103

68104
using Adapt: Adapt, adapt
69-
function Adapt.adapt_structure(to, a::KroneckerArray)
105+
function Adapt.adapt_structure(to, a::AbstractKroneckerArray)
70106
# TODO: Is this a good definition? It is similar to
71107
# the definition of `similar`.
72108
return if isactive(arg1(a)) == isactive(arg2(a))
@@ -78,18 +114,22 @@ function Adapt.adapt_structure(to, a::KroneckerArray)
78114
end
79115
end
80116

81-
function Base.copy(a::KroneckerArray)
82-
return copy(arg1(a)) copy(arg2(a))
117+
Base.copy(a::AbstractKroneckerArray) = copy(arg1(a)) copy(arg2(a))
118+
function Base.copy!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray)
119+
return mutate_active_args!(copy!, copy, dest, src)
83120
end
84121

122+
# TODO: copyto! is typically reserved for contiguous copies (i.e. also for copying from a
123+
# vector into an array), it might be better to not define that here.
85124
function Base.copyto!(dest::KroneckerArray{<:Any, N}, src::KroneckerArray{<:Any, N}) where {N}
86125
return mutate_active_args!(copyto!, copy, dest, src)
87126
end
88127

89128
function Base.convert(
90-
::Type{KroneckerArray{T, N, A1, A2}}, a::KroneckerArray
91-
) where {T, N, A1, A2}
92-
return _convert(A1, arg1(a)) _convert(A2, arg2(a))
129+
::Type{KroneckerArray{T, N, A1, A2}}, a::AbstractKroneckerArray
130+
)::KroneckerArray{T, N, A1, A2} where {T, N, A1, A2}
131+
typeof(a) === KroneckerArray{T, N, A1, A2} && return a
132+
return KroneckerArray(_convert(A1, arg1(a)), _convert(A2, arg2(a)))
93133
end
94134

95135
# Promote the element type if needed.
@@ -98,7 +138,7 @@ end
98138
maybe_promot_eltype(a, elt) = eltype(a) <: elt ? a : elt.(a)
99139

100140
function Base.similar(
101-
a::KroneckerArray,
141+
a::AbstractKroneckerArray,
102142
elt::Type,
103143
axs::Tuple{
104144
CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}},
@@ -115,7 +155,7 @@ function Base.similar(
115155
maybe_promot_eltype(arg1(a), elt) similar(arg2(a), elt, arg2.(axs))
116156
end
117157
end
118-
function Base.similar(a::KroneckerArray, elt::Type)
158+
function Base.similar(a::AbstractKroneckerArray, elt::Type)
119159
# TODO: Is this a good definition?
120160
return if isactive(arg1(a)) == isactive(arg2(a))
121161
similar(arg1(a), elt) similar(arg2(a), elt)
@@ -125,7 +165,7 @@ function Base.similar(a::KroneckerArray, elt::Type)
125165
maybe_promot_eltype(arg1(a), elt) similar(arg2(a), elt)
126166
end
127167
end
128-
function Base.similar(a::KroneckerArray)
168+
function Base.similar(a::AbstractKroneckerArray)
129169
# TODO: Is this a good definition?
130170
return if isactive(arg1(a)) == isactive(arg2(a))
131171
similar(arg1(a)) similar(arg2(a))
@@ -147,16 +187,18 @@ function Base.similar(
147187
end
148188

149189
function Base.similar(
150-
arrayt::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}},
190+
::Type{ArrayT},
151191
axs::Tuple{
152192
CartesianProductUnitRange{<:Integer}, Vararg{CartesianProductUnitRange{<:Integer}},
153193
},
154-
) where {A1, A2}
194+
) where {ArrayT <: AbstractKroneckerArray}
195+
A1, A2 = arg1type(ArrayT), arg2type(ArrayT)
155196
return similar(A1, map(arg1, axs)) similar(A2, map(arg2, axs))
156197
end
157198
function Base.similar(
158-
::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}, sz::Tuple{Int, Vararg{Int}}
159-
) where {A1, A2}
199+
::Type{ArrayT}, sz::Tuple{Int, Vararg{Int}}
200+
) where {ArrayT <: AbstractKroneckerArray}
201+
A1, A2 = arg1type(ArrayT), arg2type(ArrayT)
160202
return similar(promote_type(A1, A2), sz)
161203
end
162204

@@ -169,15 +211,15 @@ function Base.similar(
169211
return similar(arrayt, map(arg1, axs)) similar(arrayt, map(arg2, axs))
170212
end
171213

172-
function Base.permutedims(a::KroneckerArray, perm)
214+
function Base.permutedims(a::AbstractKroneckerArray, perm)
173215
return permutedims(arg1(a), perm) permutedims(arg2(a), perm)
174216
end
175217
using DerivableInterfaces: DerivableInterfaces, permuteddims
176-
function DerivableInterfaces.permuteddims(a::KroneckerArray, perm)
218+
function DerivableInterfaces.permuteddims(a::AbstractKroneckerArray, perm)
177219
return permuteddims(arg1(a), perm) permuteddims(arg2(a), perm)
178220
end
179221

180-
function Base.permutedims!(dest::KroneckerArray, src::KroneckerArray, perm)
222+
function Base.permutedims!(dest::AbstractKroneckerArray, src::AbstractKroneckerArray, perm)
181223
return mutate_active_args!(
182224
(dest, src) -> permutedims!(dest, src, perm), Base.Fix2(permutedims, perm), dest, src
183225
)
@@ -208,9 +250,10 @@ kron_nd(a1::AbstractMatrix, a2::AbstractMatrix) = kron(a1, a2)
208250
kron_nd(a1::AbstractVector, a2::AbstractVector) = kron(a1, a2)
209251

210252
# Eagerly collect arguments to make more general on GPU.
211-
Base.collect(a::KroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
253+
Base.collect(a::AbstractKroneckerArray) = kron_nd(collect(arg1(a)), collect(arg2(a)))
254+
Base.collect(T::Type, a::AbstractKroneckerArray) = kron_nd(collect(T, arg1(a)), collect(T, arg2(a)))
212255

213-
function Base.zero(a::KroneckerArray)
256+
function Base.zero(a::AbstractKroneckerArray)
214257
return if isactive(arg1(a)) == isactive(arg2(a))
215258
# TODO: Maybe this should zero both arguments?
216259
# This is how `a * false` would behave.
@@ -223,35 +266,28 @@ function Base.zero(a::KroneckerArray)
223266
end
224267

225268
using DerivableInterfaces: DerivableInterfaces, zero!
226-
function DerivableInterfaces.zero!(a::KroneckerArray)
269+
function DerivableInterfaces.zero!(a::AbstractKroneckerArray)
227270
(isactive(arg1(a)) || isactive(arg2(a))) ||
228271
error("Can't mutate immutable KroneckerArray.")
229272
isactive(arg1(a)) && zero!(arg1(a))
230273
isactive(arg2(a)) && zero!(arg2(a))
231274
return a
232275
end
233276

234-
function Base.Array{T, N}(a::KroneckerArray{S, N}) where {T, S, N}
235-
return convert(Array{T, N}, collect(a))
277+
function Base.Array{T, N}(a::AbstractKroneckerArray{S, N}) where {T, S, N}
278+
return convert(Array{T, N}, collect(T, a))
236279
end
237280

238-
function Base.size(a::KroneckerArray)
239-
return ntuple(dim -> size(arg1(a), dim) * size(arg2(a), dim), ndims(a))
240-
end
281+
Base.size(a::AbstractKroneckerArray) = size(arg1(a)) .* size(arg2(a))
241282

242-
function Base.axes(a::KroneckerArray)
283+
function Base.axes(a::AbstractKroneckerArray)
243284
return ntuple(ndims(a)) do dim
244285
return CartesianProductUnitRange(
245286
axes(arg1(a), dim) × axes(arg2(a), dim), Base.OneTo(size(a, dim))
246287
)
247288
end
248289
end
249290

250-
arguments(a::KroneckerArray) = (arg1(a), arg2(a))
251-
arguments(a::KroneckerArray, n::Int) = arguments(a)[n]
252-
argument_types(a::KroneckerArray) = argument_types(typeof(a))
253-
argument_types(::Type{<:KroneckerArray{<:Any, <:Any, A1, A2}}) where {A1, A2} = (A1, A2)
254-
255291
function Base.print_array(io::IO, a::KroneckerArray)
256292
Base.print_array(io, arg1(a))
257293
println(io, "\n")
@@ -285,45 +321,48 @@ end
285321

286322
# Indexing logic.
287323
function Base.to_indices(
288-
a::KroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
324+
a::AbstractKroneckerArray, inds, I::Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
289325
)
290326
I1 = to_indices(arg1(a), arg1.(inds), arg1.(I))
291327
I2 = to_indices(arg2(a), arg2.(inds), arg2.(I))
292328
return I1 I2
293329
end
294330

295331
function Base.getindex(
296-
a::KroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}
332+
a::AbstractKroneckerArray{<:Any, N}, I::Vararg{Union{CartesianPair, CartesianProduct}, N}
297333
) where {N}
298334
I′ = to_indices(a, I)
299335
return arg1(a)[arg1.(I′)...] arg2(a)[arg2.(I′)...]
300336
end
301337
# Fix ambigiuity error.
302-
Base.getindex(a::KroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[]
338+
Base.getindex(a::AbstractKroneckerArray{<:Any, 0}) = arg1(a)[] * arg2(a)[]
303339

304340
arg1(::Colon) = (:)
305341
arg2(::Colon) = (:)
306342
arg1(::Base.Slice) = (:)
307343
arg2(::Base.Slice) = (:)
308344
function Base.view(
309-
a::KroneckerArray{<:Any, N},
345+
a::AbstractKroneckerArray{<:Any, N},
310346
I::Vararg{Union{CartesianProduct, CartesianProductUnitRange, Base.Slice, Colon}, N},
311347
) where {N}
312348
return view(arg1(a), arg1.(I)...) view(arg2(a), arg2.(I)...)
313349
end
314-
function Base.view(a::KroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N}
350+
function Base.view(a::AbstractKroneckerArray{<:Any, N}, I::Vararg{CartesianPair, N}) where {N}
315351
return view(arg1(a), arg1.(I)...) view(arg2(a), arg2.(I)...)
316352
end
317353
# Fix ambigiuity error.
318-
Base.view(a::KroneckerArray{<:Any, 0}) = view(arg1(a)) view(arg2(a))
354+
Base.view(a::AbstractKroneckerArray{<:Any, 0}) = view(arg1(a)) view(arg2(a))
319355

320-
function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
356+
function Base.:(==)(a::AbstractKroneckerArray, b::AbstractKroneckerArray)
321357
return arg1(a) == arg1(b) && arg2(a) == arg2(b)
322358
end
323-
function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...)
359+
360+
# TODO: this definition doesn't fully retain the original meaning:
361+
# ‖a - b‖ < atol could be true even if the following check isn't
362+
function Base.isapprox(a::AbstractKroneckerArray, b::AbstractKroneckerArray; kwargs...)
324363
return isapprox(arg1(a), arg1(b); kwargs...) && isapprox(arg2(a), arg2(b); kwargs...)
325364
end
326-
function Base.iszero(a::KroneckerArray)
365+
function Base.iszero(a::AbstractKroneckerArray)
327366
return iszero(arg1(a)) || iszero(arg2(a))
328367
end
329368
function Base.isreal(a::KroneckerArray)
@@ -335,17 +374,17 @@ function DiagonalArrays.diagonal(a::KroneckerArray)
335374
return diagonal(arg1(a)) diagonal(arg2(a))
336375
end
337376

338-
Base.real(a::KroneckerArray{<:Real}) = a
339-
function Base.real(a::KroneckerArray)
377+
Base.real(a::AbstractKroneckerArray{<:Real}) = a
378+
function Base.real(a::AbstractKroneckerArray)
340379
if iszero(imag(arg1(a))) || iszero(imag(arg2(a)))
341380
return real(arg1(a)) real(arg2(a))
342381
elseif iszero(real(arg1(a))) || iszero(real(arg2(a)))
343382
return -(imag(arg1(a)) imag(arg2(a)))
344383
end
345384
return real(arg1(a)) real(arg2(a)) - imag(arg1(a)) imag(arg2(a))
346385
end
347-
Base.imag(a::KroneckerArray{<:Real}) = zero(a)
348-
function Base.imag(a::KroneckerArray)
386+
Base.imag(a::AbstractKroneckerArray{<:Real}) = zero(a)
387+
function Base.imag(a::AbstractKroneckerArray)
349388
if iszero(imag(arg1(a))) || iszero(real(arg2(a)))
350389
return real(arg1(a)) imag(arg2(a))
351390
elseif iszero(real(arg1(a))) || iszero(imag(arg2(a)))
@@ -356,14 +395,14 @@ end
356395

357396
for f in [:transpose, :adjoint, :inv]
358397
@eval begin
359-
function Base.$f(a::KroneckerArray)
398+
function Base.$f(a::AbstractKroneckerArray)
360399
return $f(arg1(a)) $f(arg2(a))
361400
end
362401
end
363402
end
364403

365404
function Base.reshape(
366-
a::KroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
405+
a::AbstractKroneckerArray, ax::Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
367406
)
368407
return reshape(arg1(a), map(arg1, ax)) reshape(arg2(a), map(arg2, ax))
369408
end
@@ -383,8 +422,8 @@ end
383422
function KroneckerStyle{N, A1, A2}(v::Val{M}) where {N, A1, A2, M}
384423
return KroneckerStyle{M, typeof(A1)(v), typeof(A2)(v)}()
385424
end
386-
function Base.BroadcastStyle(::Type{<:KroneckerArray{<:Any, N, A1, A2}}) where {N, A1, A2}
387-
return KroneckerStyle{N}(BroadcastStyle(A1), BroadcastStyle(A2))
425+
function Base.BroadcastStyle(::Type{T}) where {T <: AbstractKroneckerArray}
426+
return KroneckerStyle{ndims(T)}(BroadcastStyle(arg1type(T)), BroadcastStyle(arg2type(T)))
388427
end
389428
function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N}) where {N}
390429
style_a = BroadcastStyle(arg1(style1), arg1(style2))
@@ -403,10 +442,10 @@ function Base.similar(
403442
return a b
404443
end
405444

406-
function Base.map(f, a1::KroneckerArray, a_rest::KroneckerArray...)
445+
function Base.map(f, a1::AbstractKroneckerArray, a_rest::AbstractKroneckerArray...)
407446
return Broadcast.broadcast_preserving_zero_d(f, a1, a_rest...)
408447
end
409-
function Base.map!(f, dest::KroneckerArray, a1::KroneckerArray, a_rest::KroneckerArray...)
448+
function Base.map!(f, dest::AbstractKroneckerArray, a1::AbstractKroneckerArray, a_rest::AbstractKroneckerArray...)
410449
dest .= f.(a1, a_rest...)
411450
return dest
412451
end
@@ -438,7 +477,7 @@ end
438477
function Base.copy(a::Summed{<:KroneckerStyle})
439478
return copy(KroneckerBroadcast(a))
440479
end
441-
function Base.copyto!(dest::KroneckerArray, a::Summed{<:KroneckerStyle})
480+
function Base.copyto!(dest::AbstractKroneckerArray, a::Summed{<:KroneckerStyle})
442481
return copyto!(dest, KroneckerBroadcast(a))
443482
end
444483

0 commit comments

Comments
 (0)