Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Oct 23, 2025

This PR adds a new struct ParamsWithStats and functions to_chains and from_chains which is mainly meant for developers of packages that share an interface with DynamicPPL.

I would say that the main purpose of these function are to abstract away the inner details of chain construction so that this doesn't have to be duplicated everywhere. For example, there are at least four different places that feature the 'split-up-dicts-of-varnames' game for MCMCChains:

(1) AbstractMCMC.bundle_samples https://github.com/TuringLang/Turing.jl/blob/0eb8576c2c1f659aafdc1a22fc6396e0b1588a67/src/mcmc/Inference.jl#L311-L312

(2) DynamicPPL.predict

function _predictive_samples_to_arrays(predictive_samples)
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}()
sample_dicts = map(predictive_samples) do sample
varname_value_pairs = sample.varname_and_values
varnames = map(first, varname_value_pairs)
values = map(last, varname_value_pairs)
for varname in varnames
push!(variable_names_set, varname)
end
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values))
end
variable_names = collect(variable_names_set)
variable_values = [
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
key in variable_names
]
return variable_names, variable_values
end

(3) This DynamicPPL test utility

function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
# Sample from the prior
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
# Extract all varnames found in any dictionary. Doing it this way guards
# against the possibility of having different varnames in different
# dictionaries, e.g. for models that have dynamic variables / array sizes
varnames = OrderedSet{VarName}()
# Convert each varinfo into an OrderedDict of vns => params.
# We have to use varname_and_value_leaves so that each parameter is a scalar
dicts = map(varinfos) do t
vals = DynamicPPL.values_as(t, OrderedDict)
iters = map(AbstractPPL.varname_and_value_leaves, keys(vals), values(vals))
tuples = mapreduce(collect, vcat, iters)
# The following loop is a replacement for:
# push!(varnames, map(first, tuples)...)
# which causes a stack overflow if `map(first, tuples)` is too large.
# Unfortunately there isn't a union() function for OrderedSet.
for vn in map(first, tuples)
push!(varnames, vn)
end
OrderedDict(tuples)
end
# Convert back to list
varnames = collect(varnames)
# Construct matrix of values
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
# Construct dict of varnames -> symbol
vn_to_sym_dict = Dict(zip(varnames, map(Symbol, varnames)))
# Construct and return the Chains object
return Chains(vals, varnames; info=(; varname_to_symbol=vn_to_sym_dict))
end

(4) Pathfinder.pathfinder https://github.com/mlcolab/Pathfinder.jl/blob/6389f125197110ff35ccddc10ed682e4b9ff8c12/ext/PathfinderTuringExt.jl#L49

Another benefit is that certain details, like the varname_to_symbol Dict that is stored with the chain, are implemented at the same level at which it's being used.


The eagle-eyed will notice that ParamsWithStats is effectively the same as Turing.Inference.Transition, just without the logp terms explicitly bundled in.

Furthermore, to_chains in the MCMCChainsExt is almost completely the same as bundle_samples in Turing (although perhaps implemented in a slightly simpler way).

I did it this way because I want Turing to be able to make use of this function. In an original draft I had to_chains take an array of VarInfo, and then perform the reevaluation. However, this makes it quite complicated to use this in the MCMC sampling bits of Turing.

@github-actions
Copy link
Contributor

github-actions bot commented Oct 23, 2025

Benchmark Report for Commit 2fecc74

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────┬─────────────────┐
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │ t(eval)/t(ref) │ t(grad)/t(eval) │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼────────────────┼─────────────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │            6.6 │             1.7 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │          741.8 │            44.6 │
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │          427.4 │            54.3 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │          791.2 │            36.1 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │         7067.5 │            25.5 │
│           Smorgasbord │   201 │ reversediff │             typed │   true │          760.5 │            54.9 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │          760.3 │             6.0 │
│           Smorgasbord │   201 │      enzyme │             typed │   true │          919.2 │             3.8 │
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │         3969.9 │             5.8 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │         1022.2 │             8.9 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │        43894.2 │             5.5 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │         9053.2 │             9.8 │
│               Dynamic │    10 │    mooncake │             typed │   true │          120.8 │            12.2 │
│              Submodel │     1 │    mooncake │             typed │   true │            8.7 │             6.7 │
│                   LDA │    12 │ reversediff │             typed │   true │         1021.3 │             2.0 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴────────────────┴─────────────────┘

@codecov
Copy link

codecov bot commented Oct 23, 2025

Codecov Report

❌ Patch coverage is 96.29630% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.38%. Comparing base (9a2607b) to head (2fecc74).

Files with missing lines Patch % Lines
src/to_chains.jl 92.59% 2 Missing ⚠️
ext/DynamicPPLMCMCChainsExt.jl 98.14% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1087      +/-   ##
==========================================
+ Coverage   81.06%   81.38%   +0.32%     
==========================================
  Files          40       41       +1     
  Lines        3749     3798      +49     
==========================================
+ Hits         3039     3091      +52     
+ Misses        710      707       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm force-pushed the py/varinfos_to_chains branch from 96df1a4 to 70c3dd9 Compare October 23, 2025 18:36
@penelopeysm penelopeysm marked this pull request as ready for review October 23, 2025 18:40
@penelopeysm penelopeysm force-pushed the py/varinfos_to_chains branch from 70c3dd9 to 7049125 Compare October 23, 2025 18:47
@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1087 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1087/

@penelopeysm penelopeysm changed the title Add varinfos_to_chains function Add to_chains function Oct 23, 2025
@penelopeysm penelopeysm force-pushed the py/varinfos_to_chains branch from 7b337bd to 833cbbf Compare October 23, 2025 19:38
@penelopeysm
Copy link
Member Author

CI failures are because of #1081.

@sethaxen Tagging you too since this came up via Pathfinder!

@penelopeysm penelopeysm requested a review from sunxd3 October 23, 2025 19:58
@sethaxen
Copy link
Member

@sethaxen Tagging you too since this came up via Pathfinder!

Seems to work like a charm for Pathfinder! mlcolab/Pathfinder.jl#274

@sethaxen
Copy link
Member

I wonder, for other packages implementing new chain_types, is there an equivalent to an inverse of to_chains that could be used e.g. for predict or pointwise_loglikelihoods? Or would something like from_chains that given an object of type chain_type and a Model returns an array of ParamsWithStats also be useful to have?

@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 24, 2025

an inverse of to_chains

So there are a couple of potential stages: Chain ----> Dict{VarName,Any} --[model]--> VarInfo.

@penelopeysm penelopeysm changed the title Add to_chains function Add to_chains and from_chains function Oct 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants