Skip to content

Extend Static/StaticArrays tooling and adapt PowerMeasure #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MeasureBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import ConstructionBase
using ConstructionBase: constructorof
using IntervalSets

import StaticArrays
using StaticArrays:
StaticArray, StaticVector, StaticMatrix, SArray, SVector, SMatrix, SOneTo

Expand Down
4 changes: 2 additions & 2 deletions src/combinators/implicitlymapped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,13 @@ struct TakeAny{T<:IntegerLike}
n::T
end

_takeany_range(f::TakeAny, idxs) = first(idxs):first(idxs)+dynamic(f.n)-1
_takeany_range(f::TakeAny, idxs) = first(idxs):(first(idxs)+dynamic(f.n)-1)
@inline _takeany_range(f::TakeAny, ::OneTo) = OneTo(dynamic(f.n))

@inline _takeany_range(::TakeAny{<:Static.StaticInteger{N}}, ::OneTo) where {N} = SOneTo(N)
@inline _takeany_range(::TakeAny{<:Static.StaticInteger{N}}, ::SOneTo) where {N} = SOneTo(N)

@inline (f::TakeAny)(xs::Tuple) = xs[begin:begin+f.n-1]
@inline (f::TakeAny)(xs::Tuple) = xs[begin:(begin+f.n-1)]
@inline (f::TakeAny)(xs::AbstractVector) = xs[_takeany_range(f, eachindex(xs))]

function (f::TakeAny)(xs)
Expand Down
75 changes: 51 additions & 24 deletions src/combinators/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,37 @@ the product determines the dimensionality of the resulting support.
Note that power measures are only well-defined for integer powers.

The nth power of a measure μ can be written μ^n.

See also [`pwr_base`](@ref), [`pwr_axes`](@ref) and [`pwr_size`](@ref).
"""
struct PowerMeasure{M,A} <: AbstractProductMeasure
parent::M
axes::A
end

maybestatic_length(μ::PowerMeasure) = prod(maybestatic_size(μ))
maybestatic_size(μ::PowerMeasure) = map(maybestatic_length, μ.axes)
maybestatic_length(μ::PowerMeasure) = size2length(maybestatic_size(μ))
maybestatic_size(μ::PowerMeasure) = axes2size(μ.axes)

"""
MeasureBase.pwr_base(μ::PowerMeasure)

Returns `ν` for `μ = ν^axs`
"""
@inline pwr_base(μ::PowerMeasure) = μ.parent

"""
MeasureBase.pwr_axes(μ::PowerMeasure)

Returns `axs` for `μ = ν^axs`, `axs` being a tuple of integer ranges.
"""
@inline pwr_axes(μ::PowerMeasure) = μ.axes

"""
MeasureBase.pwr_size(μ::PowerMeasure)

Returns `sz` for `μ = ν^sz`, `sz` being a tuple of integers.
"""
@inline pwr_size(μ::PowerMeasure) = axes2size(μ.axes)

function Pretty.tile(μ::PowerMeasure)
sz = length.(μ.axes)
Expand All @@ -30,30 +53,29 @@ end
# ToDo: Make rand return static arrays for statically-sized power measures.

function _cartidxs(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N}
CartesianIndices(map(_dynamic, axs))
CartesianIndices(map(asnonstatic, axs))
end

function Base.rand(
rng::AbstractRNG,
::Type{T},
d::PowerMeasure{M},
) where {T,M<:AbstractMeasure}
map(_cartidxs(d.axes)) do _
rand(rng, T, d.parent)
axs, base_d = pwr_axes(d), pwr_base(d)
map(_cartidxs(axs)) do _
rand(rng, T, base_d)
end
end

function Base.rand(rng::AbstractRNG, ::Type{T}, d::PowerMeasure) where {T}
map(_cartidxs(d.axes)) do _
rand(rng, d.parent)
axs, base_d = pwr_axes(d), pwr_base(d)
map(_cartidxs(axs)) do _
rand(rng, base_d)
end
end

@inline _pm_axes(sz::Tuple{Vararg{IntegerLike,N}}) where {N} = map(one_to, sz)
@inline _pm_axes(axs::Tuple{Vararg{AbstractUnitRange,N}}) where {N} = axs

@inline function powermeasure(x::T, sz::Tuple{Vararg{Any,N}}) where {T,N}
PowerMeasure(x, _pm_axes(sz))
PowerMeasure(x, asaxes(sz))
end

marginals(d::PowerMeasure) = fill_with(d.parent, d.axes)
Expand All @@ -80,23 +102,32 @@ end

for func in [:logdensityof, :logdensity_def]
@eval @inline function $func(d::PowerMeasure{M}, x) where {M}
parent = d.parent
sum(x) do xj
$func(parent, xj)
parent_m = d.parent
sz_parent = axes2size(d.axes)
sz_x = maybestatic_size(x)
if sz_parent != sz_x
throw(ArgumentError("Size of variate doesn't match size of power measure"))
end
R = infer_logdensity_type($func, parent_m, eltype(x))
if isempty(x)
return zero(R)::R
else
# Need to convert since sum can turn static into dynamic values:
return convert(R, sum(Base.Fix1($func, parent_m), x))::R
end
end

@eval @inline function $func(d::PowerMeasure{M,Tuple{Static.SOneTo{N}}}, x) where {M,N}
@eval @inline function $func(d::PowerMeasure{<:Any,Tuple{<:StaticOneToLike}}, x)
parent = d.parent
sum(1:N) do j
@inbounds $func(parent, x[j])
end
end

@eval @inline function $func(
d::PowerMeasure{M,NTuple{N,Static.SOneTo{0}}},
::PowerMeasure{<:Any,<:Tuple{Vararg{StaticOneToLike{0}}}},
x,
) where {M,N}
)
static(0.0)
end
end
Expand All @@ -117,15 +148,11 @@ end
end
end

@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes))

@inline function getdof(::PowerMeasure{<:Any,NTuple{N,Static.SOneTo{0}}}) where {N}
static(0)
end
@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * size2length(axes2size(μ.axes))

@propagate_inbounds function checked_arg(μ::PowerMeasure, x::AbstractArray{<:Any})
@boundscheck begin
sz_μ = map(length, μ.axes)
sz_μ = pwr_size(μ)
sz_x = size(x)
if sz_μ != sz_x
throw(ArgumentError("Size of variate doesn't match size of power measure"))
Expand All @@ -144,7 +171,7 @@ logdensity_def(::PowerMeasure{P}, x) where {P<:PrimitiveMeasure} = static(0.0)

# To avoid ambiguities
function logdensity_def(
::PowerMeasure{P,Tuple{Vararg{Static.SOneTo{0},N}}},
::PowerMeasure{P,<:Tuple{Vararg{StaticOneToLike{0},N}}},
x,
) where {P<:PrimitiveMeasure,N}
static(0.0)
Expand Down
4 changes: 2 additions & 2 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,13 @@ end
ℓ = logdensity_def(μs[$M], νs[$N], x)
end

for i in 1:M-1
for i in 1:(M-1)
push!(q.args, :(Δℓ = logdensity_def(μs[$i], x)))
# push!(q.args, :(println("Adding", Δℓ)))
push!(q.args, :(ℓ += Δℓ))
end

for j in 1:N-1
for j in 1:(N-1)
push!(q.args, :(Δℓ = logdensity_def(νs[$j], x)))
# push!(q.args, :(println("Subtracting", Δℓ)))
push!(q.args, :(ℓ -= Δℓ))
Expand Down
2 changes: 1 addition & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function test_smf(μ, n = 100)
@testset "smf($μ)" begin
# Get `n` sorted uniforms in O(n) time
p = rand(n)
p .+= 0:n-1
p .+= 0:(n-1)
p .*= inv(n)

F(x) = smf(μ, x)
Expand Down
6 changes: 3 additions & 3 deletions src/standard/stdmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ end
# Helpers for product transforms and similar:

struct _TransportToStd{NU<:StdMeasure} <: Function end
_TransportToStd{NU}(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x)
(::_TransportToStd{NU})(μ, x) where {NU} = transport_to(NU()^getdof(μ), μ)(x)

struct _TransportFromStd{MU<:StdMeasure} <: Function end
_TransportFromStd{MU}(ν, x) where {MU} = transport_to(ν, MU()^getdof(ν))(x)
Expand All @@ -67,7 +67,7 @@ function _tuple_transport_def(
μs::Tuple,
xs::Tuple,
) where {NU<:StdMeasure}
reshape(vcat(map(_TransportToStd{NU}, μs, xs)...), ν.axes)
reshape(vcat(map(_TransportToStd{NU}(), μs, xs)...), ν.axes)
end

function transport_def(
Expand All @@ -93,7 +93,7 @@ end
function _stdvar_viewranges(μs::Tuple, startidx::IntegerLike)
N = map(getdof, μs)
offs = _offset_cumsum(startidx, N...)
map((o, n) -> o:o+n-1, offs, N)
map((o, n) -> o:(o+n-1), offs, N)
end

function _tuple_transport_def(
Expand Down
Loading
Loading