- 
                Notifications
    You must be signed in to change notification settings 
- Fork 230
Closed
Description
Drawing a variable from a Dirichlet distribution introduces a type instability, slowing down sampling dramatically (around 25x on my laptop):
using Turing
@model MarginalizedGMM(x, K, ::Type{T}=Vector{Float64}) where {T} = begin
    N = length(x)
    μ = T(undef, K)
    σ = T(undef, K)
    for i in 1:K
        μ[i] ~ Normal(0, 5)
        σ[i] ~ Gamma()
    end
    w ~ Dirichlet(K, 1.0)
    # w = T([0.75, 0.25]) Way faster with this line instead of ↑
    for i in 1:N
      x[i] ~ Distributions.UnivariateGMM(μ,σ, Categorical(w))
    end
    return (μ::T, σ::T, w::T)
end
x = [randn(150) .- 2; randn(50) .+ 2]
gmm = MarginalizedGMM(x, 2)
varinfo = Turing.VarInfo(gmm)
spl = Turing.SampleFromPrior()
@code_warntype gmm.f(varinfo, spl, Turing.DefaultContext(), gmm)Metadata
Metadata
Assignees
Labels
No labels