Skip to content

Fix inner kernel tuning initialization for compatibility with persistent SMC#810

Open
thomasckng wants to merge 2 commits intoblackjax-devs:mainfrom
thomasckng:fix-inner-kernel-tuning-with-persistent-smc
Open

Fix inner kernel tuning initialization for compatibility with persistent SMC#810
thomasckng wants to merge 2 commits intoblackjax-devs:mainfrom
thomasckng:fix-inner-kernel-tuning-with-persistent-smc

Conversation

@thomasckng
Copy link

Summary

Fixes the as_top_level_api function in inner_kernel_tuning.py to properly initialize the SMC algorithm with all required parameters before accessing its .init method. This ensures compatibility with the persistent SMC algorithm (#799).

Changes

  • Modified init_fn in as_top_level_api to instantiate the smc_algorithm with all parameters (logprior_fn, loglikelihood_fn, mcmc_step_fn, mcmc_init_fn, mcmc_parameters, resampling_fn, num_mcmc_steps, and extra parameters) before calling .init
  • This mirrors the pattern already used in the build_kernel function's step_fn

Motivation

The previous implementation passed smc_algorithm.init directly without initializing the algorithm object first. This worked for some SMC variants but breaks with persistent SMC, which requires the algorithm to be fully instantiated with its parameters before the init function can be accessed.

Related Issues

Test

  • Existing tests in tests/smc/test_inner_kernel_tuning.py cover this functionality

@junpenglao
Copy link
Member

Thank! could you fix the test?

@thomasckng
Copy link
Author

@junpenglao Thank you for the review. The current test failures are caused by JAX v0.9.0+ issuing a DeprecationWarning for the jax_pmap_shmap_merge setting, and because pytest.ini is configured with filterwarnings = error, this results in all tests failing.

Since addressing this warning (e.g., by adjusting the test configuration) is outside the scope of this change, I haven’t included a fix for it here. Please let me know whether you’d prefer to handle that fix in another PR, or if you’d like me to add it to this PR.

@junpenglao
Copy link
Member

Ah i see, let me fix it upstream.

@junpenglao
Copy link
Member

update: saw that optax has removed it google-deepmind/optax@df03db8, i will just wait for them to cut a new release

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants