1+ """
2+ abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end
3+
4+ Abstract supertype for arrays that have a kronecker product structure,
5+ i.e. that can be written as `AB = A ⊗ B`.
6+ """
7+ abstract type AbstractKroneckerArray{T, N} <: AbstractArray{T, N} end
8+
9+ const AbstractKroneckerVector{T} = AbstractKroneckerArray{T, 1 }
10+ const AbstractKroneckerMatrix{T} = AbstractKroneckerArray{T, 2 }
11+
12+ @doc """
13+ arg1(AB::AbstractKroneckerArray{T, N})
14+
15+ Extract the first factor (`A`) of the Kronecker array `AB = A ⊗ B`.
16+ """ arg1
17+
18+ @doc """
19+ arg2(AB::AbstractKroneckerArray{T, N})
20+
21+ Extract the second factor (`B`) of the Kronecker array `AB = A ⊗ B`.
22+ """ arg2
23+
24+ arg1type (x:: AbstractKroneckerArray ) = arg1type (typeof (x))
25+ arg1type (:: Type{<:AbstractKroneckerArray} ) = error (" `AbstractKroneckerArray` subtypes have to implement `arg1type`." )
26+ arg2type (x:: AbstractKroneckerArray ) = arg2type (typeof (x))
27+ arg2type (:: Type{<:AbstractKroneckerArray} ) = error (" `AbstractKroneckerArray` subtypes have to implement `arg2type`." )
28+
29+ arguments (a:: AbstractKroneckerArray ) = (arg1 (a), arg2 (a))
30+ arguments (a:: AbstractKroneckerArray , n:: Int ) = arguments (a)[n]
31+ argument_types (a:: AbstractKroneckerArray ) = argument_types (typeof (a))
32+
133function unwrap_array (a:: AbstractArray )
234 p = parent (a)
335 p ≡ a && return a
@@ -26,7 +58,7 @@ function _convert(A::Type{<:Diagonal}, a::AbstractMatrix)
2658end
2759
2860struct KroneckerArray{T, N, A1 <: AbstractArray{T, N} , A2 <: AbstractArray{T, N} } < :
29- AbstractArray {T, N}
61+ AbstractKroneckerArray {T, N}
3062 arg1:: A1
3163 arg2:: A2
3264end
@@ -48,6 +80,10 @@ const KroneckerVector{T, A1 <: AbstractVector{T}, A2 <: AbstractVector{T}} = Kro
4880
4981@inline arg1 (a:: KroneckerArray ) = getfield (a, :arg1 )
5082@inline arg2 (a:: KroneckerArray ) = getfield (a, :arg2 )
83+ arg1type (:: Type{KroneckerArray{T, N, A1, A2}} ) where {T, N, A1, A2} = A1
84+ arg2type (:: Type{KroneckerArray{T, N, A1, A2}} ) where {T, N, A1, A2} = A2
85+
86+ argument_types (:: Type{<:KroneckerArray{<:Any, <:Any, A1, A2}} ) where {A1, A2} = (A1, A2)
5187
5288function mutate_active_args! (f!, f, dest, src)
5389 (isactive (arg1 (dest)) || isactive (arg2 (dest))) ||
@@ -66,7 +102,7 @@ function mutate_active_args!(f!, f, dest, src)
66102end
67103
68104using Adapt: Adapt, adapt
69- function Adapt. adapt_structure (to, a:: KroneckerArray )
105+ function Adapt. adapt_structure (to, a:: AbstractKroneckerArray )
70106 # TODO : Is this a good definition? It is similar to
71107 # the definition of `similar`.
72108 return if isactive (arg1 (a)) == isactive (arg2 (a))
@@ -78,18 +114,22 @@ function Adapt.adapt_structure(to, a::KroneckerArray)
78114 end
79115end
80116
81- function Base. copy (a:: KroneckerArray )
82- return copy (arg1 (a)) ⊗ copy (arg2 (a))
117+ Base. copy (a:: AbstractKroneckerArray ) = copy (arg1 (a)) ⊗ copy (arg2 (a))
118+ function Base. copy! (dest:: AbstractKroneckerArray , src:: AbstractKroneckerArray )
119+ return mutate_active_args! (copy!, copy, dest, src)
83120end
84121
122+ # TODO : copyto! is typically reserved for contiguous copies (i.e. also for copying from a
123+ # vector into an array), it might be better to not define that here.
85124function Base. copyto! (dest:: KroneckerArray{<:Any, N} , src:: KroneckerArray{<:Any, N} ) where {N}
86125 return mutate_active_args! (copyto!, copy, dest, src)
87126end
88127
89128function Base. convert (
90- :: Type{KroneckerArray{T, N, A1, A2}} , a:: KroneckerArray
91- ) where {T, N, A1, A2}
92- return _convert (A1, arg1 (a)) ⊗ _convert (A2, arg2 (a))
129+ :: Type{KroneckerArray{T, N, A1, A2}} , a:: AbstractKroneckerArray
130+ ):: KroneckerArray{T, N, A1, A2} where {T, N, A1, A2}
131+ typeof (a) === KroneckerArray{T, N, A1, A2} && return a
132+ return KroneckerArray (_convert (A1, arg1 (a)), _convert (A2, arg2 (a)))
93133end
94134
95135# Promote the element type if needed.
98138maybe_promot_eltype (a, elt) = eltype (a) <: elt ? a : elt .(a)
99139
100140function Base. similar (
101- a:: KroneckerArray ,
141+ a:: AbstractKroneckerArray ,
102142 elt:: Type ,
103143 axs:: Tuple {
104144 CartesianProductUnitRange{<: Integer }, Vararg{CartesianProductUnitRange{<: Integer }},
@@ -115,7 +155,7 @@ function Base.similar(
115155 maybe_promot_eltype (arg1 (a), elt) ⊗ similar (arg2 (a), elt, arg2 .(axs))
116156 end
117157end
118- function Base. similar (a:: KroneckerArray , elt:: Type )
158+ function Base. similar (a:: AbstractKroneckerArray , elt:: Type )
119159 # TODO : Is this a good definition?
120160 return if isactive (arg1 (a)) == isactive (arg2 (a))
121161 similar (arg1 (a), elt) ⊗ similar (arg2 (a), elt)
@@ -125,7 +165,7 @@ function Base.similar(a::KroneckerArray, elt::Type)
125165 maybe_promot_eltype (arg1 (a), elt) ⊗ similar (arg2 (a), elt)
126166 end
127167end
128- function Base. similar (a:: KroneckerArray )
168+ function Base. similar (a:: AbstractKroneckerArray )
129169 # TODO : Is this a good definition?
130170 return if isactive (arg1 (a)) == isactive (arg2 (a))
131171 similar (arg1 (a)) ⊗ similar (arg2 (a))
@@ -147,16 +187,18 @@ function Base.similar(
147187end
148188
149189function Base. similar (
150- arrayt :: Type{<:KroneckerArray{<:Any, <:Any, A1, A2} } ,
190+ :: Type{ArrayT } ,
151191 axs:: Tuple {
152192 CartesianProductUnitRange{<: Integer }, Vararg{CartesianProductUnitRange{<: Integer }},
153193 },
154- ) where {A1, A2}
194+ ) where {ArrayT <: AbstractKroneckerArray }
195+ A1, A2 = arg1type (ArrayT), arg2type (ArrayT)
155196 return similar (A1, map (arg1, axs)) ⊗ similar (A2, map (arg2, axs))
156197end
157198function Base. similar (
158- :: Type{<:KroneckerArray{<:Any, <:Any, A1, A2}} , sz:: Tuple{Int, Vararg{Int}}
159- ) where {A1, A2}
199+ :: Type{ArrayT} , sz:: Tuple{Int, Vararg{Int}}
200+ ) where {ArrayT <: AbstractKroneckerArray }
201+ A1, A2 = arg1type (ArrayT), arg2type (ArrayT)
160202 return similar (promote_type (A1, A2), sz)
161203end
162204
@@ -169,15 +211,15 @@ function Base.similar(
169211 return similar (arrayt, map (arg1, axs)) ⊗ similar (arrayt, map (arg2, axs))
170212end
171213
172- function Base. permutedims (a:: KroneckerArray , perm)
214+ function Base. permutedims (a:: AbstractKroneckerArray , perm)
173215 return permutedims (arg1 (a), perm) ⊗ permutedims (arg2 (a), perm)
174216end
175217using DerivableInterfaces: DerivableInterfaces, permuteddims
176- function DerivableInterfaces. permuteddims (a:: KroneckerArray , perm)
218+ function DerivableInterfaces. permuteddims (a:: AbstractKroneckerArray , perm)
177219 return permuteddims (arg1 (a), perm) ⊗ permuteddims (arg2 (a), perm)
178220end
179221
180- function Base. permutedims! (dest:: KroneckerArray , src:: KroneckerArray , perm)
222+ function Base. permutedims! (dest:: AbstractKroneckerArray , src:: AbstractKroneckerArray , perm)
181223 return mutate_active_args! (
182224 (dest, src) -> permutedims! (dest, src, perm), Base. Fix2 (permutedims, perm), dest, src
183225 )
@@ -208,9 +250,10 @@ kron_nd(a1::AbstractMatrix, a2::AbstractMatrix) = kron(a1, a2)
208250kron_nd (a1:: AbstractVector , a2:: AbstractVector ) = kron (a1, a2)
209251
210252# Eagerly collect arguments to make more general on GPU.
211- Base. collect (a:: KroneckerArray ) = kron_nd (collect (arg1 (a)), collect (arg2 (a)))
253+ Base. collect (a:: AbstractKroneckerArray ) = kron_nd (collect (arg1 (a)), collect (arg2 (a)))
254+ Base. collect (T:: Type , a:: AbstractKroneckerArray ) = kron_nd (collect (T, arg1 (a)), collect (T, arg2 (a)))
212255
213- function Base. zero (a:: KroneckerArray )
256+ function Base. zero (a:: AbstractKroneckerArray )
214257 return if isactive (arg1 (a)) == isactive (arg2 (a))
215258 # TODO : Maybe this should zero both arguments?
216259 # This is how `a * false` would behave.
@@ -223,35 +266,28 @@ function Base.zero(a::KroneckerArray)
223266end
224267
225268using DerivableInterfaces: DerivableInterfaces, zero!
226- function DerivableInterfaces. zero! (a:: KroneckerArray )
269+ function DerivableInterfaces. zero! (a:: AbstractKroneckerArray )
227270 (isactive (arg1 (a)) || isactive (arg2 (a))) ||
228271 error (" Can't mutate immutable KroneckerArray." )
229272 isactive (arg1 (a)) && zero! (arg1 (a))
230273 isactive (arg2 (a)) && zero! (arg2 (a))
231274 return a
232275end
233276
234- function Base. Array {T, N} (a:: KroneckerArray {S, N} ) where {T, S, N}
235- return convert (Array{T, N}, collect (a))
277+ function Base. Array {T, N} (a:: AbstractKroneckerArray {S, N} ) where {T, S, N}
278+ return convert (Array{T, N}, collect (T, a))
236279end
237280
238- function Base. size (a:: KroneckerArray )
239- return ntuple (dim -> size (arg1 (a), dim) * size (arg2 (a), dim), ndims (a))
240- end
281+ Base. size (a:: AbstractKroneckerArray ) = size (arg1 (a)) .* size (arg2 (a))
241282
242- function Base. axes (a:: KroneckerArray )
283+ function Base. axes (a:: AbstractKroneckerArray )
243284 return ntuple (ndims (a)) do dim
244285 return CartesianProductUnitRange (
245286 axes (arg1 (a), dim) × axes (arg2 (a), dim), Base. OneTo (size (a, dim))
246287 )
247288 end
248289end
249290
250- arguments (a:: KroneckerArray ) = (arg1 (a), arg2 (a))
251- arguments (a:: KroneckerArray , n:: Int ) = arguments (a)[n]
252- argument_types (a:: KroneckerArray ) = argument_types (typeof (a))
253- argument_types (:: Type{<:KroneckerArray{<:Any, <:Any, A1, A2}} ) where {A1, A2} = (A1, A2)
254-
255291function Base. print_array (io:: IO , a:: KroneckerArray )
256292 Base. print_array (io, arg1 (a))
257293 println (io, " \n ⊗" )
@@ -285,45 +321,48 @@ end
285321
286322# Indexing logic.
287323function Base. to_indices (
288- a:: KroneckerArray , inds, I:: Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
324+ a:: AbstractKroneckerArray , inds, I:: Tuple{Union{CartesianPair, CartesianProduct}, Vararg}
289325 )
290326 I1 = to_indices (arg1 (a), arg1 .(inds), arg1 .(I))
291327 I2 = to_indices (arg2 (a), arg2 .(inds), arg2 .(I))
292328 return I1 .× I2
293329end
294330
295331function Base. getindex (
296- a:: KroneckerArray {<:Any, N} , I:: Vararg{Union{CartesianPair, CartesianProduct}, N}
332+ a:: AbstractKroneckerArray {<:Any, N} , I:: Vararg{Union{CartesianPair, CartesianProduct}, N}
297333 ) where {N}
298334 I′ = to_indices (a, I)
299335 return arg1 (a)[arg1 .(I′)... ] ⊗ arg2 (a)[arg2 .(I′)... ]
300336end
301337# Fix ambigiuity error.
302- Base. getindex (a:: KroneckerArray {<:Any, 0} ) = arg1 (a)[] * arg2 (a)[]
338+ Base. getindex (a:: AbstractKroneckerArray {<:Any, 0} ) = arg1 (a)[] * arg2 (a)[]
303339
304340arg1 (:: Colon ) = (:)
305341arg2 (:: Colon ) = (:)
306342arg1 (:: Base.Slice ) = (:)
307343arg2 (:: Base.Slice ) = (:)
308344function Base. view (
309- a:: KroneckerArray {<:Any, N} ,
345+ a:: AbstractKroneckerArray {<:Any, N} ,
310346 I:: Vararg{Union{CartesianProduct, CartesianProductUnitRange, Base.Slice, Colon}, N} ,
311347 ) where {N}
312348 return view (arg1 (a), arg1 .(I)... ) ⊗ view (arg2 (a), arg2 .(I)... )
313349end
314- function Base. view (a:: KroneckerArray {<:Any, N} , I:: Vararg{CartesianPair, N} ) where {N}
350+ function Base. view (a:: AbstractKroneckerArray {<:Any, N} , I:: Vararg{CartesianPair, N} ) where {N}
315351 return view (arg1 (a), arg1 .(I)... ) ⊗ view (arg2 (a), arg2 .(I)... )
316352end
317353# Fix ambigiuity error.
318- Base. view (a:: KroneckerArray {<:Any, 0} ) = view (arg1 (a)) ⊗ view (arg2 (a))
354+ Base. view (a:: AbstractKroneckerArray {<:Any, 0} ) = view (arg1 (a)) ⊗ view (arg2 (a))
319355
320- function Base.:(== )(a:: KroneckerArray , b:: KroneckerArray )
356+ function Base.:(== )(a:: AbstractKroneckerArray , b:: AbstractKroneckerArray )
321357 return arg1 (a) == arg1 (b) && arg2 (a) == arg2 (b)
322358end
323- function Base. isapprox (a:: KroneckerArray , b:: KroneckerArray ; kwargs... )
359+
360+ # TODO : this definition doesn't fully retain the original meaning:
361+ # ‖a - b‖ < atol could be true even if the following check isn't
362+ function Base. isapprox (a:: AbstractKroneckerArray , b:: AbstractKroneckerArray ; kwargs... )
324363 return isapprox (arg1 (a), arg1 (b); kwargs... ) && isapprox (arg2 (a), arg2 (b); kwargs... )
325364end
326- function Base. iszero (a:: KroneckerArray )
365+ function Base. iszero (a:: AbstractKroneckerArray )
327366 return iszero (arg1 (a)) || iszero (arg2 (a))
328367end
329368function Base. isreal (a:: KroneckerArray )
@@ -335,17 +374,17 @@ function DiagonalArrays.diagonal(a::KroneckerArray)
335374 return diagonal (arg1 (a)) ⊗ diagonal (arg2 (a))
336375end
337376
338- Base. real (a:: KroneckerArray {<:Real} ) = a
339- function Base. real (a:: KroneckerArray )
377+ Base. real (a:: AbstractKroneckerArray {<:Real} ) = a
378+ function Base. real (a:: AbstractKroneckerArray )
340379 if iszero (imag (arg1 (a))) || iszero (imag (arg2 (a)))
341380 return real (arg1 (a)) ⊗ real (arg2 (a))
342381 elseif iszero (real (arg1 (a))) || iszero (real (arg2 (a)))
343382 return - (imag (arg1 (a)) ⊗ imag (arg2 (a)))
344383 end
345384 return real (arg1 (a)) ⊗ real (arg2 (a)) - imag (arg1 (a)) ⊗ imag (arg2 (a))
346385end
347- Base. imag (a:: KroneckerArray {<:Real} ) = zero (a)
348- function Base. imag (a:: KroneckerArray )
386+ Base. imag (a:: AbstractKroneckerArray {<:Real} ) = zero (a)
387+ function Base. imag (a:: AbstractKroneckerArray )
349388 if iszero (imag (arg1 (a))) || iszero (real (arg2 (a)))
350389 return real (arg1 (a)) ⊗ imag (arg2 (a))
351390 elseif iszero (real (arg1 (a))) || iszero (imag (arg2 (a)))
@@ -356,14 +395,14 @@ end
356395
357396for f in [:transpose , :adjoint , :inv ]
358397 @eval begin
359- function Base. $f (a:: KroneckerArray )
398+ function Base. $f (a:: AbstractKroneckerArray )
360399 return $ f (arg1 (a)) ⊗ $ f (arg2 (a))
361400 end
362401 end
363402end
364403
365404function Base. reshape (
366- a:: KroneckerArray , ax:: Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
405+ a:: AbstractKroneckerArray , ax:: Tuple{CartesianProductUnitRange, Vararg{CartesianProductUnitRange}}
367406 )
368407 return reshape (arg1 (a), map (arg1, ax)) ⊗ reshape (arg2 (a), map (arg2, ax))
369408end
383422function KroneckerStyle {N, A1, A2} (v:: Val{M} ) where {N, A1, A2, M}
384423 return KroneckerStyle {M, typeof(A1)(v), typeof(A2)(v)} ()
385424end
386- function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any, N, A1, A2}} ) where {N, A1, A2 }
387- return KroneckerStyle {N } (BroadcastStyle (A1) , BroadcastStyle (A2 ))
425+ function Base. BroadcastStyle (:: Type{T} ) where {T <: AbstractKroneckerArray }
426+ return KroneckerStyle {ndims(T) } (BroadcastStyle (arg1type (T)) , BroadcastStyle (arg2type (T) ))
388427end
389428function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
390429 style_a = BroadcastStyle (arg1 (style1), arg1 (style2))
@@ -403,10 +442,10 @@ function Base.similar(
403442 return a ⊗ b
404443end
405444
406- function Base. map (f, a1:: KroneckerArray , a_rest:: KroneckerArray ... )
445+ function Base. map (f, a1:: AbstractKroneckerArray , a_rest:: AbstractKroneckerArray ... )
407446 return Broadcast. broadcast_preserving_zero_d (f, a1, a_rest... )
408447end
409- function Base. map! (f, dest:: KroneckerArray , a1:: KroneckerArray , a_rest:: KroneckerArray ... )
448+ function Base. map! (f, dest:: AbstractKroneckerArray , a1:: AbstractKroneckerArray , a_rest:: AbstractKroneckerArray ... )
410449 dest .= f .(a1, a_rest... )
411450 return dest
412451end
438477function Base. copy (a:: Summed{<:KroneckerStyle} )
439478 return copy (KroneckerBroadcast (a))
440479end
441- function Base. copyto! (dest:: KroneckerArray , a:: Summed{<:KroneckerStyle} )
480+ function Base. copyto! (dest:: AbstractKroneckerArray , a:: Summed{<:KroneckerStyle} )
442481 return copyto! (dest, KroneckerBroadcast (a))
443482end
444483
0 commit comments