-
Notifications
You must be signed in to change notification settings - Fork 36
Move predict
from Turing
#716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
1c1c907
bdf90b4
c7d08b0
a425c41
41471f6
90d99ca
ea23b7c
76ef40f
304b63e
53b6749
fcd7c3d
3dc742a
30208ec
bf38627
fd1277b
86eab6b
7b172e2
a3fc8b1
da7fa1c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1203,6 +1203,39 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC | |
end | ||
end | ||
|
||
""" | ||
predict([rng::AbstractRNG,] model::Model, chain; include_all=false) | ||
|
||
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample | ||
in `chain`. | ||
|
||
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by | ||
the samples in `chain`. This is useful when you want to sample only new variables from the posterior | ||
predictive distribution. | ||
""" | ||
function predict(model::Model, chain; include_all=false) | ||
|
||
return predict(Random.default_rng(), model, chain; include_all) | ||
end | ||
|
||
function predict( | ||
rng::Random.AbstractRNG, | ||
model::Model, | ||
varinfos::AbstractArray{<:AbstractVarInfo}; | ||
include_all=false, | ||
) | ||
predictive_samples = similar(varinfos, OrderedDict{Symbol,Any}) | ||
|
||
for i in eachindex(varinfos) | ||
model(rng, varinfos[i], SampleFromPrior()) | ||
vals = values_as_in_model(model, varinfos[i]) | ||
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) | ||
params = mapreduce(collect, vcat, iters) | ||
predictive_samples[i] = OrderedDict( | ||
:values => params, :logp => getlogp(varinfos[i]) | ||
) | ||
end | ||
return predictive_samples | ||
end | ||
|
||
""" | ||
generated_quantities(model::Model, parameters::NamedTuple) | ||
generated_quantities(model::Model, values, keys) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" | ||
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" | ||
|
||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" | ||
|
@@ -32,6 +33,7 @@ AbstractMCMC = "5" | |
AbstractPPL = "0.8.4, 0.9" | ||
Accessors = "0.1" | ||
Bijectors = "0.13.9, 0.14" | ||
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" | ||
Combinatorics = "1" | ||
Compat = "4.3.0" | ||
Distributions = "0.25" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,170 @@ | |
@test size(chain_generated) == (1000, 1) | ||
@test mean(chain_generated) ≈ 0 atol = 0.1 | ||
end | ||
|
||
@testset "predict" begin | ||
DynamicPPL.Random.seed!(100) | ||
|
||
@model function linear_reg(x, y, σ=0.1) | ||
β ~ Normal(0, 1) | ||
|
||
for i in eachindex(y) | ||
y[i] ~ Normal(β * x[i], σ) | ||
end | ||
end | ||
|
||
@model function linear_reg_vec(x, y, σ=0.1) | ||
β ~ Normal(0, 1) | ||
return y ~ MvNormal(β .* x, σ^2 * I) | ||
end | ||
|
||
f(x) = 2 * x + 0.1 * randn() | ||
|
||
Δ = 0.1 | ||
xs_train = 0:Δ:10 | ||
ys_train = f.(xs_train) | ||
xs_test = [10 + Δ, 10 + 2 * Δ] | ||
ys_test = f.(xs_test) | ||
|
||
# Infer | ||
m_lin_reg = linear_reg(xs_train, ys_train) | ||
chain_lin_reg = sample( | ||
DynamicPPL.LogDensityFunction(m_lin_reg), | ||
AdvancedHMC.NUTS(0.65), | ||
|
||
1000; | ||
chain_type=MCMCChains.Chains, | ||
param_names=[:β], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
|
||
# Predict on two last indices | ||
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) | ||
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) | ||
|
||
# test like this depends on the variance of the posterior | ||
# this only makes sense if the posterior variance is about 0.002 | ||
@test sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
|
||
# Ensure that `rng` is respected | ||
predictions1 = let rng = MersenneTwister(42) | ||
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) | ||
end | ||
predictions2 = let rng = MersenneTwister(42) | ||
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) | ||
end | ||
@test all(Array(predictions1) .== Array(predictions2)) | ||
|
||
# Predict on two last indices for vectorized | ||
m_lin_reg_test = linear_reg_vec(xs_test, missing) | ||
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) | ||
|
||
@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 | ||
|
||
# Multiple chains | ||
chain_lin_reg = sample( | ||
DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), | ||
AdvancedHMC.NUTS(0.65), | ||
MCMCThreads(), | ||
1000, | ||
2; | ||
chain_type=MCMCChains.Chains, | ||
param_names=[:β], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) | ||
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
@test size(chain_lin_reg, 3) == size(predictions, 3) | ||
|
||
for chain_idx in MCMCChains.chains(chain_lin_reg) | ||
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) | ||
@test sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
end | ||
|
||
# Predict on two last indices for vectorized | ||
m_lin_reg_test = linear_reg_vec(xs_test, missing) | ||
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
for chain_idx in MCMCChains.chains(chain_lin_reg) | ||
ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) | ||
@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 | ||
end | ||
|
||
# https://github.com/TuringLang/Turing.jl/issues/1352 | ||
@model function simple_linear1(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef ~ MvNormal(zeros(2), I) | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear2(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef ~ filldist(Normal(0, 1), 2) | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear3(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef = Vector(undef, 2) | ||
for i in axes(coef, 1) | ||
coef[i] ~ Normal(0, 1) | ||
end | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear4(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef1 ~ Normal(0, 1) | ||
coef2 ~ Normal(0, 1) | ||
coef = [coef1, coef2] | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
x = randn(2, 100) | ||
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] | ||
|
||
param_names = Dict( | ||
simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear4 => [:intercept, :coef1, :coef2, :error], | ||
) | ||
@testset "$model" for model in | ||
[simple_linear1, simple_linear2, simple_linear3, simple_linear4] | ||
m = model(x, y) | ||
chain = sample( | ||
DynamicPPL.LogDensityFunction(m), | ||
AdvancedHMC.NUTS(0.65), | ||
400; | ||
initial_params=rand(4), | ||
chain_type=MCMCChains.Chains, | ||
param_names=param_names[model], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
chain_predict = DynamicPPL.predict(model(x, missing), chain) | ||
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] | ||
@test mean(abs2, mean_prediction - y) ≤ 1e-3 | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here: no need to use
AdvancedHMC
(or any of the other packages), just construct theChains
by hand.This also doesn't actually show that you need to import
MCMCChains
for this to work, which might be a good idea