- 
                Notifications
    You must be signed in to change notification settings 
- Fork 2.1k
Blackjax sampler fix for breaking change / enable progress bar under parallel chain_method #7453
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 
 It's enough if we're compatible with the latest blackjax releases. We can raise a runtime informative error if we know the installed version of blackjax is too old to work, directing users to update it | 
| 
 | 
| Thanks for the new release Looks like it needs to get into conda first | 
| Codecov ReportAll modified and coverable lines are covered by tests ✅ 
 Additional details and impacted files@@            Coverage Diff             @@
##             main    #7453      +/-   ##
==========================================
- Coverage   92.20%   92.17%   -0.03%     
==========================================
  Files         103      103              
  Lines       17301    17258      -43     
==========================================
- Hits        15952    15908      -44     
- Misses       1349     1350       +1     
 | 
…parallel chain_method (pymc-devs#7453) * remove blackjax pmap warning * use gen_scan_fn * remove labels * retrigger checks * retrigger checks
blackjax-devs/blackjax#712 changes the expected jax.lax.scan carry in
progress_bar_scan. Since pymc's external blackjax sampler directly usesprogress_bar_scanit will break whenprogressbar=True. This change switches to use a new wrapper to hide the progress bar details. In addition it enables the use of progress bars underchain_method="parallel".I think any breaking issues can be handled by restricting blackjax version numbers. However, I'm not sure how to properly do that?
And of course for now tests are expected to fail until the changes show in a blackjax release.
PRs that are dependencies:
blackjax-devs/blackjax#712
blackjax-devs/blackjax#716
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7453.org.readthedocs.build/en/7453/