Skip to content

Fix ESS-target comparison in adaptive persistent sampling solver#811

Open
thomasckng wants to merge 2 commits intoblackjax-devs:mainfrom
thomasckng:fix-adaptive-smc
Open

Fix ESS-target comparison in adaptive persistent sampling solver#811
thomasckng wants to merge 2 commits intoblackjax-devs:mainfrom
thomasckng:fix-adaptive-smc

Conversation

@thomasckng
Copy link

The root-finding function in adaptive_persistent_sampling.calculate_lambda compared ESS and target in mismatched spaces. compute_persistent_ess returns ESS in linear space, but target_val is computed as jnp.log(n_particles * target_ess) (log space). The original code subtracted the log-space target directly from the linear-space ESS:

return ess_val - target_val

This meant the solver was finding a root of ESS - log(N * target) instead of log(ESS) - log(N * target), causing incorrect tempering schedules. The fix wraps ess_val in jnp.log() so both sides are in log space:

return jnp.log(ess_val) - target_val

This is consistent with how adaptive_tempered_smc handles the same computation via log_ess() in blackjax.smc.ess.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant