diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 2bad7d92..63149408 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -31,6 +31,7 @@ import ConstructionBase using ConstructionBase: constructorof using IntervalSets +import StaticArrays using StaticArrays: StaticArray, StaticVector, StaticMatrix, SArray, SVector, SMatrix, SOneTo diff --git a/src/combinators/implicitlymapped.jl b/src/combinators/implicitlymapped.jl index 964ea466..3966b10a 100644 --- a/src/combinators/implicitlymapped.jl +++ b/src/combinators/implicitlymapped.jl @@ -179,13 +179,13 @@ struct TakeAny{T<:IntegerLike} n::T end -_takeany_range(f::TakeAny, idxs) = first(idxs):first(idxs)+dynamic(f.n)-1 +_takeany_range(f::TakeAny, idxs) = first(idxs):(first(idxs)+dynamic(f.n)-1) @inline _takeany_range(f::TakeAny, ::OneTo) = OneTo(dynamic(f.n)) @inline _takeany_range(::TakeAny{<:Static.StaticInteger{N}}, ::OneTo) where {N} = SOneTo(N) @inline _takeany_range(::TakeAny{<:Static.StaticInteger{N}}, ::SOneTo) where {N} = SOneTo(N) -@inline (f::TakeAny)(xs::Tuple) = xs[begin:begin+f.n-1] +@inline (f::TakeAny)(xs::Tuple) = xs[begin:(begin+f.n-1)] @inline (f::TakeAny)(xs::AbstractVector) = xs[_takeany_range(f, eachindex(xs))] function (f::TakeAny)(xs) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index e6397c3f..62b4336f 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -11,14 +11,37 @@ the product determines the dimensionality of the resulting support. Note that power measures are only well-defined for integer powers. The nth power of a measure μ can be written μ^n. + +See also [`pwr_base`](@ref), [`pwr_axes`](@ref) and [`pwr_size`](@ref). """ struct PowerMeasure{M,A} <: AbstractProductMeasure parent::M axes::A end -maybestatic_length(μ::PowerMeasure) = prod(maybestatic_size(μ)) -maybestatic_size(μ::PowerMeasure) = map(maybestatic_length, μ.axes) +maybestatic_length(μ::PowerMeasure) = size2length(maybestatic_size(μ)) +maybestatic_size(μ::PowerMeasure) = axes2size(μ.axes) + +""" + MeasureBase.pwr_base(μ::PowerMeasure) + +Returns `ν` for `μ = ν^axs` +""" +@inline pwr_base(μ::PowerMeasure) = μ.parent + +""" + MeasureBase.pwr_axes(μ::PowerMeasure) + +Returns `axs` for `μ = ν^axs`, `axs` being a tuple of integer ranges. +""" +@inline pwr_axes(μ::PowerMeasure) = μ.axes + +""" + MeasureBase.pwr_size(μ::PowerMeasure) + +Returns `sz` for `μ = ν^sz`, `sz` being a tuple of integers. +""" +@inline pwr_size(μ::PowerMeasure) = axes2size(μ.axes) function Pretty.tile(μ::PowerMeasure) sz = length.(μ.axes) @@ -30,7 +53,7 @@ end # ToDo: Make rand return static arrays for statically-sized power measures. function _cartidxs(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} - CartesianIndices(map(_dynamic, axs)) + CartesianIndices(map(asnonstatic, axs)) end function Base.rand( @@ -38,22 +61,21 @@ function Base.rand( ::Type{T}, d::PowerMeasure{M}, ) where {T,M<:AbstractMeasure} - map(_cartidxs(d.axes)) do _ - rand(rng, T, d.parent) + axs, base_d = pwr_axes(d), pwr_base(d) + map(_cartidxs(axs)) do _ + rand(rng, T, base_d) end end function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T} - map(_cartidxs(d.axes)) do _ - rand(rng, d.parent) + axs, base_d = pwr_axes(d), pwr_base(d) + map(_cartidxs(axs)) do _ + rand(rng, base_d) end end -@inline _pm_axes(sz::Tuple{Vararg{IntegerLike,N}}) where {N} = map(one_to, sz) -@inline _pm_axes(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} = axs - @inline function powermeasure(x::T, sz::Tuple{Vararg{Any,N}}) where {T,N} - PowerMeasure(x, _pm_axes(sz)) + PowerMeasure(x, asaxes(sz)) end marginals(d::PowerMeasure) = fill_with(d.parent, d.axes) @@ -80,13 +102,22 @@ end for func in [:logdensityof, :logdensity_def] @eval @inline function $func(d::PowerMeasure{M}, x) where {M} - parent = d.parent - sum(x) do xj - $func(parent, xj) + parent_m = d.parent + sz_parent = axes2size(d.axes) + sz_x = maybestatic_size(x) + if sz_parent != sz_x + throw(ArgumentError("Size of variate doesn't match size of power measure")) + end + R = infer_logdensity_type($func, parent_m, eltype(x)) + if isempty(x) + return zero(R)::R + else + # Need to convert since sum can turn static into dynamic values: + return convert(R, sum(Base.Fix1($func, parent_m), x))::R end end - @eval @inline function $func(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N} + @eval @inline function $func(d::PowerMeasure{<:Any,Tuple{<:StaticOneToLike}}, x) parent = d.parent sum(1:N) do j @inbounds $func(parent, x[j]) @@ -94,9 +125,9 @@ for func in [:logdensityof, :logdensity_def] end @eval @inline function $func( - d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}}, + ::PowerMeasure{<:Any,<:Tuple{Vararg{StaticOneToLike{0}}}}, x, - ) where {M,N} + ) static(0.0) end end @@ -117,15 +148,11 @@ end end end -@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes)) - -@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N} - static(0) -end +@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * size2length(axes2size(μ.axes)) @propagate_inbounds function checked_arg(μ::PowerMeasure, x::AbstractArray{<:Any}) @boundscheck begin - sz_μ = map(length, μ.axes) + sz_μ = pwr_size(μ) sz_x = size(x) if sz_μ != sz_x throw(ArgumentError("Size of variate doesn't match size of power measure")) @@ -144,7 +171,7 @@ logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0) # To avoid ambiguities function logdensity_def( - ::PowerMeasure{P,Tuple{Vararg{Static.SOneTo{0},N}}}, + ::PowerMeasure{P,<:Tuple{Vararg{StaticOneToLike{0},N}}}, x, ) where {P<:PrimitiveMeasure,N} static(0.0) diff --git a/src/density-core.jl b/src/density-core.jl index 6ac3d01e..f3b2db2b 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -149,13 +149,13 @@ end ℓ = logdensity_def(μs[$M], νs[$N], x) end - for i in 1:M-1 + for i in 1:(M-1) push!(q.args, :(Δℓ = logdensity_def(μs[$i], x))) # push!(q.args, :(println("Adding", Δℓ))) push!(q.args, :(ℓ += Δℓ)) end - for j in 1:N-1 + for j in 1:(N-1) push!(q.args, :(Δℓ = logdensity_def(νs[$j], x))) # push!(q.args, :(println("Subtracting", Δℓ))) push!(q.args, :(ℓ -= Δℓ)) diff --git a/src/interface.jl b/src/interface.jl index 18080ac7..4890ddd6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -110,7 +110,7 @@ function test_smf(μ, n = 100) @testset "smf($μ)" begin # Get `n` sorted uniforms in O(n) time p = rand(n) - p .+= 0:n-1 + p .+= 0:(n-1) p .*= inv(n) F(x) = smf(μ, x) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 833f280e..4b957651 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -57,7 +57,7 @@ end # Helpers for product transforms and similar: struct _TransportToStd{NU<:StdMeasure} <: Function end -_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) +(::_TransportToStd{NU})(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x) struct _TransportFromStd{MU<:StdMeasure} <: Function end _TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x) @@ -67,7 +67,7 @@ function _tuple_transport_def( μs::Tuple, xs::Tuple, ) where {NU<:StdMeasure} - reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes) + reshape(vcat(map(_TransportToStd{NU}(), μs, xs)...), ν.axes) end function transport_def( @@ -93,7 +93,7 @@ end function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike) N = map(getdof, μs) offs = _offset_cumsum(startidx, N...) - map((o, n) -> o:o+n-1, offs, N) + map((o, n) -> o:(o+n-1), offs, N) end function _tuple_transport_def( diff --git a/src/static.jl b/src/static.jl index da471b62..12db9585 100644 --- a/src/static.jl +++ b/src/static.jl @@ -1,3 +1,27 @@ +# A lots of this is about bridging Static and StaticArrays, both have their +# own SUnitRange and SOneTo. Also provides tools to control static vs dynamic +# array, size and axes handling. + +""" + MeasureBase.StaticUnitRange + +The MeasureBase default type for static unit ranges. +""" +const StaticUnitRange = @static if isdefined(StaticArrays, :SUnitRange) + # Unclear if StaticArrays.SUnitRange is part of StaticArrays stable API. + # Some packages use it, but let's be careful in case it disappears. + StaticArrays.SUnitRange +else + Static.SUnitRange +end + +""" + MeasureBase.StaticOneTo + +The MeasureBase default type for static one-based unit ranges. +""" +const StaticOneTo{T} = StaticArrays.SOneTo{T} + """ MeasureBase.IntegerLike @@ -5,6 +29,67 @@ Equivalent to `Union{Integer,Static.StaticInteger}`. """ const IntegerLike = Union{Integer,Static.StaticInteger} +""" + MeasureBase.SizeLike + +Something that can represent the size of a collection. +""" +const SizeLike = Union{Tuple{},Tuple{Vararg{IntegerLike}},StaticArrays.Size} + +""" + MeasureBase.StaticSizeLike + +Something that can represent the size of a statically sized collection. +""" +const StaticSizeLike = Union{Tuple{Vararg{StaticInteger}},StaticArrays.Size} + +""" + MeasureBase.AxesLike + +Something that can represent axes of a collection. +""" +const AxesLike = Union{Tuple{},Tuple{Vararg{AbstractVector{<:IntegerLike}}}} + +""" + MeasureBase.StaticAxesLike + +Something that can represent axes of a statically sized collection. +""" +@static if isdefined(StaticArrays, :SUnitRange) + const StaticAxesLike = Union{ + Tuple{Vararg{Union{StaticArrays.SOneTo,StaticArrays.SUnitRange,Static.SUnitRange}}}, + } +else + const StaticAxesLike = + Union{Tuple{Vararg{Union{StaticArrays.SOneTo,Static.SUnitRange}}}} +end + +""" + const OneToLike + +Alias for unit ranges that start at one. +""" +const OneToLike = Union{Base.OneTo,StaticArrays.SOneTo,Static.SOneTo} + +""" + const StaticOneToLike{N} + +A static unit range from one to N. +""" +const StaticOneToLike{N} = Union{StaticArrays.SOneTo{N},Static.SOneTo{N}} + +""" + const StaticUnitRangeLike + +A static unit range. +""" +@static if isdefined(StaticArrays, :SUnitRange) + const StaticUnitRangeLike = + Union{StaticArrays.SOneTo,StaticArrays.SUnitRange,Static.SUnitRange} +else + const StaticUnitRangeLike = Union{StaticArrays.SOneTo,Static.SUnitRange} +end + """ MeasureBase.one_to(n::IntegerLike) @@ -16,48 +101,265 @@ on the type of `n`. @inline one_to(n::Integer) = Base.OneTo(n) @inline one_to(::Static.StaticInteger{N}) where {N} = Static.SOneTo{N}() -_dynamic(x::Number) = dynamic(x) -_dynamic(::Static.SOneTo{N}) where {N} = Base.OneTo(N) -_dynamic(r::AbstractUnitRange) = minimum(r):maximum(r) +""" + MeasureBase.asnonstatic(x) + +Return a non-static equivalent of `x`. + +Defaults to `Static.dynamic(x)`. +""" +@inline asnonstatic(x::Number) = dynamic(x) +@inline asnonstatic(::Tuple{}) = () +@static if isdefined(StaticArrays, :SUnitRange) + @inline asnonstatic(r::StaticArrays.SUnitRange) = r[begin]:r[end] +end +@inline asnonstatic(r::AbstractUnitRange) = asnonstatic(r[begin]):asnonstatic(r[end]) +@inline asnonstatic(r::Base.OneTo) = Base.OneTo(asnonstatic(r.stop)) +@inline asnonstatic(::StaticOneToLike{N}) where {N} = Base.OneTo(N) +@inline asnonstatic(x::SizeLike) = map(asnonstatic, x) +@inline asnonstatic(::StaticArrays.Size{TPL}) where {TPL} = TPL +@inline asnonstatic(x::AxesLike) = map(asnonstatic, x) """ MeasureBase.fill_with(x, sz::NTuple{N,<:IntegerLike}) where N Creates an array of size `sz` filled with `x`. -Returns an instance of `FillArrays.Fill`. +The result will typically be either a `FillArrays.Fill` or a static array, """ function fill_with end -@inline function fill_with(x::T, sz::Tuple{Vararg{IntegerLike,N}}) where {T,N} - fill_with(x, map(one_to, sz)) +@inline fill_with(x::T, n::IntegerLike) where {T} = fill_with(x, (n,)) + +@inline fill_with(x::T, ::Tuple{}) where {T} = FillArrays.Fill(x) + +@inline fill_with(x, sz::SizeLike) = fill_with(x, size2axes(sz)) + +@inline function fill_with(x::T, sz::StaticSizeLike) where {T} + fill(x, staticarray_type(T, canonical_size(sz))) end -@inline function fill_with(x::T, axs::Tuple{Vararg{AbstractUnitRange,N}}) where {T,N} - # While `FillArrays.Fill` (mostly?) works with axes that are static unit - # ranges, some operations that automatic differentiation requires do fail - # on such instances of `Fill` (e.g. `reshape` from dynamic to static size). - # So need to use standard ranges for the axes for now: - dyn_axs = map(_dynamic, axs) +@inline function fill_with(x, axs::AxesLike) + dyn_axs = map(asnonstatic, axs) FillArrays.Fill(x, dyn_axs) end +# While `FillArrays.Fill` (mostly?) works with axes that are static unit +# ranges, some operations that automatic differentiation requires do fail +# on such instances of `Fill` (e.g. `reshape` from dynamic to static size). +# So need to build a filled static array: +@inline function fill_with(x::T, axs::Tuple{Vararg{StaticOneToLike}}) where {T} + sz = axes2size(axs) + fill(x, staticarray_type(T, sz)) +end + +""" + MeasureBase.staticarray_type(T, sz::StaticArrays.Size) + +Returns the type of a static array with element type `T` and size `sz`. +""" +function staticarray_type end + +@inline @generated function staticarray_type( + ::Type{T}, + ::StaticArrays.Size{sz}, +) where {T,sz} + N = length(sz) + len = prod(sz) + :(SArray{Tuple{$sz...},T,$N,$len}) +end + +""" + MeasureBase.maybestatic_reshape(A, sz) + +Reshapes array `A` to sizes `sz`. + +If `A` is a static array and `sz` is static, the result is a static array. +""" +function maybestatic_reshape end + +maybestatic_reshape(A, sz) = reshape(A, canonical_size(sz)) +function maybestatic_reshape(A, sz::StaticSizeLike) + StaticArrays.SArray(reshape(A, canonical_size(sz))) +end +function maybestatic_reshape(A::StaticArray, sz::Tuple{Vararg{StaticInteger}}) + staticarray_type(eltype(A), canonical_size(sz))(Tuple(A)) +end + """ - MeasureBase.maybestatic_length(x)::IntegerLike + MeasureBase.maybestatic_length(x) Returns the length of `x` as a dynamic or static integer. """ -maybestatic_length(x) = length(x) -maybestatic_length(x::AbstractUnitRange) = length(x) -function maybestatic_length( - ::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}, -) where {A,B} - StaticInt{B - A + 1}() +@inline maybestatic_length(::Number) = static(1) +@inline maybestatic_length(::Tuple{}) = static(0) +@inline maybestatic_length(::Tuple{Vararg{Any,N}}) where {N} = static(N) +@inline maybestatic_length(nt::NamedTuple) = maybestatic_length(values(nt)) +@inline maybestatic_length(A::AbstractArray) = size2length(maybestatic_size(A)) +@static if isdefined(StaticArrays, :SUnitRange) + @inline maybestatic_length(r::StaticArrays.SUnitRange) = + maybestatic_last(r) - maybestatic_first(r) + static(1) end +@inline maybestatic_length(r::AbstractUnitRange) = + maybestatic_last(r) - maybestatic_first(r) + static(1) +@inline maybestatic_length(r::Base.OneTo) = length(r) +@inline maybestatic_length(::StaticArrays.SOneTo{N}) where {N} = static(N) +@inline maybestatic_length(::Static.SOneTo{N}) where {N} = static(N) """ - MeasureBase.maybestatic_size(x)::Tuple{Vararg{IntegerLike}} + + MeasureBase.maybestatic_size(x) Returns the size of `x` as a tuple of dynamic or static integers. """ -maybestatic_size(x) = size(x) +@inline maybestatic_size(::Number) = () +@inline maybestatic_size(::Tuple{}) = + throw(ArgumentError("Cannot determine (maybe-static) size of empty tuple")) +@inline maybestatic_size(::Tuple{Vararg{Any,N}}) where {N} = StaticArrays.Size{(N,)}() +@inline maybestatic_size(nt::NamedTuple) = maybestatic_size(values(nt)) +@inline maybestatic_size(A::AbstractArray) = axes2size(maybestatic_axes(A)) +@inline maybestatic_size(A::StaticArray) = StaticArrays.Size(A) + +""" + MeasureBase.maybestatic_axes(x)::Tuple{Vararg{IntegerLike}} + +Returns the size of `x` as a tuple of dynamic or static integers. +""" +@inline maybestatic_axes(::Number) = () + +@inline maybestatic_axes(::Tuple{}) = (StaticOneTo(0),) +@inline maybestatic_axes(::Tuple{Vararg{Any,N}}) where {N} = (StaticOneTo(N),) +@inline maybestatic_axes(nt::NamedTuple) = maybestatic_axes(values(nt)) +@inline maybestatic_axes(::StaticOneToLike{N}) where {N} = (StaticOneTo(N),) +@static if isdefined(StaticArrays, :SUnitRange) + @inline maybestatic_axes(r::StaticArrays.SUnitRange) = axes(r) +end +@inline maybestatic_axes(r::Static.OptionallyStaticUnitRange) = canonical_axes(axes(r)) +@inline maybestatic_axes(r::AbstractUnitRange) = axes(r) +@inline maybestatic_axes(A::AbstractArray) = axes(A) +@inline maybestatic_axes(A::StaticArray) = axes(A) + +""" + MeasureBase.axes2size(x::Tuple) + MeasureBase.axes2size(x::StaticArrays.Size) + +Get a length from a size (tuple). +""" +@inline axes2size(::Tuple{}) = () +@inline axes2size(axs::Tuple) = canonical_size(map(maybestatic_length, axs)) + +"""map(maybestatic_length, axs) + MeasureBase.size2axes(sz::Tuple) + MeasureBase.size2axes(sz::StaticArrays.Size) + +Get one-based indexing axes from a size. +""" +@inline size2axes(::Tuple{}) = () +@inline size2axes(sz::Tuple) = canonical_axes(map(one_to, sz)) +@inline size2axes(::StaticArrays.Size{TPL}) where {TPL} = map(StaticOneTo, TPL) + +""" + MeasureBase.size2length(sz::Tuple) + MeasureBase.size2length(sz::StaticArrays.Size) + +Get a length from a size (tuple). +""" +@inline size2length(::Tuple{}) = static(1) +@inline size2length(sz::Tuple) = prod(sz) +@inline size2length(::StaticArrays.Size{TPL}) where {TPL} = static(prod(TPL)) + +""" + MeasureBase.asaxes(axs::AxesLike) + MeasureBase.asaxes(sz::SizeLike) + MeasureBase.asaxes(len::IntegerLike) + +Converts axes or a size or a length of a collection to axes. + +One-based indexing will be used if the indexing offset can't be inferred from +the given dimensions. +""" +@inline asaxes(::Tuple{}) = () +@inline asaxes(axs::AxesLike) = axs +@inline asaxes(sz::SizeLike) = size2axes(sz) +@inline asaxes(len::IntegerLike) = size2axes((len,)) + +""" + MeasureBase.maybestatic_eachindex(x) + +Returns the the index range of `x` as a dynamic or static integer range +""" +maybestatic_eachindex(::Tuple{}) = StaticOneTo(0) +maybestatic_eachindex(::Tuple{Vararg{Any,N}}) where {N} = StaticOneTo(N) +maybestatic_eachindex(nt::NamedTuple) = maybestatic_eachindex(values(nt)) +maybestatic_eachindex(x::AbstractArray) = canonical_indices(eachindex(x)) + +""" + MeasureBase.maybestatic_first(A) + +Returns the first element of `A` as a dynamic or static value. +""" +maybestatic_first(tpl::Tuple) = tpl[begin] +maybestatic_first(nt::NamedTuple) = nt[begin] +maybestatic_first(A::AbstractArray) = A[begin] +maybestatic_first(::StaticArrays.Size{tpl}) where {tpl} = static(tpl[begin]) +maybestatic_first(::StaticArrays.SOneTo{N}) where {N} = static(1) +@static if isdefined(StaticArrays, :SUnitRange) + maybestatic_first(::StaticArrays.SUnitRange{B,L}) where {B,L} = static(B) +end +function maybestatic_first( + ::Static.OptionallyStaticUnitRange{<:Static.StaticInteger{from},<:Static.StaticInteger}, +) where {from} + static(from) +end + +""" + MeasureBase.maybestatic_last(A) + +Returns the last element of `A` as a dynamic or static value. +""" +maybestatic_last(tpl::Tuple) = tpl[end] +maybestatic_last(nt::NamedTuple) = nt[end] +maybestatic_last(A::AbstractArray) = A[end] +maybestatic_last(::StaticArrays.Size{tpl}) where {tpl} = static(tpl[end]) +maybestatic_last(::StaticArrays.SOneTo{N}) where {N} = static(N) +@static if isdefined(StaticArrays, :SUnitRange) + maybestatic_last(::StaticArrays.SUnitRange{B,L}) where {B,L} = static(B + L - 1) +end +function maybestatic_last( + ::Static.OptionallyStaticUnitRange{<:Any,<:Static.StaticInteger{until}}, +) where {until} + static(until) +end + +""" + MeasureBase.canonical_indices(idxs::AbstractVector{<:IntegerLike}) + +Return the canonical representation of a collection axis indices. +""" +@inline canonical_indices(idxs::AbstractVector{<:IntegerLike}) = idxs +@inline canonical_indices(idxs::AbstractArray{<:CartesianIndex}) = idxs +@inline canonical_indices( + ::Static.OptionallyStaticUnitRange{<:StaticInteger{1},<:StaticInteger{N}}, +) where {N} = StaticArrays.SOneTo{N}() +@inline canonical_indices( + ::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}, +) where {A,B} = StaticUnitRange(A, B) +@inline canonical_indices( + r::Static.OptionallyStaticUnitRange{<:StaticInteger{1},<:Integer}, +) = Base.OneTo(last(r)) + +""" + MeasureBase.canonical_size(sz::SizeLike) + +Return the canonical representation of a collection size. +""" +@inline canonical_size(sz::SizeLike) = sz +@inline canonical_size(sz::Tuple{Vararg{Static.StaticInteger}}) = + StaticArrays.Size{map(dynamic, sz)}() + +""" + MeasureBase.canonical_axes(sz::SizeLike) + +Return the canonical representation collection axes. +""" +@inline canonical_axes(axs::AxesLike) = map(canonical_indices, axs) diff --git a/src/utils.jl b/src/utils.jl index 0ec81a50..5d05d8b1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -133,6 +133,11 @@ function infer_zero(f, args...) zero(typeintersect(AbstractFloat, inferred_type)) end +function infer_logdensity_type(f::F, ::M, ::Type{T}) where {F,M,T} + inferred_type = Core.Compiler.return_type(f, Tuple{M,T}) + return inferred_type +end + @inline function allequal(f, x::AbstractArray) val = f(first(x)) @simd for xj in x diff --git a/test/static.jl b/test/static.jl index f618124b..83092ec2 100644 --- a/test/static.jl +++ b/test/static.jl @@ -1,34 +1,325 @@ using Test import MeasureBase +using MeasureBase: + StaticUnitRange, + StaticOneTo, + IntegerLike, + SizeLike, + StaticSizeLike, + AxesLike, + StaticAxesLike, + OneToLike, + StaticOneToLike, + StaticUnitRangeLike, + one_to, + asnonstatic, + fill_with, + staticarray_type, + maybestatic_reshape, + maybestatic_length, + maybestatic_size, + maybestatic_axes, + axes2size, + size2axes, + size2length, + asaxes, + maybestatic_eachindex, + maybestatic_first, + maybestatic_last, + canonical_indices, + canonical_size, + canonical_axes import Static using Static: static +import StaticArrays import FillArrays @testset "static" begin - @test 2 isa MeasureBase.IntegerLike - @test static(2) isa MeasureBase.IntegerLike - @test true isa MeasureBase.IntegerLike - @test static(true) isa MeasureBase.IntegerLike - - @test @inferred(MeasureBase.one_to(7)) isa Base.OneTo - @test @inferred(MeasureBase.one_to(7)) == 1:7 - @test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo - @test @inferred(MeasureBase.one_to(static(7))) == static(1):static(7) - - @test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == - FillArrays.Fill(4.2, 3, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == - FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == - FillArrays.Fill(4.2, (3:7, 2:5)) - - @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int - @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) == 7 - @test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) isa Static.StaticInt - @test MeasureBase.maybestatic_length(MeasureBase.one_to(static(7))) == static(7) + v = 4.2 + T = typeof(v) + + tpl = (7, 42, 5) + nt = (a = 7, b = 42, c = 5) + + i = 7 + si = static(7) + + @test i isa IntegerLike + @test si isa IntegerLike + + sz = (2, 4, 3) + sasz = StaticArrays.Size(2, 4, 3) + sisz = (static(2), static(4), static(3)) + + len = prod(sz) + slen = static(len) + + @test sz isa SizeLike + @test sasz isa SizeLike + @test sisz isa SizeLike + + @test !(sz isa StaticSizeLike) + @test sasz isa StaticSizeLike + @test sisz isa StaticSizeLike + + axs = (Base.OneTo(2), 2:5, Base.OneTo(3)) + axs1 = (Base.OneTo(2), Base.OneTo(4), Base.OneTo(3)) + saaxs = (StaticOneTo(2), StaticUnitRange(2, 5), StaticOneTo(3)) + saaxs1 = (StaticOneTo(2), StaticOneTo(4), StaticOneTo(3)) + siaxs = (Static.SOneTo(2), static(2):static(5), static(1):static(3)) + siaxs1 = (Static.SOneTo(2), static(1):static(4), static(1):static(3)) + + @test axs isa AxesLike + @test axs1 isa AxesLike + @test saaxs isa AxesLike + @test saaxs1 isa AxesLike + @test siaxs isa AxesLike + @test siaxs1 isa AxesLike + + @test !(axs isa StaticAxesLike) + @test saaxs isa StaticAxesLike + @test saaxs1 isa StaticAxesLike + @test siaxs isa StaticAxesLike + @test siaxs1 isa StaticAxesLike + + @test axs[1] isa OneToLike + @test !(axs[2] isa OneToLike) + @test axs[3] isa OneToLike + + @test saaxs[1] isa OneToLike + @test !(saaxs[2] isa OneToLike) + @test saaxs1[2] isa OneToLike + @test saaxs[3] isa OneToLike + + @test siaxs[1] isa OneToLike + @test !(siaxs[2] isa OneToLike) + @test siaxs1[2] isa OneToLike + @test siaxs[3] isa OneToLike + + @test !(axs[1] isa StaticOneToLike) + @test !(axs[2] isa StaticOneToLike) + @test !(axs[3] isa StaticOneToLike) + + @test saaxs[1] isa StaticOneToLike + @test !(saaxs[2] isa StaticOneToLike) + @test saaxs1[2] isa StaticOneToLike + @test saaxs[3] isa StaticOneToLike + + @test siaxs[1] isa StaticOneToLike + @test !(siaxs[2] isa StaticOneToLike) + @test siaxs1[2] isa StaticOneToLike + @test siaxs[3] isa StaticOneToLike + + @test !(axs[1] isa StaticUnitRangeLike) + @test !(axs[2] isa StaticUnitRangeLike) + @test !(axs[3] isa StaticUnitRangeLike) + + @test saaxs[1] isa StaticUnitRangeLike + @test saaxs[2] isa StaticUnitRangeLike + @test saaxs1[2] isa StaticUnitRangeLike + @test saaxs[3] isa StaticUnitRangeLike + + @test siaxs[1] isa StaticUnitRangeLike + @test siaxs[2] isa StaticUnitRangeLike + @test siaxs1[2] isa StaticUnitRangeLike + @test siaxs[3] isa StaticUnitRangeLike + + @test @inferred(one_to(i)) == Base.OneTo(i) + @test @inferred(one_to(si)) == StaticOneTo(i) + + @test @inferred(asnonstatic(i)) === i + @test @inferred(asnonstatic(si)) === i + @test @inferred(asnonstatic(sz)) === sz + @test @inferred(asnonstatic(sasz)) === sz + @test @inferred(asnonstatic(sisz)) === sz + @test @inferred(asnonstatic(axs)) === axs + @test @inferred(asnonstatic(saaxs)) === axs + @test @inferred(asnonstatic(saaxs1)) === (Base.OneTo(2), Base.OneTo(4), Base.OneTo(3)) + @test @inferred(asnonstatic(siaxs)) === axs + @test @inferred(asnonstatic(siaxs1)) === (Base.OneTo(2), Base.OneTo(4), Base.OneTo(3)) + + @test @inferred(fill_with(v, i)) === FillArrays.Fill(v, i) + @test @inferred(fill_with(v, si)) === StaticArrays.SVector(fill(v, i)...) + @test @inferred(fill_with(v, ())) === FillArrays.Fill(v) + + @test @inferred(fill_with(v, sz)) === FillArrays.Fill(v, sz) + @test @inferred(fill_with(v, sasz)) === StaticArrays.SArray{Tuple{sz...},T}(fill(v, sz)) + @test @inferred(fill_with(v, sisz)) === StaticArrays.SArray{Tuple{sz...},T}(fill(v, sz)) + + @test @inferred(fill_with(v, axs)) === FillArrays.Fill(v, axs) + @test @inferred(fill_with(v, saaxs)) === FillArrays.Fill(v, axs) + @test @inferred(fill_with(v, saaxs1)) === + StaticArrays.SArray{Tuple{sz...},T}(fill(v, sz)) + @test @inferred(fill_with(v, siaxs)) === FillArrays.Fill(v, axs) + @test @inferred(fill_with(v, siaxs1)) === + StaticArrays.SArray{Tuple{sz...},T}(fill(v, sz)) + + @test @inferred(staticarray_type(T, sasz)) <: StaticArrays.SArray{Tuple{2,4,3},T} + + A = rand(T, len) + FA = FillArrays.Fill(v, len) + SA = StaticArrays.SVector(A...) + + # Array with CartesianIndices + ciA = view(rand(5, 6, 6), 3:4, 2:5, 3:5) + ciidxs = eachindex(ciA) + + rshpA = reshape(A, sz) + rshpFA = FillArrays.Fill(v, sz) + rshpSA = StaticArrays.SArray{Tuple{sz...},T}(A) + + @test @inferred(maybestatic_reshape(A, sz)) == rshpA + @test typeof(maybestatic_reshape(A, sz)) == typeof(rshpA) + @test @inferred(maybestatic_reshape(A, sasz)) == rshpA + @test maybestatic_reshape(A, sasz) isa StaticArrays.SArray + @test @inferred(maybestatic_reshape(A, sisz)) == rshpA + @test maybestatic_reshape(A, sisz) isa StaticArrays.SArray + + @test @inferred(maybestatic_reshape(FA, sz)) == rshpFA + @test typeof(maybestatic_reshape(FA, sz)) == typeof(rshpFA) + @test @inferred(maybestatic_reshape(FA, sasz)) == rshpFA + @test maybestatic_reshape(FA, sasz) isa StaticArrays.SArray + @test @inferred(maybestatic_reshape(FA, sisz)) == rshpFA + @test maybestatic_reshape(FA, sisz) isa StaticArrays.SArray + + @test @inferred(maybestatic_reshape(SA, sz)) == rshpA + @test maybestatic_reshape(SA, sz) isa Base.ReshapedArray{T,3,<:StaticArrays.SVector} + @test @inferred(maybestatic_reshape(SA, sasz)) === rshpSA + @test @inferred(maybestatic_reshape(SA, sisz)) === rshpSA + + @test @inferred(maybestatic_length(5)) === static(1) + @test @inferred(maybestatic_length(())) === static(0) + @test @inferred(maybestatic_length((sz))) === static(3) + @test @inferred(maybestatic_length((a = 2, b = 4, c = 3))) === static(3) + @test @inferred(maybestatic_length(Base.OneTo(4))) === 4 + @test @inferred(maybestatic_length(StaticArrays.SOneTo(4))) === static(4) + @test @inferred(maybestatic_length(Static.SOneTo(4))) === static(4) + @test @inferred(maybestatic_length(static(2):static(5))) === static(4) + @test @inferred(maybestatic_length(rshpA)) === length(rshpA) + @test @inferred(maybestatic_length(rshpFA)) === length(rshpA) + @test @inferred(maybestatic_length(rshpSA)) === static(length(rshpA)) + + @test @inferred(maybestatic_size(5)) === () + @test_throws ArgumentError maybestatic_size(()) + @test @inferred(maybestatic_size((sz))) === StaticArrays.Size(3) + @test @inferred(maybestatic_size((a = 2, b = 4, c = 3))) === StaticArrays.Size(3) + @test @inferred(maybestatic_size(Base.OneTo(4))) === (4,) + @test @inferred(maybestatic_size(StaticArrays.SOneTo(4))) === StaticArrays.Size(4) + @test @inferred(maybestatic_size(StaticUnitRange(2, 5))) === StaticArrays.Size(4) + @test @inferred(maybestatic_size(Static.SOneTo(4))) === StaticArrays.Size(4) + @test @inferred(maybestatic_size(static(2):static(5))) === StaticArrays.Size(4) + @test @inferred(maybestatic_size(rshpA)) === size(rshpA) + @test @inferred(maybestatic_size(rshpFA)) === size(rshpA) + @test @inferred(maybestatic_size(rshpSA)) === StaticArrays.Size(size(rshpA)...) + + @test @inferred(maybestatic_axes(5)) === () + @test @inferred(maybestatic_axes(())) === (StaticOneTo(0),) + @test @inferred(maybestatic_axes((sz))) === (StaticOneTo(3),) + @test @inferred(maybestatic_axes((a = 2, b = 4, c = 3))) === (StaticOneTo(3),) + @test @inferred(maybestatic_axes(Base.OneTo(4))) === (Base.OneTo(4),) + @test @inferred(maybestatic_axes(StaticArrays.SOneTo(4))) === (StaticOneTo(4),) + @test @inferred(maybestatic_axes(Static.SOneTo(4))) === (StaticOneTo(4),) + @test @inferred(maybestatic_axes(static(2):static(5))) === (StaticOneTo(4),) + @test @inferred(maybestatic_axes(rshpA)) === axes(rshpA) + @test @inferred(maybestatic_axes(rshpFA)) === axes(rshpA) + @test @inferred(maybestatic_axes(rshpSA)) === saaxs1 + + @test @inferred(axes2size(())) === () + @test @inferred(axes2size(axs)) === sz + @test @inferred(axes2size(saaxs)) === sasz + @test @inferred(axes2size(saaxs1)) === sasz + @test @inferred(axes2size(siaxs)) === sasz + @test @inferred(axes2size(siaxs1)) === sasz + + @test @inferred(size2axes(())) === () + @test @inferred(size2axes(sz)) === axs1 + @test @inferred(size2axes(sasz)) === saaxs1 + @test @inferred(size2axes(sisz)) === saaxs1 + + @test @inferred(size2length(())) === static(1) + @test @inferred(size2length(sz)) === len + @test @inferred(size2length(sasz)) === slen + @test @inferred(size2length(sisz)) === slen + + @test @inferred(asaxes(())) === () + @test @inferred(asaxes(len)) === (Base.OneTo(len),) + @test @inferred(asaxes(slen)) === (StaticOneTo(len),) + @test @inferred(asaxes(sz)) === axs1 + @test @inferred(asaxes(sasz)) === saaxs1 + @test @inferred(asaxes(sisz)) === saaxs1 + @test @inferred(asaxes(axs)) === axs + @test @inferred(asaxes(axs1)) === axs1 + @test @inferred(asaxes(saaxs)) === saaxs + @test @inferred(asaxes(saaxs1)) === saaxs1 + @test @inferred(asaxes(siaxs)) === siaxs + @test @inferred(asaxes(siaxs1)) === siaxs1 + + @test @inferred(maybestatic_eachindex(())) === StaticOneTo(0) + @test @inferred(maybestatic_eachindex(tpl)) === StaticOneTo(3) + @test @inferred(maybestatic_eachindex(nt)) === StaticOneTo(3) + @test @inferred(maybestatic_eachindex(axs[1])) === Base.OneTo(length(axs[1])) + @test @inferred(maybestatic_eachindex(axs[2])) === Base.OneTo(length(axs[2])) + @test @inferred(maybestatic_eachindex(saaxs[1])) === StaticOneTo(length(axs[1])) + @test @inferred(maybestatic_eachindex(saaxs[2])) === StaticOneTo(length(axs[2])) + @test @inferred(maybestatic_eachindex(siaxs[1])) === StaticOneTo(length(axs[1])) + @test @inferred(maybestatic_eachindex(siaxs[2])) === StaticOneTo(length(axs[2])) + @test @inferred(maybestatic_eachindex(A)) === Base.OneTo(24) + @test @inferred(maybestatic_eachindex(ciA)) === eachindex(ciA) + @test @inferred(maybestatic_eachindex(FA)) === Base.OneTo(24) + @test @inferred(maybestatic_eachindex(SA)) === StaticOneTo(24) + + @test_throws BoundsError maybestatic_first(()) + @test @inferred(maybestatic_first(tpl)) === first(tpl) + @test @inferred(maybestatic_first(nt)) === first(nt) + @test @inferred(maybestatic_first(sz)) === first(sz) + @test @inferred(maybestatic_first(sasz)) === static(first(sz)) + @test @inferred(maybestatic_first(sisz)) === static(first(sz)) + @test @inferred(maybestatic_first(axs[1])) === first(axs[1]) + @test @inferred(maybestatic_first(axs[2])) === first(axs[2]) + @test @inferred(maybestatic_first(saaxs[1])) === static(first(axs[1])) + @test @inferred(maybestatic_first(saaxs[2])) === static(first(axs[2])) + @test @inferred(maybestatic_first(siaxs[1])) === static(first(axs[1])) + @test @inferred(maybestatic_first(siaxs[2])) === static(first(axs[2])) + @test @inferred(maybestatic_first(A)) === first(A) + @test @inferred(maybestatic_first(ciA)) === first(ciA) + @test @inferred(maybestatic_first(FA)) === first(FA) + @test @inferred(maybestatic_first(SA)) === first(SA) + + @test_throws BoundsError maybestatic_last(()) + @test @inferred(maybestatic_last(tpl)) === last(tpl) + @test @inferred(maybestatic_last(nt)) === last(nt) + @test @inferred(maybestatic_last(sz)) === last(sz) + @test @inferred(maybestatic_last(sasz)) === static(last(sz)) + @test @inferred(maybestatic_last(sisz)) === static(last(sz)) + @test @inferred(maybestatic_last(axs[1])) === last(axs[1]) + @test @inferred(maybestatic_last(axs[2])) === last(axs[2]) + @test @inferred(maybestatic_last(saaxs[1])) === static(last(axs[1])) + @test @inferred(maybestatic_last(saaxs[2])) === static(last(axs[2])) + @test @inferred(maybestatic_last(siaxs[1])) === static(last(axs[1])) + @test @inferred(maybestatic_last(siaxs[2])) === static(last(axs[2])) + @test @inferred(maybestatic_last(A)) === last(A) + @test @inferred(maybestatic_last(ciA)) === last(ciA) + @test @inferred(maybestatic_last(FA)) === last(FA) + @test @inferred(maybestatic_last(SA)) === last(SA) + + @test @inferred(canonical_indices(axs[1])) === axs[1] + @test @inferred(canonical_indices(axs[2])) === axs[2] + @test @inferred(canonical_indices(saaxs[1])) === saaxs[1] + @test @inferred(canonical_indices(saaxs[2])) === saaxs[2] + @test @inferred(canonical_indices(siaxs[1])) === saaxs[1] + @test @inferred(canonical_indices(siaxs[2])) === saaxs[2] + @test @inferred(canonical_indices(ciidxs)) === ciidxs + + @test @inferred(canonical_size(sz)) === sz + @test @inferred(canonical_size(sasz)) === sasz + @test @inferred(canonical_size(sisz)) === sasz + + @test @inferred(canonical_axes(axs)) === axs + @test @inferred(canonical_axes(axs1)) === axs1 + @test @inferred(canonical_axes(saaxs)) === saaxs + @test @inferred(canonical_axes(saaxs1)) === saaxs1 + @test @inferred(canonical_axes(siaxs)) === saaxs + @test @inferred(canonical_axes(siaxs1)) === saaxs1 end diff --git a/test/test_basics.jl b/test/test_basics.jl index 7ac29dc1..bd5a409c 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -120,8 +120,9 @@ end end @testset "powers" begin - @test logdensityof(Lebesgue()^3, 2) == logdensityof(Lebesgue()^(3,), 2) - @test logdensityof(Lebesgue()^3, 2) == logdensityof(Lebesgue()^(3, 1), (2, 0)) + @test logdensityof(Lebesgue()^3, [2, 2, 2]) == logdensityof(Lebesgue()^(3,), fill(2, 3)) + @test logdensityof(Lebesgue()^3, fill(2, 3)) == + logdensityof(Lebesgue()^(3, 1), fill(2, 3, 1)) end NormalMeasure() = ∫exp(x -> -0.5x^2, Lebesgue(ℝ))