Skip to content

Conversation

yallup
Copy link

@yallup yallup commented Nov 11, 2024

A few important guidelines and requirements before we can merge your PR:

  • If I add a new sampler, there is an issue discussing it already; Nested Sampling implementation #753
  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

High level description of changes

The following files are included in the folder ns:

  • base: The base nested sampler. Detailed more below, should be somewhat familiar to the SMC structure. Nested sampling as an outer kernel with a delete function (resampling) to remove the lowest likelihood points, and then maps a vectorized update over those “deleted” particles to replace them with new particles subject to the hard likelihood constraint.

  • adaptive: Similar to the SMC inner kernel tuning, wraps the base sampler with a parameter update function to tune the inner kernel parameters.

  • utils: Useful calculations, particularly for extracting log_weights, weighted (at a specified inverse temperature) samples of the target

  • vectorized_slice: A compatible inner kernel for the nested sampling kernels, this is non-standard for the rest of the library so opinions on how best to do this are most welcome, we tried to follow the SMC design of flexible choice of inner kernel, but currently only this one works... Currently this explicitly loads both the prior logdensity and loglikelihood as functions, as we would think about them in nested sampling. But I suspect there is a clever way to lift this to be in the mcmc folder, and overload the extra loglikelihood condition for use in nested sampling. For now we have a practical implementation here that works for our purpose. Currently this doesn’t use a proposal distribution as in the mh kernels, allowing a more flexible definition of a random slice direction, and instead hardcodes a derived from a covariance.

Out of these there are currently 3 top level APIs defined (which is somewhat overkill as things stand but hopefully it translates). Base and adaptive both have top level apis, named generically as per the usual design. Inside adaptive we have put a top level api for nss or "nested slice sampling", that loads explicitly the vectorized slice inner kernel and corresponding tuning.

Example usage

Lastly there is an example usage script (not to be included in the final PR but to help demonstrate how we intend these components to be used on a simple gaussian-gaussian integration problem. Under docs/examples/nested_sampling.py (this has an external dependency of distrax). I have added a number of inline comments in this to explain some choices, this would all be extended at some point and folded into the sampling book rather than here but I’ve included it as a tracked file for convenience.

As there are quite a few non standard parts here I will submit this as a draft PR for now, hoping for some higher level feedback before getting too far into the weeds. Hopefully there is enough here for you to go on @AdrienCorenflos as an initial look and the example works out of the box for you.

williamjameshandley and others added 30 commits May 26, 2025 21:34
…pler and update Nested Sampling

This commit enhances the Slice Sampler (`blackjax.mcmc.ss`) with the ability to handle generic constraints beyond the log-density slice itself. Nested Sampling (`blackjax.ns`) implementations are refactored to leverage this new capability for enforcing the likelihood constraint.

Slice Sampler (`mcmc.ss`):
- The `kernel` and `horizontal_slice_proposal` functions now accept `constraint_fn`, `constraint`, and `strict` arguments. These allow specifying an additional function whose output must satisfy a given bound for a proposal to be considered "within" the slice.
- During the stepping-out and shrinking procedures, proposed points `x` are now checked against `constraint_fn(x) > constraint` (or `>=` if `strict` is False) in addition to `logdensity_fn(x) >= log_slice_height`.
- `SliceInfo` now includes `constraint` (the value of `constraint_fn` at the accepted point).
- Type hints for `l_steps`, `r_steps`, `s_steps`, and `evals` in `SliceInfo` are changed from `Array` to `int`.

Nested Sampling (`ns`):
- The previous approach in `ns.base` of creating a `constrained_logdensity_fn` (which combined prior and likelihood constraint by returning -inf) has been removed.
- The `inner_kernel` in `ns.base` and `ns.adaptive` (and their initializers) now explicitly receives `logprior_fn`, `loglikelihood_fn`, and `loglikelihood_0`.
- In `ns.nss` (Nested Slice Sampling):
    - The `inner_kernel` now passes `logprior_fn` as the `logdensity_fn` to the slice sampler.
    - The likelihood constraint (`loglikelihood_fn(x) > loglikelihood_0`) is passed as the new explicit constraint to the slice sampler using `constraint_fn`, `constraint`, and `strict`.
    - Introduced `NSSInnerState` and `NSSStepInfo` for better state and information management within the nested slice sampling context.
- `inner_init_fn` signatures in NS modules are updated to reflect the separation of prior, likelihood, and the particle's current state.

This change decouples the likelihood constraint logic from the main log-density function passed to the slice sampler. This leads to a cleaner interface and a more direct way of handling constraints within Nested Sampling, particularly for the likelihood threshold.

It also reduces the number of calls to the likelihood and prior, and gives users access to the slice sampling chain results for further analysis.
…ce sampling

- Add tests/ns/test_nested_sampling.py with 9 tests covering:
  * Base nested sampling initialization and particle deletion
  * Adaptive nested sampling parameter updates
  * Nested slice sampling direction functions and kernel construction
  * Utility functions for log-volume simulation and live point counting

- Add tests/mcmc/test_slice_sampling.py with 11 tests covering:
  * Slice sampler initialization and vertical slice height sampling
  * Multi-dimensional slice sampling (1D, 2D, 5D)
  * Constrained slice sampling and direction generation
  * Hit-and-run slice sampling top-level API
  * Statistical correctness validation with robust error handling

- All tests follow BlackJAX conventions using chex.TestCase, parameterized
  testing, and proper shape/tree validation
- Tests are optimized for fast execution with reduced sample sizes
- Comprehensive coverage of both core functionality and edge cases

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
- Remove incorrect assertion that number of live points must be monotonically decreasing
- In real nested sampling, live points follow a sawtooth pattern as particles die and are replenished
- Fix evidence estimation test to use proper utility functions instead of skipping
- Create more realistic mock data with varied birth likelihoods
- Add proper documentation explaining why monotonic decrease assumption is wrong
- Reference: Fowlie et al's plateau nested sampling shows live points can increase

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
- Remove unused imports from test files (functools, numpy, blackjax, etc.)
- Fix MyPy type annotation issues:
  * Add Dict type annotation for inner_kernel_params in base.py
  * Rename duplicate nested function names in ss.py (shrink_body_fun, shrink_cond_fun)
  * Add proper type annotation for params parameter in nss.py
  * Add None check for optional update_inner_kernel_params_fn in adaptive.py
- Update test comment to clarify scope rather than suggesting "skip"
- Apply Black formatting and fix import ordering
- All pre-commit hooks now pass successfully

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
…sampling

- Replace weak evidence tests with proper analytic validation
- Test unnormalized Gaussian likelihood with uniform prior using exact analytic solution
- Validate evidence estimates against analytic values using statistical consistency
- Generate 500-1000 Monte Carlo evidence samples to test distribution properties
- Check that analytic evidence falls within 95-99% confidence intervals
- Add test cases for:
  * Unnormalized Gaussian exp(-0.5*x²) with uniform prior [-3,3]
  * Narrow prior challenging case with full Gaussian likelihood
  * Constant likelihood case for baseline validation
- Use proper Bayesian evidence formula: Z = ∫ p(x) * L(x) dx
- Statistical validation with confidence intervals rather than arbitrary tolerances
- Addresses requirement for evidence correctness testing with known analytic solutions

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
This commit refactors state management in the Slice Sampler and simplifies
the inner kernel handling in the Nested Sampling base.

**Slice Sampler (`mcmc.ss`):**
- `SliceState` now includes a `logslice` field (defaulting to `jnp.inf`)
  to store the height sampled during the vertical slice.
- `vertical_slice` now takes the full `SliceState`, calculates `logslice`,
  and returns an updated `SliceState`. Its signature changes from
  `(rng_key, logdensity_fn, position)` to `(rng_key, state)`.
- `horizontal_slice` now takes `SliceState` (which includes `logslice`)
  instead of separate `x0` and `log_slice_height`. Its signature changes
  to accept `state: SliceState` as its second argument.
- These changes centralize slice-related information, improving data flow
  within the slice sampling kernel. `ss.init` remains the primary way to
  initialize `SliceState`, and `ss.build_kernel` API is unchanged.

**Nested Sampling (`ns.base`, `ns.nss`):**
- Removed the `inner_kernel_init_fn` parameter from `ns.base.build_kernel`.
  The `NSInnerState` is now always initialized directly within the NS loop
  before calling the vmapped `inner_kernel`. This is a breaking API change
  for users who might have been providing a custom `inner_kernel_init_fn`.
- Introduced `ns.base.new_state_and_info` helper function to standardize
  the creation of `NSInnerState` and `NSInnerInfo`.
- `ns.nss.inner_kernel` is adapted to use the updated `SliceState` from
  `mcmc.ss` and now utilizes the `new_state_and_info` helper.

These modifications enhance code clarity and maintainability. Tests for
slice sampling and nested sampling have been updated accordingly.
- Rename NSInnerState/NSInnerInfo to PartitionedState/PartitionedInfo for posterior repartitioning
- Add comprehensive docstrings explaining separation of log-prior and log-likelihood components
- Improve type annotations and docstring consistency across slice sampling modules
- Standardize SliceInfo with default values for clean initialization
- Fix function signatures and return types throughout nested sampling codebase
- Enable posterior repartitioning techniques through explicit component separation

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
…ampling

Clean up type annotations by removing redundant # type: ignore directives:
- Remove type ignore from SamplingAlgorithm return in nss.py
- Remove type ignores from NSInfo constructor in utils.py

All type checks now pass without suppression directives.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
Fix broadcasting error by vmapping over both random keys and inner state.
The inner kernel expects single particles, not batches, so we need to
vmap over the PartitionedState structure (axis 0) in addition to the keys.

This resolves the "mul got incompatible shapes for broadcasting" error
that occurred when constraint functions tried to evaluate likelihood
on vectorized inputs.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
feat: Major nested sampling and slice sampling algorithm enhancements
This PR is a large refactor of the base functions presented originally:
The core Nested Sampling framework (base, adaptive, utils) has been significantly refactored for enhanced flexibility, improved state tracking, and a clearer API.
The previous vectorized_slice module is removed, superseded by the more general HRSS implementation, which has been included as a more general mcmc kernel.
Add parameter m (default=10) to control maximum expansion steps during
the stepping-out phase, following Radford Neal's slice sampling algorithm.
This prevents infinite loops when the slice extends indefinitely and
ensures computational efficiency by bounding the interval expansion.

- Add m parameter to build_kernel() and build_hrss_kernel()
- Modify stepping-out loop to use bounded counters j and k
- Remove unused d field from SliceInfo for cleaner interface
- Update horizontal_slice() to implement Neal's bounded expansion

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
Correct the shrinking loop condition to use max_shrink_steps+1, ensuring
that max_shrink_steps represents the actual number of shrinking attempts
allowed rather than the loop counter limit. This fixes the issue where
max_shrink_steps=1 would incorrectly reject all proposals that required
any shrinking iterations.

- Update shrink_cond_fun to allow exactly max_shrink_steps iterations
- Fix acceptance condition to properly handle step count semantics
- Add static_binomial_sampling for proper proposal acceptance/rejection

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
Rename parameters for better clarity:
- max_steps_out → max_steps (stepping-out phase limit)
- max_shrink_steps → max_shrinkage (shrinking phase limit)

Update function names in horizontal_slice for better organization:
- body_fun → step_body_fun (stepping-out procedure)
- cond_fun → step_cond_fun (stepping-out condition)

Fix docstring parameter descriptions to match actual function signature.
Update both slice sampling (ss.py) and nested slice sampling (nss.py)
modules for consistency with defaults max_steps=10, max_shrinkage=100.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
Implement maximum step limits and acceptance handling for slice sampling
* Refactor nested sampling parallelization strategy

Move vmap from base kernel to algorithm level to enable flexible
parallelization strategies (vmap vs pmap) for varying likelihood costs.

- Remove hardcoded vmap from ns/base.py kernel
- Apply vmap at algorithm level in ns/nss.py
- Maintains identical behavior and results
- Enables future pmap support for load balancing

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

* Refactor nested sampling init parallelization

Move vmap control from adaptive layer to algorithm level for
better parallelization strategy visibility and control.

- Remove hardcoded vmaps from ns/base.py init function
- Move vectorization from adaptive.py to nss.py
- Centralize both init and kernel parallelization in nss.py
- Clean up intermediate variables for better readability
- Maintains identical behavior and results

This completes the parallelization refactor, enabling future
pmap support for load balancing with varying likelihood costs.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

* Fix linting issues and standardize init function signature

- Remove unused jax import from adaptive.py (flake8)
- Update init_fn signature to follow InitFn protocol (mypy)
- Use standard (position, rng_key=None) signature like SMC
- Maintain internal use of 'particles' parameter name
- All pre-commit hooks now pass

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

* Enable CI checks for nested_sampling branch

- Add nested_sampling to test.yml workflow triggers
- Allows CI to run on PRs targeting nested_sampling branch
- Ensures code quality checks for development branch

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

---------

Co-authored-by: Claude <[email protected]>
* Fix state handling bugs in slice sampling

- Fix nss.py: Use new_slice_state.constraint instead of slice_state.constraint when returning loglikelihood
- Fix ss.py: Initialize shrinking loop with is_accepted=False to ensure loop executes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

* Removed defaults

* bugfix

* Fix slice sampling implementation bugs and simplify logic

- Fix undefined variable bug: return new_state instead of slice_state
- Fix dtype bug: strict array must be boolean for JAX compatibility
- Simplify constraint checking without jnp.append for cleaner code
- Remove redundant state selection in build_kernel (already handled in horizontal_slice)
- Fix expansion step counting formula to m + 1 - j - k (off-by-one error)
- Simplify shrink loop condition to n < max_shrinkage (removes edge case)
- Clean up loop variable extraction for better readability

These changes were validated with GPT-5 review and extensive testing.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

---------

Co-authored-by: Claude <[email protected]>
* Adjusted MCLMC (blackjax-devs#675)

* TESTS

* TESTS

* UPDATE DOCSTRING

* ADD STREAMING VERSION

* ADD PRECONDITIONING TO MCLMC

* ADD PRECONDITIONING TO TUNING FOR MCLMC

* UPDATE GITIGNORE

* UPDATE GITIGNORE

* UPDATE TESTS

* UPDATE TESTS

* ADD DOCSTRING

* ADD TEST

* STREAMING AVERAGE

* ADD TEST

* REFACTOR RUN_INFERENCE_ALGORITHM

* UPDATE DOCSTRING

* Precommit

* CLEAN TESTS

* FIX BAD MERGE

* ADJUSTED MCLMC

* REMOVE BENCHMARKS:

* ADD ADJUSTED MCLMC

* GITIGNORE

* PRECOMMIT CLEAN UP

* FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS

* TEMPORARILY ADD BENCHMARKS

* ADD ADJUSTED MCLMC TUNING

* CLEAN

* UNIFY ADJUSTED MCLMC AND MCHMC

* ADD INITIAL_POSITION

* FIX TEST

* CLEAN UP

* REMOVE BENCHMARKS

* ADD TEST

* REMOVE BENCHMARKS

* MODIFY WINDOW ADAPTATION TO TAKE INTEGRATOR

* MODIFY WINDOW ADAPTATION TO TAKE INTEGRATOR

* BUG FIX

* CHANGE PRECISION

* CHANGE PRECISION

* ADD OMELYAN TEST

* ADD ADJUSTED MCLMC TEST

* ADD ADJUSTED MCLMC TEST

* RENAME O

* UPDATE STREAMING AVG

* UPDATE STREAMING AVG

* FIX MERGE

* UPDATE PR

* RENAME STD_MAT

* RENAME STD_MAT

* RENAME STD_MAT

* MERGE MAIN

* REMOVE COEFFICIENT EXPORTS

* REMOVE COEFFICIENT EXPORTS

* RESOLVE MYPY ISSUE

* RESOLVE MYPY ISSUE

* RETURN EXPECTATION HISTORY

* FIX KWARG BUG

* FIX KWARG BUG

* FIX KWARG BUG IN ADJUSTED MCLMC

* MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT

* L_proposal_factor

* SPLIT TUNING FOR AMCLMC INTO SEPARATE FILE

* SPLIT TUNING FOR AMCLMC INTO SEPARATE FILE

* RENAME STREAMING_AVERAGE_UPDATE ARGS IN ADJUSTED MCLMC ADAPTATION

* diagnostics

* fix bugs

* FIX MINOR TUNING BUGS

* UPDATE TUNING

* UPDATE TUNING

* UPDATE TUNING

* names

* test

* tuning

* update

* ready for test

* ready for test

* ready for test

* Update blackjax/adaptation/adjusted_mclmc_adaptation.py

Co-authored-by: Junpeng Lao <[email protected]>

* edit

---------

Co-authored-by: Junpeng Lao <[email protected]>

* SMC Pretuning (blackjax-devs#765)

* extracting taking last

* test passing

* layering

* example

* more

* Adding another example

* tests in place

* rolling back changes

* Adding test for num_mcmc_steps

* format

* better test coverage

* linter

* Flake8

* black

* implementation[

* partial posteriors implementation

* rolling back some changes

* linter

* fixing test

* adding reference

* typo

* exposing in top level api

* reruning precommit

* up to now

* one step working

* fixes

* tests passing

* checkpoint tests passing

* more

* tests passing, implementation in place

* tests passing

* rounding

* adding to init

* rollbacks

* rollback

* rollback

* docs

* precommit

* removing extra parameter

* code review updates

* Remove meeting scheduling (blackjax-devs#768)

* Remove meeting scheduling

* Fix tests

* Adjusted MCLMC (blackjax-devs#771)

* test CI

* test CI

* test CI: add static

* test CI: add static

* test CI: add static tests

* Revert "test CI: add static"

This reverts commit 2db919d.

* Revert "test CI: add static"

This reverts commit fa6558f.

* test CI: add static tests

* test CI: add static tests

* test CI: add static tests

* test CI: old tests

* test CI: old tests

* test CI: old tests with addition

* test CI: old tests with addition of num tuning steps

* test in place (blackjax-devs#772)

* MCLMC adaptation total num steps and initial guess (blackjax-devs#778)

* total_num_tuning_integrator_steps

* Initial params for MCLMC adaptation

* SMC: Joint tuning and pretuning (blackjax-devs#776)

* impl

* rename

* docs

---------

Co-authored-by: Junpeng Lao <[email protected]>

* Energy error monitoring (blackjax-devs#784)

* energy error monitoring

* energy error monitoring

* jnp abs

* ping Jaxopt version to unbreak test (blackjax-devs#789)

* ping Jaxopt version to unbreak test

* lower jaxopt version

* Apply pre-commit formatting fixes

- Fixed import ordering in blackjax/mcmc/ss.py
- Applied Black formatting to mcmc/ss.py, ns/base.py, and ns/nss.py
- Ensures consistent code style across the codebase

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

* Fix CI: Remove parallel testing to avoid pytest-benchmark/xdist conflict

The test suite was failing with an INTERNALERROR because pytest-benchmark
raises a warning when used with xdist (parallel testing), and our
filterwarnings=error configuration turns this into a fatal error.

Since benchmarks are disabled in CI anyway (-m 'not benchmark'),
removing parallel execution is a reasonable tradeoff to get tests running.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

* Revert "Fix CI: Remove parallel testing to avoid pytest-benchmark/xdist conflict"

This reverts commit d321044.

---------

Co-authored-by: Reuben <[email protected]>
Co-authored-by: Junpeng Lao <[email protected]>
Co-authored-by: Carlos Iguaran <[email protected]>
Co-authored-by: Hugo Simon-Onfroy <[email protected]>
Co-authored-by: Claude <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants