Skip to content

Current MEADS implementation is incomplete #781

@alexlyttle

Description

@alexlyttle

Current behavior

The MEADS adaptation routine appears to be incomplete. Currently, cross-chain statistics are computed each iteration and used to update the kernel parameters for the entire chain. This is missing some aspects of Algorithm 3 from Hoffman & Sountsov (2022). Perhaps this was on purpose, in which case I would be interested to know why.

meads = blackjax.meads_adaptation(logdensity_fn, num_chains)

Desired behavior

Hoffman & Sountsov (2022) describe an algorithm where the chains are split into $K$ folds. Cross-chains statistics are computed within each fold and used to update the neighbouring fold each iteration (skipping the fold equal to the current iteration modulo $K$). It also describes a shuffling of all chains every $K$ steps. It appears the original author implemented the algorithm in this notebook. I have recently experimented with modifying the BlackJAX MEADS to reflect this for a project testing new MCMC adaptation algorithms.

I propose updating the existing implementation to include the $K$-folding and shuffling described in the paper. This would introduce a few more parameters to blackjax.meads_adaptation which could take default values from the paper.

meads = blackjax.meads_adaptation(logdensity_fn, num_chains, num_folds=4, shuffle=True, step_size_multiplier=0.5, damping_slowdown=1.0)

The step_size_multiplier and damping_slowdown are hyper-parameters used in calculating the MEADS statistics.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions