Skip to content

Commit 01f7178

Browse files
committed
Support cacheddata on BroadcastArray and ApplyArray
1 parent 6905338 commit 01f7178

File tree

6 files changed

+255
-12
lines changed

6 files changed

+255
-12
lines changed

src/cache.jl

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ cacheddata(A::Transpose) = transpose(cacheddata(parent(A)))
6262

6363
maybe_cacheddata(A::AbstractCachedArray) = cacheddata(A)
6464
maybe_cacheddata(A::SubArray{<:Any,N,<:AbstractCachedArray}) where N = cacheddata(A)
65+
maybe_cacheddata(A::Adjoint) = adjoint(maybe_cacheddata(parent(A)))
66+
maybe_cacheddata(A::Transpose) = transpose(maybe_cacheddata(parent(A)))
6567
maybe_cacheddata(A) = A # no-op
6668

6769
convert(::Type{AbstractArray{T}}, S::CachedArray{T}) where T = S
@@ -228,9 +230,13 @@ resizedata!(B::CachedArray, mn...) = resizedata!(MemoryLayout(B.data), MemoryLay
228230
resizedata!(B::AbstractCachedArray, mn...) = resizedata!(MemoryLayout(B.data), UnknownLayout(), B, mn...)
229231
resizedata!(A, mn...) = A # don't do anything
230232
function resizedata!(A::AdjOrTrans, m, n)
231-
m 0 || resizedata!(parent(A), n)
233+
resizedata!(parent(A), n, m)
232234
A
233235
end
236+
function resizedata!(A::AdjOrTrans{<:Any, <:AbstractVector}, m, n)
237+
resizedata!(parent(A), n)
238+
A
239+
end
234240

235241
function cache_filldata!(B, inds...)
236242
B.data[inds...] .= view(B.array,inds...)
@@ -611,4 +617,85 @@ CachedArray(data::AbstractMatrix, array::AbstractQ) = CachedArray(data, array, s
611617
CachedArray(data::AbstractMatrix{T}, array::AbstractQ{T}, datasize::NTuple{2,Int}) where T =
612618
CachedMatrix{T,typeof(data),typeof(array)}(data, array, datasize)
613619

614-
length(A::CachedMatrix{<:T,<:AbstractMatrix{T},<:AbstractQ{T}}) where T = prod(size(A.array))
620+
length(A::CachedMatrix{<:T,<:AbstractMatrix{T},<:AbstractQ{T}}) where T = prod(size(A.array))
621+
622+
##
623+
# Tuples of potential cached arrays
624+
# Assumes that the Tuples contain arrays of the same shape
625+
# Some of the complication here is in making sure that we can handle things like mixing vectors and matrices that have one column
626+
##
627+
@inline _tuple_wrap(x::Int) = (x,) # Some CachedArrays mistakenly use Int instead of Tuple for vector sizes (e.g. https://github.com/JuliaApproximation/SemiclassicalOrthogonalPolynomials.jl/blob/a29007f2815134180b8433fdab46a23acfcdcd01/src/neg1c.jl#L19 and https://github.com/DanielVandH/InfiniteRandomArrays.jl/blob/a859d565f5bd8278d3f3499cd26b62916b4867e0/src/vector.jl#L18, to name a couple)
628+
@inline _tuple_wrap(x::Tuple) = x
629+
@inline _datasize(x) = size(x)
630+
@inline _datasize(::Number) = (1,) # Numbers have size (), but we treat them as size (1,) for resizing purposes
631+
@inline _datasize(x::AbstractCachedArray) = x.datasize
632+
@inline _datasize(x::SubArray) = _datasize(parent(x))
633+
@inline _datasize(x::AdjOrTrans) = reverse(_datasize(parent(x)))
634+
@inline _datasize(x::AdjOrTrans{<:Any, <:AbstractVector}) = (sz = only(_datasize(parent(x))); (min(1, sz), sz))
635+
@inline _datasizes(::Tuple{}) = ()
636+
@inline _datasizes(t::Tuple) = (_datasize(first(t)), _datasizes(Base.tail(t))...) # equivalent to _datasize.(t)
637+
638+
@inline _has_vector(::Tuple{}) = false # Numbers have size ()
639+
@inline _has_vector(::Tuple{Any}) = true
640+
@inline _has_vector(::Tuple{Any, Vararg}) = false
641+
@inline has_vector(::Tuple{}) = false
642+
@inline has_vector(t::Tuple) = _has_vector(first(t)) || has_vector(Base.tail(t))
643+
@inline to_vector_size(sz::Tuple{Any}) = sz # Already 1-tuple
644+
@inline to_vector_size(sz::Tuple{Any, Vararg}) = (prod(sz),) # Multi-tuple -> 1-tuple
645+
646+
function max_datasize(sizes::Tuple)
647+
if has_vector(sizes)
648+
normalized = map(to_vector_size, sizes)
649+
return reduce(@inline((a, b) -> max.(a, b)), normalized)
650+
else
651+
return reduce(@inline((a, b) -> max.(a, b)), sizes)
652+
end
653+
end
654+
655+
@inline _expand_size(sz::Tuple{Any}, ::Tuple{Any}) = sz
656+
@inline _expand_size(sz::Tuple{Any}, ::Tuple{Any, Vararg}) = (only(sz), 1)
657+
@inline _expand_size(sz::Tuple{Any, Vararg}, ::Tuple{Any, Vararg}) = sz
658+
659+
@inline _effective_ndims(sz::Int) = 1
660+
@inline function _effective_ndims(sz::Tuple)
661+
length(sz) == 1 && return 1
662+
length(sz) == 2 && minimum(sz) <= 1 && return 1
663+
return length(sz)
664+
end
665+
666+
@inline _all_same_ndims(::Tuple{}) = true
667+
@inline _all_same_ndims(t::Tuple{Any}) = true
668+
@inline function _all_same_ndims(t::Tuple)
669+
first_sz = _arraysize(first(t))
670+
first_ndim = _effective_ndims(first_sz)
671+
_check_same_ndims_args(first_ndim, Base.tail(t))
672+
end
673+
674+
@inline _check_same_ndims_args(::Int, ::Tuple{}) = true
675+
@inline function _check_same_ndims_args(expected_ndim::Int, t::Tuple)
676+
sz = _arraysize(first(t))
677+
_effective_ndims(sz) == expected_ndim && _check_same_ndims_args(expected_ndim, Base.tail(t))
678+
end
679+
680+
@inline _arraysize(x::Number) = (1,)
681+
@inline _arraysize(x) = size(x)
682+
683+
function conforming_resize!(args::Tuple)
684+
isempty(args) && return args
685+
if !_all_same_ndims(args)
686+
throw(ArgumentError("Cannot conform arrays with incompatible dimensions: $(map(_arraysize, args))"))
687+
end
688+
sizes = _datasizes(args)
689+
sz = max_datasize(sizes)
690+
_resize_each!(args, sizes, sz)
691+
return args
692+
end
693+
694+
@inline _resize_each!(::Tuple{}, ::Tuple{}, sz) = nothing
695+
@inline function _resize_each!(args::Tuple, sizes::Tuple, sz)
696+
arg = first(args)
697+
orig_sz = first(sizes)
698+
target_sz = _expand_size(sz, orig_sz)
699+
resizedata!(arg, target_sz...)
700+
_resize_each!(Base.tail(args), Base.tail(sizes), sz)
701+
end

src/lazyapplying.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,11 @@ AbstractArray{T,N}(A::ApplyArray{<:Any,N}) where {T,N} = ApplyArray{T,N}(A.f, ma
228228

229229
@inline applied_axes(f, args...) = map(oneto, applied_size(f, args...))
230230

231-
231+
function cacheddata(A::ApplyArray{T, N, F}) where {T, N, F}
232+
args = arguments(A)
233+
conforming_resize!(args)
234+
ApplyArray{T, N}(A.f, map(maybe_cacheddata, args)...)
235+
end
232236

233237
# immutable arrays don't need to copy.
234238
# Some special cases like vcat overload setindex! and therefore
@@ -329,9 +333,6 @@ function show(io::IO, A::Applied)
329333
print(io, ')')
330334
end
331335

332-
# BroadcastStyle(::Type{<:LinearAlgebra.QRCompactWYQ}) = DefaultArrayStyle{2}()
333-
# BroadcastStyle(::Type{<:LinearAlgebra.AdjointQ}) = DefaultArrayStyle{2}()
334-
335336
applybroadcaststyle(::Type{<:AbstractArray{<:Any,N}}, _2) where N = DefaultArrayStyle{N}()
336337
applybroadcaststyle(::Type{<:AbstractArray{<:Any,N}}, ::AbstractLazyLayout) where N = LazyArrayStyle{N}()
337338
applybroadcaststyle(::Type{<:ApplyArray{<:Any,N,<:Any,Args}}, ::AbstractLazyLayout) where {N,Args<:Tuple} = result_style(LazyArrayStyle{N}(), tuple_type_broadcastlayout(Args))

src/lazybroadcasting.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ sub_materialize(::BroadcastLayout, A) = converteltype(eltype(A), sub_materialize
122122

123123
copy(bc::Broadcasted{<:AbstractLazyArrayStyle}) = BroadcastArray(bc)
124124

125+
function cacheddata(A::BroadcastArray{T, N, F}) where {T, N, F}
126+
args = arguments(A)
127+
conforming_resize!(args)
128+
BroadcastArray{T, N}(A.f, map(maybe_cacheddata, args)...)
129+
end
130+
125131
# BroadcastArray are immutable
126132
copy(bc::BroadcastArray) = bc
127133
map(::typeof(copy), bc::BroadcastArray) = bc

src/padded.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,6 @@ _vcat_resizedata!(_, B, m...) = B # by default we can't resize
149149

150150
resizedata!(B::Vcat, m...) = _vcat_resizedata!(MemoryLayout(B), B, m...)
151151

152-
cacheddata(B::Vcat) = Vcat(map(maybe_cacheddata, arguments(B))...)
153-
154152
function ==(A::CachedVector{<:Any,<:Any,<:Zeros}, B::CachedVector{<:Any,<:Any,<:Zeros})
155153
length(A) == length(B) || return false
156154
n = max(A.datasize[1], B.datasize[1])

test/cachetests.jl

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,11 @@ using Infinities
552552
@test maybe_cacheddata(B) === cacheddata(B)
553553
C = [1, 2, 3]
554554
@test maybe_cacheddata(C) === C
555+
556+
v = cache(1:10)'
557+
@test maybe_cacheddata(v) === cacheddata(parent(v))'
558+
v = transpose(cache(1:10))
559+
@test maybe_cacheddata(v) === transpose(cacheddata(parent(v)))
555560
end
556561

557562
@testset "Missing BroadcastStyles/MemoryLayouts/cacheddata with CachedArrayStyles" begin
@@ -587,6 +592,149 @@ using Infinities
587592
@test BroadcastStyle(typeof(Hcat(d, (1:10)))) == CachedArrayStyle{2}()
588593
@test BroadcastStyle(typeof(Hcat((1:10), d))) == CachedArrayStyle{2}()
589594
end
590-
end
595+
596+
@testset "Enforce same-size arguments for cacheddata" begin
597+
@testset "max_datasize" begin
598+
@test LazyArrays._datasize(1:10) == (10,)
599+
x = cache(1:10)
600+
@test LazyArrays._datasize(x) == (0,)
601+
resizedata!(x, 3)
602+
@test LazyArrays._datasize(x) == (3,)
603+
@test LazyArrays._datasize(view(x, 3:5)) == (3,)
604+
@test LazyArrays._datasize(transpose(x)) == (1, 3)
605+
606+
arr = (1:10, cache(1:10))
607+
@test LazyArrays.max_datasize(LazyArrays._datasizes(arr)) == (10,)
608+
609+
arr = (cache(1:10), cache(1:10))
610+
resizedata!(arr[1], 3)
611+
@test LazyArrays.max_datasize(LazyArrays._datasizes(arr)) == (3,)
612+
resizedata!(arr[2], 7)
613+
@test LazyArrays.max_datasize(LazyArrays._datasizes(arr)) == (7,)
614+
615+
arr = (rand(10, 5), LazyArrays.CachedArray(rand(10, 5)))
616+
@test LazyArrays.max_datasize(LazyArrays._datasizes(arr)) == (10, 5)
617+
618+
arr = (LazyArrays.CachedArray(rand(10, 5)), LazyArrays.CachedArray(rand(10, 5)))
619+
@test LazyArrays.max_datasize(LazyArrays._datasizes(arr)) == (0, 0)
620+
resizedata!(arr[1], 3, 4)
621+
@test LazyArrays.max_datasize(LazyArrays._datasizes(arr)) == (3, 4)
622+
resizedata!(arr[2], 2, 5);
623+
@test LazyArrays.max_datasize(LazyArrays._datasizes(arr)) == (3, 5)
624+
625+
arr = (1, [1, 2])
626+
@test LazyArrays._datasize(1) == (1, )
627+
@test LazyArrays.max_datasize(LazyArrays._datasizes(arr)) == (2,)
628+
end
629+
630+
@testset "conforming_resize!" begin
631+
args = (cache(1:10), cache(1:10));
632+
LazyArrays.conforming_resize!(args);
633+
@test LazyArrays._datasizes(args) == ((0,), (0,));
634+
LazyArrays.resizedata!(args[2], 4);
635+
LazyArrays.conforming_resize!(args);
636+
@test LazyArrays._datasizes(args) == ((4,), (4,));
637+
638+
args = (1:10, cache(1:10));
639+
LazyArrays.conforming_resize!(args);
640+
@test LazyArrays._datasizes(args) == ((10,), (10,));
641+
642+
args = (cache(1:10)', cache(1:10)');
643+
LazyArrays.conforming_resize!(args);
644+
@test LazyArrays._datasizes(args) == ((0, 0), (0, 0));
645+
LazyArrays.resizedata!(args[1], 1, 3);
646+
LazyArrays.conforming_resize!(args);
647+
@test LazyArrays._datasizes(args) == ((1, 3), (1, 3));
648+
649+
args = (cache(1:10)', LazyArrays.CachedArray(rand(1, 10)));
650+
@test LazyArrays._datasizes(args) == ((0, 0), (0, 0));
651+
LazyArrays.conforming_resize!(args);
652+
@test LazyArrays._datasizes(args) == ((0, 0), (0, 0));
653+
LazyArrays.resizedata!(args[1], 1, 4);
654+
LazyArrays.conforming_resize!(args);
655+
@test LazyArrays._datasizes(args) == ((1, 4), (1, 4));
656+
LazyArrays.resizedata!(args[2], 1, 6)
657+
LazyArrays.conforming_resize!(args);
658+
@test LazyArrays._datasizes(args) == ((1, 6), (1, 6));
659+
660+
args = (cache(1:10), LazyArrays.CachedArray(rand(1, 10))');
661+
LazyArrays.resizedata!(args[1], 4);
662+
LazyArrays.conforming_resize!(args);
663+
@test LazyArrays._datasizes(args) == ((4,), (4, 1));
664+
665+
args = (cache(1:10), LazyArrays.CachedArray(rand(1, 10))', 1:10);
666+
LazyArrays.conforming_resize!(args);
667+
@test LazyArrays._datasizes(args) == ((10,), (10, 1), (10,));
668+
669+
args = (1, cache(1:2));
670+
LazyArrays.conforming_resize!(args);
671+
@test LazyArrays._datasizes(args) == ((1,), (1,)) # because scalars are treated as size (1,)
672+
673+
args = (rand(10, 10), LazyArrays.CachedArray(rand(10, 10)));
674+
LazyArrays.conforming_resize!(args);
675+
@test LazyArrays._datasizes(args) == ((10, 10), (10, 10));
676+
677+
args = (LazyArrays.CachedArray(rand(10, 10)), LazyArrays.CachedArray(rand(10, 10)));
678+
LazyArrays.conforming_resize!(args);
679+
@test LazyArrays._datasizes(args) == ((0, 0), (0, 0));
680+
LazyArrays.resizedata!(args[1], 3, 4);
681+
LazyArrays.conforming_resize!(args);
682+
@test LazyArrays._datasizes(args) == ((3, 4), (3, 4));
683+
LazyArrays.resizedata!(args[2], 5, 6);
684+
LazyArrays.conforming_resize!(args);
685+
@test LazyArrays._datasizes(args) == ((5, 6), (5, 6));
686+
687+
@testset "conforming_resize! dimension mismatch" begin
688+
args = (cache(1:10), cache(1:10))
689+
@test LazyArrays.conforming_resize!(args) === args
690+
691+
args = (cache(1:10), LazyArrays.CachedArray(rand(10, 5)))
692+
@test_throws ArgumentError LazyArrays.conforming_resize!(args)
693+
694+
args = (LazyArrays.CachedArray(rand(3, 4)), LazyArrays.CachedArray(rand(5, 2)))
695+
@test LazyArrays.conforming_resize!(args) === args
696+
697+
args = (1, cache(1:10))
698+
@test LazyArrays.conforming_resize!(args) === args
699+
700+
args = (cache(1:10), LazyArrays.CachedArray(rand(2, 3)), reshape(cache(1:8), 2, 2, 2))
701+
@test_throws ArgumentError LazyArrays.conforming_resize!(args)
702+
703+
args = (reshape(cache(1:8), 2, 2, 2), reshape(cache(1:27), 3, 3, 3))
704+
@test LazyArrays.conforming_resize!(args) === args
705+
706+
args = ()
707+
@test LazyArrays.conforming_resize!(args) === args
708+
709+
args = (cache(1:10),)
710+
@test LazyArrays.conforming_resize!(args) === args
711+
end
712+
end
713+
end
714+
715+
@testset "cacheddata for ApplyArray and BroadcastArray" begin
716+
x = ApplyArray(+, 1:10, cache(11:20));
717+
@test cacheddata(x) == ApplyArray(+, 1:10, 11:20)
718+
@test Base.Broadcast.BroadcastStyle(typeof(cacheddata(x))) == LazyArrays.LazyArrayStyle{1}()
719+
720+
x = BroadcastVector(*, 1:10, cache(1:10));
721+
@test cacheddata(x) == BroadcastVector(*, 1:10, 1:10)
722+
@test Base.Broadcast.BroadcastStyle(typeof(cacheddata(x))) == LazyArrays.LazyArrayStyle{1}()
723+
724+
x = ApplyArray(+, cache(1:10), cache(11:20));
725+
@test cacheddata(x) == ApplyArray(+, 1:0, 1:0)
726+
@test Base.Broadcast.BroadcastStyle(typeof(cacheddata(x))) == LazyArrays.LazyArrayStyle{1}()
727+
LazyArrays.resizedata!(x.args[1], 3)
728+
@test cacheddata(x) == ApplyArray(+, 1:3, 11:13)
729+
@test Base.Broadcast.BroadcastStyle(typeof(cacheddata(x))) == LazyArrays.LazyArrayStyle{1}()
730+
731+
x = BroadcastVector(*, cache(1:10), cache(11:20));
732+
@test cacheddata(x) == BroadcastVector(*, 1:0, 1:0)
733+
@test Base.Broadcast.BroadcastStyle(typeof(cacheddata(x))) == LazyArrays.LazyArrayStyle{1}()
734+
LazyArrays.resizedata!(x.args[1], 4)
735+
@test cacheddata(x) == BroadcastVector(*, 1:4, 11:14)
736+
@test Base.Broadcast.BroadcastStyle(typeof(cacheddata(x))) == LazyArrays.LazyArrayStyle{1}()
737+
end
738+
end
591739

592740
end # module

test/concattests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,14 @@ import Base.Broadcast: BroadcastStyle
224224
end
225225

226226
@testset "cacheddata" begin
227-
v = Vcat(1, cache(1:2))
228-
@test @inferred(cacheddata(v)) == [1]
229-
resizedata!(v, 2)
227+
v = Vcat(1, cache(1:2));
230228
@test @inferred(cacheddata(v)) == [1, 1]
229+
resizedata!(v, 3)
230+
@test @inferred(cacheddata(v)) == [1, 1, 2]
231231
@test cacheddata(v) isa Vcat
232+
233+
v = Vcat((1:10)', cache(11:20)');
234+
@test @inferred(cacheddata(v)) == Vcat((1:10)', (11:20)')
232235
end
233236

234237
p = Vcat([1,2], Zeros(4));

0 commit comments

Comments
 (0)