-
Notifications
You must be signed in to change notification settings - Fork 121
Description
Current behavior
The MEADS adaptation routine appears to be incomplete. Currently, cross-chain statistics are computed each iteration and used to update the kernel parameters for the entire chain. This is missing some aspects of Algorithm 3 from Hoffman & Sountsov (2022). Perhaps this was on purpose, in which case I would be interested to know why.
meads = blackjax.meads_adaptation(logdensity_fn, num_chains)
Desired behavior
Hoffman & Sountsov (2022) describe an algorithm where the chains are split into
I propose updating the existing implementation to include the blackjax.meads_adaptation
which could take default values from the paper.
meads = blackjax.meads_adaptation(logdensity_fn, num_chains, num_folds=4, shuffle=True, step_size_multiplier=0.5, damping_slowdown=1.0)
The step_size_multiplier
and damping_slowdown
are hyper-parameters used in calculating the MEADS statistics.