diff --git a/.github/workflows/schedule-meeting.yml b/.github/workflows/schedule-meeting.yml deleted file mode 100644 index 0575bd20f..000000000 --- a/.github/workflows/schedule-meeting.yml +++ /dev/null @@ -1,18 +0,0 @@ -# Open a Meeting issue the 25th day of the month. -# Meetings happen on the first Friday of the month -name: Open a meeting issue -on: - schedule: - - cron: '0 0 20 * *' - workflow_dispatch: - -jobs: - create-meeting-issue: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: JasonEtco/create-an-issue@v2 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - filename: .github/ISSUE_TEMPLATE/meeting.md diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 93a9cb5d6..c21ac7cea 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,9 +2,9 @@ name: Tests on: pull_request: - branches: [main] + branches: [main, nested_sampling] push: - branches: [main] + branches: [main, nested_sampling] jobs: style: diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 5858c34aa..45555900f 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -3,6 +3,7 @@ from blackjax._version import __version__ +from .adaptation.adjusted_mclmc_adaptation import adjusted_mclmc_find_L_and_step_size from .adaptation.chees_adaptation import chees_adaptation from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .adaptation.meads_adaptation import meads_adaptation @@ -11,6 +12,8 @@ from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat +from .mcmc import adjusted_mclmc as _adjusted_mclmc +from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice @@ -22,12 +25,14 @@ from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk from .mcmc import rmhmc as _rmhmc +from .mcmc import ss from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk from .mcmc.random_walk import ( irmh_as_top_level_api, normal_random_walk, rmh_as_top_level_api, ) +from .ns import nss as _nss from .optimizers import dual_averaging, lbfgs from .sgmcmc import csgld as _csgld from .sgmcmc import sghmc as _sghmc @@ -36,6 +41,7 @@ from .smc import adaptive_tempered from .smc import inner_kernel_tuning as _inner_kernel_tuning from .smc import partial_posteriors_path as _partial_posteriors_smc +from .smc import pretuning as _pretuning from .smc import tempered from .vi import meanfield_vi as _meanfield_vi from .vi import pathfinder as _pathfinder @@ -109,7 +115,11 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) +hrss = GenerateSamplingAPI(ss.hrss_as_top_level_api, ss.init, ss.build_kernel) + mclmc = generate_top_level_api_from(_mclmc) +adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) +adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) @@ -121,10 +131,16 @@ def generate_top_level_api_from(module): tempered_smc = generate_top_level_api_from(tempered) inner_kernel_tuning = generate_top_level_api_from(_inner_kernel_tuning) partial_posteriors_smc = generate_top_level_api_from(_partial_posteriors_smc) +pretuning = generate_top_level_api_from(_pretuning) smc_family = [tempered_smc, adaptive_tempered_smc, partial_posteriors_smc] "Step_fn returning state has a .particles attribute" +# NS +nss = generate_top_level_api_from(_nss) + +ns_family = [nss] + # stochastic gradient mcmc sgld = generate_top_level_api_from(_sgld) sghmc = generate_top_level_api_from(_sghmc) @@ -160,6 +176,7 @@ def generate_top_level_api_from(module): "chees_adaptation", "pathfinder_adaptation", "mclmc_find_L_and_step_size", # mclmc adaptation + "adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation "ess", # diagnostics "rhat", ] diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py new file mode 100644 index 000000000..408c31383 --- /dev/null +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -0,0 +1,404 @@ +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState +from blackjax.adaptation.step_size import ( + DualAveragingAdaptationState, + dual_averaging_adaptation, +) +from blackjax.diagnostics import effective_sample_size +from blackjax.util import incremental_value_update, pytree_size + +Lratio_lowerbound = 0.0 +Lratio_upperbound = 2.0 + + +def adjusted_mclmc_find_L_and_step_size( + mclmc_kernel, + num_steps, + state, + rng_key, + target, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.0, + diagonal_preconditioning=True, + params=None, + max="avg", + num_windows=1, + tuning_factor=1.3, +): + """ + Finds the optimal value of the parameters for the MH-MCHMC algorithm. + + Parameters + ---------- + mclmc_kernel + The kernel function used for the MCMC algorithm. + num_steps + The number of MCMC steps that will subsequently be run, after tuning. + state + The initial state of the MCMC algorithm. + rng_key + The random number generator key. + target + The target acceptance rate for the step size adaptation. + frac_tune1 + The fraction of tuning for the first step of the adaptation. + frac_tune2 + The fraction of tuning for the second step of the adaptation. + frac_tune3 + The fraction of tuning for the third step of the adaptation. + diagonal_preconditioning + Whether to do diagonal preconditioning (i.e. a mass matrix) + params + Initial params to start tuning from (optional) + max + whether to calculate L from maximum or average eigenvalue. Average is advised. + num_windows + how many iterations of the tuning are carried out + tuning_factor + multiplicative factor for L + + + Returns + ------- + A tuple containing the final state of the MCMC algorithm and the final hyperparameters. + """ + + frac_tune1 /= num_windows + frac_tune2 /= num_windows + frac_tune3 /= num_windows + + dim = pytree_size(state.position) + if params is None: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,)) + ) + + part1_key, part2_key = jax.random.split(rng_key, 2) + + total_num_tuning_integrator_steps = 0 + for i in range(num_windows): + window_key = jax.random.fold_in(part1_key, i) + ( + state, + params, + eigenvector, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=frac_tune2, + target=target, + diagonal_preconditioning=diagonal_preconditioning, + max=max, + tuning_factor=tuning_factor, + )( + state, params, num_steps, window_key + ) + total_num_tuning_integrator_steps += num_tuning_integrator_steps + + if frac_tune3 != 0: + for i in range(num_windows): + part2_key = jax.random.fold_in(part2_key, i) + part2_key1, part2_key2 = jax.random.split(part2_key, 2) + + ( + state, + params, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_adaptation_L( + mclmc_kernel, + frac=frac_tune3, + Lfactor=0.5, + max=max, + eigenvector=eigenvector, + )( + state, params, num_steps, part2_key1 + ) + + total_num_tuning_integrator_steps += num_tuning_integrator_steps + + ( + state, + params, + _, + num_tuning_integrator_steps, + ) = adjusted_mclmc_make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1, + frac_tune2=0, + target=target, + fix_L_first_da=True, + diagonal_preconditioning=diagonal_preconditioning, + max=max, + tuning_factor=tuning_factor, + )( + state, params, num_steps, part2_key2 + ) + + total_num_tuning_integrator_steps += num_tuning_integrator_steps + + return state, params, total_num_tuning_integrator_steps + + +def adjusted_mclmc_make_L_step_size_adaptation( + kernel, + dim, + frac_tune1, + frac_tune2, + target, + diagonal_preconditioning, + fix_L_first_da=False, + max="avg", + tuning_factor=1.0, +): + """Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC""" + + def dual_avg_step(fix_L, update_da): + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + + def step(iteration_state, weight_and_key): + mask, rng_key = weight_and_key + ( + previous_state, + params, + (adaptive_state, step_size_max), + previous_weight_and_average, + ) = iteration_state + + avg_num_integration_steps = params.L / params.step_size + + state, info = kernel( + rng_key=rng_key, + state=previous_state, + avg_num_integration_steps=avg_num_integration_steps, + step_size=params.step_size, + inverse_mass_matrix=params.inverse_mass_matrix, + ) + + # step updating + success, state, step_size_max, energy_change = handle_nans( + previous_state, + state, + params.step_size, + step_size_max, + info.energy, + ) + + with_mask = lambda x, y: mask * x + (1 - mask) * y + + log_step_size, log_step_size_avg, step, avg_error, mu = update_da( + adaptive_state, info.acceptance_rate + ) + + adaptive_state = DualAveragingAdaptationState( + with_mask(log_step_size, adaptive_state.log_step_size), + with_mask(log_step_size_avg, adaptive_state.log_step_size_avg), + with_mask(step, adaptive_state.step), + with_mask(avg_error, adaptive_state.avg_error), + with_mask(mu, adaptive_state.mu), + ) + + step_size = jax.lax.clamp( + 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 + ) + adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) + + x = ravel_pytree(state.position)[0] + + # update the running average of x, x^2 + previous_weight_and_average = incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=previous_weight_and_average, + weight=(1 - mask) * success * step_size, + zero_prevention=mask, + ) + + params = params._replace(step_size=with_mask(step_size, params.step_size)) + if not fix_L: + params = params._replace( + L=with_mask(params.L * (step_size / params.step_size), params.L), + ) + + state_position = state.position + + return ( + state, + params, + (adaptive_state, step_size_max), + previous_weight_and_average, + ), ( + info, + state_position, + ) + + return step + + def step_size_adaptation(mask, state, params, keys, fix_L, initial_da, update_da): + return jax.lax.scan( + dual_avg_step(fix_L, update_da), + init=( + state, + params, + (initial_da(params.step_size), jnp.inf), # step size max + (0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), + ), + xs=(mask, keys), + ) + + def L_step_size_adaptation(state, params, num_steps, rng_key): + num_steps1, num_steps2 = int(num_steps * frac_tune1), int( + num_steps * frac_tune2 + ) + + check_key, rng_key = jax.random.split(rng_key, 2) + + rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) + L_step_size_adaptation_keys_pass1 = jax.random.split( + rng_key_pass1, num_steps1 + num_steps2 + ) + L_step_size_adaptation_keys_pass2 = jax.random.split(rng_key_pass2, num_steps1) + + # determine which steps to ignore in the streaming average + mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, position_samples), + ) = step_size_adaptation( + mask, + state, + params, + L_step_size_adaptation_keys_pass1, + fix_L=fix_L_first_da, + initial_da=initial_da, + update_da=update_da, + ) + + num_tuning_integrator_steps = info.num_integration_steps.sum() + final_stepsize = final_da(dual_avg_state) + params = params._replace(step_size=final_stepsize) + + # determine L + eigenvector = None + if num_steps2 != 0.0: + x_average, x_squared_average = average[0], average[1] + variances = x_squared_average - jnp.square(x_average) + + if max == "max": + contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor + + elif max == "avg": + contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor + + else: + raise ValueError("max should be either 'max' or 'avg'") + + change = jax.lax.clamp( + Lratio_lowerbound, + contract(variances) / params.L, + Lratio_upperbound, + ) + params = params._replace( + L=params.L * change, step_size=params.step_size * change + ) + if diagonal_preconditioning: + params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) + + initial_da, update_da, final_da = dual_averaging_adaptation(target=target) + ( + (state, params, (dual_avg_state, step_size_max), (_, average)), + (info, params_history), + ) = step_size_adaptation( + jnp.ones(num_steps1), + state, + params, + L_step_size_adaptation_keys_pass2, + fix_L=True, + update_da=update_da, + initial_da=initial_da, + ) + + num_tuning_integrator_steps += info.num_integration_steps.sum() + + params = params._replace(step_size=final_da(dual_avg_state)) + + return state, params, eigenvector, num_tuning_integrator_steps + + return L_step_size_adaptation + + +def adjusted_mclmc_make_adaptation_L( + kernel, frac, Lfactor, max="avg", eigenvector=None +): + """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" + + def adaptation_L(state, params, num_steps, key): + num_steps = int(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps) + + def step(state, key): + next_state, info = kernel( + rng_key=key, + state=state, + step_size=params.step_size, + avg_num_integration_steps=params.L / params.step_size, + inverse_mass_matrix=params.inverse_mass_matrix, + ) + return next_state, (next_state.position, info) + + state, (samples, info) = jax.lax.scan( + f=step, + init=state, + xs=adaptation_L_keys, + ) + + if max == "max": + contract = jnp.min + else: + contract = jnp.mean + + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + + if eigenvector is not None: + flat_samples = jnp.expand_dims( + jnp.einsum("ij,j", flat_samples, eigenvector), 1 + ) + + # number of effective samples per 1 actual sample + ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps + + return ( + state, + params._replace( + L=jnp.clip( + Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound + ) + ), + info.num_integration_steps.sum(), + ) + + return adaptation_L + + +def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change): + """if there are nans, let's reduce the stepsize, and not update the state. The + function returns the old state in this case.""" + + reduced_step_size = 0.8 + p, unravel_fn = ravel_pytree(next_state.position) + nonans = jnp.all(jnp.isfinite(p)) + state, step_size, kinetic_change = jax.tree_util.tree_map( + lambda new, old: jax.lax.select(nonans, jnp.nan_to_num(new), old), + (next_state, step_size_max, kinetic_change), + (previous_state, step_size * reduced_step_size, 0.0), + ) + + return nonans, state, step_size, kinetic_change diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 831586201..fa644898a 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov + inverse_mass_matrix A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov: float + inverse_mass_matrix: float def mclmc_find_L_and_step_size( @@ -51,6 +51,7 @@ def mclmc_find_L_and_step_size( trust_in_estimate=1.5, num_effective_samples=150, diagonal_preconditioning=True, + params=None, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -77,6 +78,10 @@ def mclmc_find_L_and_step_size( The trust in the estimate of optimal stepsize. num_effective_samples The number of effective samples for the MCMC algorithm. + diagonal_preconditioning + Whether to do diagonal preconditioning (i.e. a mass matrix) + params + Initial params to start tuning from (optional) Returns ------- @@ -85,10 +90,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - std_mat=std_mat, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -103,10 +108,19 @@ def mclmc_find_L_and_step_size( ) """ dim = pytree_size(state.position) - params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) - ) + if params is None: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) + ) + part1_key, part2_key = jax.random.split(rng_key, 2) + total_num_tuning_integrator_steps = 0 + + num_steps1, num_steps2 = round(num_steps * frac_tune1), round( + num_steps * frac_tune2 + ) + num_steps2 += diagonal_preconditioning * (num_steps2 // 3) + num_steps3 = round(num_steps * frac_tune3) state, params = make_L_step_size_adaptation( kernel=mclmc_kernel, @@ -118,13 +132,15 @@ def mclmc_find_L_and_step_size( num_effective_samples=num_effective_samples, diagonal_preconditioning=diagonal_preconditioning, )(state, params, num_steps, part1_key) + total_num_tuning_integrator_steps += num_steps1 + num_steps2 - if frac_tune3 != 0: + if num_steps3 >= 2: # at least 2 samples for ESS estimation state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) + total_num_tuning_integrator_steps += num_steps3 - return state, params + return state, params, total_num_tuning_integrator_steps def make_L_step_size_adaptation( @@ -137,7 +153,7 @@ def make_L_step_size_adaptation( trust_in_estimate=1.5, num_effective_samples=150, ): - """Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC""" + """Adapts the stepsize and L of the MCLMC kernel. Designed for unadjusted MCLMC""" decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) @@ -150,7 +166,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): rng_key, nan_key = jax.random.split(rng_key) # dynamics - next_state, info = kernel(params.sqrt_diag_cov)( + next_state, info = kernel(params.inverse_mass_matrix)( rng_key=rng_key, state=previous_state, L=params.L, @@ -223,10 +239,10 @@ def step(iteration_state, weight_and_key): )[0] def L_step_size_adaptation(state, params, num_steps, rng_key): - num_steps1, num_steps2 = ( - int(num_steps * frac_tune1) + 1, - int(num_steps * frac_tune2) + 1, + num_steps1, num_steps2 = round(num_steps * frac_tune1), round( + num_steps * frac_tune2 ) + L_step_size_adaptation_keys = jax.random.split( rng_key, num_steps1 + num_steps2 + 1 ) @@ -245,25 +261,25 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov = params.sqrt_diag_cov + inverse_mass_matrix = params.inverse_mass_matrix if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov=sqrt_diag_cov) + inverse_mass_matrix = variances + params = params._replace(inverse_mass_matrix=inverse_mass_matrix) L = jnp.sqrt(dim) # readjust the stepsize - steps = num_steps2 // 3 # we do some small number of steps + steps = round(num_steps2 / 3) # we do some small number of steps keys = jax.random.split(final_key, steps) state, params, _, (_, average) = run_steps( xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return L_step_size_adaptation @@ -272,8 +288,8 @@ def make_adaptation_L(kernel, frac, Lfactor): """determine L by the autocorrelations (around 10 effective samples are needed for this to be accurate)""" def adaptation_L(state, params, num_steps, key): - num_steps = int(num_steps * frac) - adaptation_L_keys = jax.random.split(key, num_steps) + num_steps_3 = round(num_steps * frac) + adaptation_L_keys = jax.random.split(key, num_steps_3) def step(state, key): next_state, _ = kernel( @@ -295,7 +311,7 @@ def step(state, key): ess = effective_sample_size(flat_samples[None, ...]) return state, params._replace( - L=Lfactor * params.step_size * jnp.mean(num_steps / ess) + L=Lfactor * params.step_size * jnp.mean(num_steps_3 / ess) ) return adaptation_L diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index 93480302e..257ce759c 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -115,6 +115,9 @@ def effective_sample_size( sample_axis = sample_axis if sample_axis >= 0 else len(input_shape) + sample_axis num_chains = input_shape[chain_axis] num_samples = input_shape[sample_axis] + assert ( + num_samples > 1 + ), f"The input array must have at least 2 samples, got only {num_samples}." mean_across_chain = input_array.mean(axis=sample_axis, keepdims=True) # Compute autocovariance estimates for every lag for the input array using FFT. diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 6e207741d..8acb28274 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,4 +1,6 @@ from . import ( + adjusted_mclmc, + adjusted_mclmc_dynamic, barker, elliptical_slice, ghmc, @@ -24,4 +26,6 @@ "marginal_latent_gaussian", "random_walk", "mclmc", + "adjusted_mclmc_dynamic", + "adjusted_mclmc", ] diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py new file mode 100644 index 000000000..f390402f2 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -0,0 +1,242 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". + +NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. + +""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.hmc import HMCInfo, HMCState +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return HMCState(position, logdensity, logdensity_grad) + + +def build_kernel( + logdensity_fn: Callable, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + inverse_mass_matrix=1.0, +): + """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: HMCState, + step_size: float, + num_integration_steps: int, + L_proposal_factor: float = jnp.inf, + ) -> tuple[HMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + HMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + inverse_mass_matrix=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + num_integration_steps, +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + logdensity_fn=logdensity_fn, + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=num_integration_steps, + L_proposal_factor=L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py new file mode 100644 index 000000000..1a69e1a28 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -0,0 +1,259 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + inverse_mass_matrix=1.0, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn(state.random_generator_arg) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + inverse_mass_matrix=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + inverse_mass_matrix=inverse_mass_matrix, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index e9d19e3dc..0effa204e 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -311,7 +311,9 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): +def esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0): + sqrt_inverse_mass_matrix = jnp.sqrt(inverse_mass_matrix) + def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -330,7 +332,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov + flatten_grads = flatten_grads * sqrt_inverse_mass_matrix flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -342,7 +344,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) + gr = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -374,11 +376,11 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 + logdensity_fn: Callable, inverse_mass_matrix: ArrayTree = 1.0 ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov), + esh_dynamics_momentum_update_one_step(inverse_mass_matrix), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, @@ -414,11 +416,19 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): ------- momentum with random change in angle """ + m, unravel_fn = ravel_pytree(momentum) dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - return unravel_fn((m + z) / jnp.linalg.norm(m + z)) + new_momentum = unravel_fn((m + z) / jnp.linalg.norm(m + z)) + # return new_momentum + return jax.lax.cond( + jnp.isinf(L), + lambda _: momentum, + lambda _: new_momentum, + operand=None, + ) def with_isokinetic_maruyama(integrator): diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e7a69849b..d4a235770 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,6 +15,7 @@ from typing import Callable, NamedTuple import jax +import jax.numpy as jnp from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import ( @@ -60,7 +61,13 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): +def build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +): """Build a HMC kernel. Parameters @@ -81,7 +88,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ step = with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) def kernel( @@ -91,14 +98,33 @@ def kernel( state, step_size, L, rng_key ) - return IntegratorState( - position, momentum, logdensity, logdensitygrad - ), MCLMCInfo( - logdensity=logdensity, - energy_change=kinetic_change - logdensity + state.logdensity, - kinetic_change=kinetic_change, + energy_error = kinetic_change - logdensity + state.logdensity + + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) + + new_state, new_info = jax.lax.cond( + jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim), + lambda: ( + state, + MCLMCInfo( + logdensity=state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + ), + ), + lambda: ( + IntegratorState(position, momentum, logdensity, logdensitygrad), + MCLMCInfo( + logdensity=logdensity, + energy_change=energy_error, + kinetic_change=kinetic_change, + ), + ), ) + return new_state, new_info + return kernel @@ -107,7 +133,8 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +182,12 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) + kernel = build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + ) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/blackjax/mcmc/ss.py b/blackjax/mcmc/ss.py new file mode 100644 index 000000000..50764ad4a --- /dev/null +++ b/blackjax/mcmc/ss.py @@ -0,0 +1,390 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hit-and-Run Slice Sampling. + +This module implements the Hit-and-Run Slice Sampling algorithm as described by +Neal (2003) [1]_. Slice sampling is an MCMC method that adapts its step size +automatically and can be efficient for sampling from distributions with complex +geometries. The "hit-and-run" variant involves proposing a direction and then +finding an acceptable point along that direction within a slice defined by the +current auxiliary variable. + +References +---------- +.. [1] Neal, R. M. (2003). Slice sampling. The Annals of Statistics, 31(3), 705-767. + +""" + +from functools import partial +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = [ + "SliceState", + "SliceInfo", + "init", + "build_kernel", + "build_hrss_kernel", + "hrss_as_top_level_api", +] + + +class SliceState(NamedTuple): + """State of the Slice Sampling algorithm. + + Attributes + ---------- + position + The current position of the chain. + logdensity + The log-density of the target distribution at the current position. + """ + + position: ArrayLikeTree + logdensity: float + constraint: Array + + +class SliceInfo(NamedTuple): + """Additional information about the Slice Sampling transition. + + This information can be used for diagnostics and monitoring the sampler's + performance. + + Attributes + ---------- + is_accepted + A boolean indicating whether the proposed sample was accepted. + constraint + The constraint values at the final accepted position. + num_steps + The number of steps taken to expand the interval during the "stepping-out" phase. + num_shrink + The number of steps taken during the "shrinking" phase to find an + acceptable sample. + """ + + is_accepted: bool + num_steps: int + num_shrink: int + + +def init( + position: ArrayTree, logdensity_fn: Callable, constraint_fn: Callable +) -> SliceState: + """Initialize the Slice Sampler state. + + Parameters + ---------- + position + The initial position of the chain. + logdensity_fn + A function that computes the log-density of the target distribution. + + Returns + ------- + SliceState + The initial state of the Slice Sampler. + """ + return SliceState(position, logdensity_fn(position), constraint_fn(position)) + + +def build_kernel( + stepper_fn: Callable, + max_steps: int = 10, + max_shrinkage: int = 100, +) -> Callable: + """Build a Slice Sampling kernel. + + This kernel performs one step of Slice Sampling algorithm, which involves + defining a vertical slice, stepping out to define an interval, and then + shrinking that interval to find an acceptable new sample. + + Parameters + ---------- + stepper_fn + A function that computes a new position given an initial position, + direction `d` and a slice parameter `t`. + `(x0, d, t) -> x_new` where e.g. `x_new = x0 + t * d`. + + Returns + ------- + Callable + A kernel function that takes a PRNG key, the current `SliceState`, + the log-density function, direction `d`, constraint function, constraint + values, and strict flags, and returns a new `SliceState` and `SliceInfo`. + + References + ---------- + .. [1] Neal, R. M. (2003). Slice sampling. The Annals of Statistics, 31(3), 705-767. + """ + + def kernel( + rng_key: PRNGKey, + state: SliceState, + logdensity_fn: Callable, + d: ArrayTree, + constraint_fn: Callable, + constraint: Array, + strict: Array, + ) -> tuple[SliceState, SliceInfo]: + vs_key, hs_key = jax.random.split(rng_key) + logslice = state.logdensity + jnp.log(jax.random.uniform(vs_key)) + vertical_is_accepted = logslice < state.logdensity + + def slicer(t) -> tuple[SliceState, SliceInfo]: + x, step_accepted = stepper_fn(state.position, d, t) + new_state = init(x, logdensity_fn, constraint_fn) + constraints_ok = jnp.all( + jnp.where( + strict, + new_state.constraint > constraint, + new_state.constraint >= constraint, + ) + ) + in_slice = new_state.logdensity >= logslice + is_accepted = in_slice & constraints_ok & step_accepted + return new_state, is_accepted + + new_state, info = horizontal_slice( + hs_key, slicer, state, max_steps, max_shrinkage + ) + info = info._replace(is_accepted=info.is_accepted & vertical_is_accepted) + return new_state, info + + return kernel + + +def horizontal_slice( + rng_key: PRNGKey, + slicer: Callable, + state: SliceState, + m: int, + max_shrinkage: int, +) -> tuple[SliceState, SliceInfo]: + """Propose a new sample using the stepping-out and shrinking procedures. + + This function implements the core of the Hit-and-Run Slice Sampling algorithm. + It first expands an interval (`[l, r]`) along the slice starting + from `x0` and proceeding along direction `d` until both ends are outside + the slice defined by `logslice` (stepping-out). Then, it samples + points uniformly from this interval and shrinks the interval until a point + is found that lies within the slice (shrinking). + + Parameters + ---------- + rng_key + A JAX PRNG key. + slicer + A function that takes a scalar `t` and returns a state and info on the + slice. + state + The current slice sampling state. + m + The maximum number of steps to take when expanding the interval in + each direction during the stepping-out phase. + max_shrinkage + The maximum number of shrinking steps to perform to avoid infinite loops. + + Returns + ------- + tuple[SliceState, SliceInfo] + A tuple containing the new state (with the accepted sample and its + log-density) and information about the sampling process (number of + expansion and shrinkage steps). + """ + # Initial bounds + rng_key, subkey = jax.random.split(rng_key) + u, v = jax.random.uniform(subkey, 2) + j = jnp.floor(m * v).astype(int) + k = (m - 1) - j + + # Expand + def step_body_fun(carry): + i, s, t, _ = carry + t += s + _, is_accepted = slicer(t) + i -= 1 + return i, s, t, is_accepted + + def step_cond_fun(carry): + i, _, _, is_accepted = carry + return is_accepted & (i > 0) + + j, _, l, _ = jax.lax.while_loop( + step_cond_fun, step_body_fun, (j + 1, -1, 1 - u, True) + ) + k, _, r, _ = jax.lax.while_loop(step_cond_fun, step_body_fun, (k + 1, +1, -u, True)) + + # Shrink + def shrink_body_fun(carry): + n, rng_key, l, r, state, is_accepted = carry + + rng_key, subkey = jax.random.split(rng_key) + u = jax.random.uniform(subkey, minval=l, maxval=r) + + new_state, is_accepted = slicer(u) + n += 1 + + l = jnp.where(u < 0, u, l) + r = jnp.where(u > 0, u, r) + + return n, rng_key, l, r, new_state, is_accepted + + def shrink_cond_fun(carry): + n, _, _, _, _, is_accepted = carry + return ~is_accepted & (n < max_shrinkage) + + carry = 0, rng_key, l, r, state, False + carry = jax.lax.while_loop(shrink_cond_fun, shrink_body_fun, carry) + n, _, _, _, new_state, is_accepted = carry + new_state, (is_accepted, _, _) = static_binomial_sampling( + rng_key, jnp.log(is_accepted), state, new_state + ) + slice_info = SliceInfo(is_accepted, m + 1 - j - k, n) + return new_state, slice_info + + +def build_hrss_kernel( + generate_slice_direction_fn: Callable, + stepper_fn: Callable, + max_steps: int = 10, +) -> Callable: + """Build a Hit-and-Run Slice Sampling kernel. + + This kernel performs one step of the Hit-and-Run Slice Sampling algorithm, + which involves defining a vertical slice, proposing a direction, stepping out + to define an interval, and then shrinking that interval to find an acceptable + new sample. + + Parameters + ---------- + generate_slice_direction_fn + A function that, given a PRNG key, generates a direction vector (PyTree + with the same structure as the position) for the "hit-and-run" part of + the algorithm. This direction is typically normalized. + + stepper_fn + A function that computes a new position given an initial position, a + direction, and a step size `t`. It should implement something analogous + to `x_new = x_initial + t * direction`. + + Returns + ------- + Callable + A kernel function that takes a PRNG key, the current `SliceState`, and + the log-density function, and returns a new `SliceState` and `SliceInfo`. + """ + slice_kernel = build_kernel(stepper_fn, max_steps) + + def kernel( + rng_key: PRNGKey, state: SliceState, logdensity_fn: Callable + ) -> tuple[SliceState, SliceInfo]: + rng_key, prop_key = jax.random.split(rng_key, 2) + d = generate_slice_direction_fn(prop_key) + constraint_fn = lambda x: jnp.array([]) + constraint = jnp.array([]) + strict = jnp.array([], dtype=bool) + return slice_kernel( + rng_key, state, logdensity_fn, d, constraint_fn, constraint, strict + ) + + return kernel + + +def default_stepper_fn(x: ArrayTree, d: ArrayTree, t: float) -> ArrayTree: + """A simple stepper function that moves from `x` along direction `d` by `t` units. + + Implements the operation: `x_new = x + t * d`. + + Parameters + ---------- + x + The starting position (PyTree). + d + The direction of movement (PyTree, same structure as `x`). + t + The scalar step size or distance along the direction. + + Returns + ------- + position, is_accepted + """ + return jax.tree.map(lambda x, d: x + t * d, x, d), True + + +def sample_direction_from_covariance(rng_key: PRNGKey, cov: Array) -> Array: + """Generates a random direction vector, normalized, from a multivariate Gaussian. + + This function samples a direction `d` from a zero-mean multivariate Gaussian + distribution with covariance matrix `cov`, and then normalizes `d` to be a + unit vector with respect to the Mahalanobis norm defined by `inv(cov)`. + That is, `d_normalized^T @ inv(cov) @ d_normalized = 1`. + + Parameters + ---------- + rng_key + A JAX PRNG key. + cov + The covariance matrix for the multivariate Gaussian distribution from which + the initial direction is sampled. Assumed to be a 2D array. + + Returns + ------- + Array + A normalized direction vector (1D array). + """ + d = jax.random.multivariate_normal(rng_key, mean=jnp.zeros(cov.shape[0]), cov=cov) + invcov = jnp.linalg.inv(cov) + norm = jnp.sqrt(jnp.einsum("...i,...ij,...j", d, invcov, d)) + d = d / norm[..., None] + return d + + +def hrss_as_top_level_api( + logdensity_fn: Callable, + cov: Array, +) -> SamplingAlgorithm: + """Creates a Hit-and-Run Slice Sampling algorithm. + + This function serves as a convenience wrapper to easily construct a + Hit-and-Run Slice Sampler using a default direction proposal mechanism + based on a multivariate Gaussian distribution with the provided covariance. + + Parameters + ---------- + logdensity_fn + The log-density function of the target distribution to sample from. + cov + The covariance matrix used by the default direction proposal function + (`default_proposal_distribution`). This matrix shapes the random + directions proposed for the slice sampling steps. + + Returns + ------- + SamplingAlgorithm + A `SamplingAlgorithm` tuple containing `init` and `step` functions for + the configured Hit-and-Run Slice Sampler. + """ + generate_slice_direction_fn = partial(sample_direction_from_covariance, cov=cov) + kernel = build_hrss_kernel(generate_slice_direction_fn, default_stepper_fn) + init_fn = partial(init, logdensity_fn=logdensity_fn) + step_fn = partial(kernel, logdensity_fn=logdensity_fn) + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/ns/__init__.py b/blackjax/ns/__init__.py new file mode 100644 index 000000000..d9c37553d --- /dev/null +++ b/blackjax/ns/__init__.py @@ -0,0 +1,42 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Nested Sampling Algorithms in BlackJAX. + +This subpackage provides implementations of Nested Sampling algorithms, +including a base version, an adaptive version, and Nested Slice Sampling (NSS). + +Nested Sampling is a Monte Carlo method for Bayesian computation, primarily +used for evidence (marginal likelihood) calculation and posterior sampling. +It is particularly well-suited for problems with multi-modal posteriors or +complex likelihood landscapes. + +Available modules: +------------------ +- `adaptive`: Implements an adaptive Nested Sampling algorithm where inner + kernel parameters are tuned at each iteration. +- `base`: Provides core components and a non-adaptive Nested Sampling kernel. +- `nss`: Implements Nested Slice Sampling, using Hit-and-Run Slice Sampling as + the inner kernel with adaptive tuning of its proposal mechanism. +- `utils`: Contains utility functions for processing and analyzing Nested + Sampling results. + +""" +from . import adaptive, base, nss, utils + +__all__ = [ + "adaptive", + "base", + "utils", + "nss", +] diff --git a/blackjax/ns/adaptive.py b/blackjax/ns/adaptive.py new file mode 100644 index 000000000..1bddb6296 --- /dev/null +++ b/blackjax/ns/adaptive.py @@ -0,0 +1,156 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Adaptive Nested Sampling for BlackJAX. + +This module provides an adaptive version of the Nested Sampling algorithm. +In this variant, the parameters of the inner kernel, which is used to +sample new live points, are updated (tuned) at each iteration of the +Nested Sampling loop. This adaptation is based on the information from the +current set of live particles or the history of the sampling process, +allowing the kernel to adjust to the changing characteristics of the +constrained prior distribution as the likelihood threshold increases. +""" + +from typing import Callable, Dict, Optional + +import jax.numpy as jnp + +from blackjax.ns.base import NSInfo, NSState +from blackjax.ns.base import build_kernel as base_build_kernel +from blackjax.ns.base import init as base_init +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = ["init", "build_kernel"] + + +def init( + particles: ArrayLikeTree, + logprior_fn: Callable, + loglikelihood_fn: Callable, + loglikelihood_birth: Array = -jnp.nan, + update_inner_kernel_params_fn: Optional[Callable] = None, +) -> NSState: + """Initializes the Nested Sampler state. + + Parameters + ---------- + particles + An initial set of particles (PyTree of arrays) drawn from the prior + distribution. The leading dimension of each leaf array must be equal to + the number of particles. + loglikelihood_fn + A function that computes the log-likelihood of a single particle. + logprior_fn + A function that computes the log-prior of a single particle. + loglikelihood_birth + The initial log-likelihood birth threshold. Defaults to -NaN, which + implies no initial likelihood constraint beyond the prior. + update_inner_kernel_params_fn + A function that takes the `NSState`, `NSInfo` from the completed NS + step, and the current inner kernel parameters dictionary, and returns + a dictionary of parameters to be used for the kernel in the *next* NS step. + + Returns + ------- + NSState + The initial state of the Nested Sampler. + """ + state = base_init(particles, logprior_fn, loglikelihood_fn, loglikelihood_birth) + if update_inner_kernel_params_fn is not None: + inner_kernel_params = update_inner_kernel_params_fn(state, None, {}) + state = state._replace(inner_kernel_params=inner_kernel_params) + return state + + +def build_kernel( + logprior_fn: Callable, + loglikelihood_fn: Callable, + delete_fn: Callable, + inner_kernel: Callable, + update_inner_kernel_params_fn: Callable[ + [NSState, NSInfo, Dict[str, ArrayTree]], Dict[str, ArrayTree] + ], +) -> Callable: + """Build an adaptive Nested Sampling kernel. + + This kernel extends the base Nested Sampling kernel by re-computing/tuning + the parameters for the inner kernel at each step. The `update_inner_kernel_params_fn` + is called after each NS step to determine the parameters for the *next* NS + step. + + Parameters + ---------- + logprior_fn + A function that computes the log-prior probability of a single particle. + loglikelihood_fn + A function that computes the log-likelihood of a single particle. + delete_fn + this particle deletion function has the signature + `(rng_key, current_state) -> (dead_idx, target_update_idx, start_idx)` + and identifies particles to be deleted, particles to be updated, and + selects live particles to be starting points for the inner kernel + for new particle generation. + inner_kernel + This kernel function has the signature + `(rng_key, inner_state, logprior_fn, loglikelihood_fn, loglikelihood_0, inner_kernel_params) -> (new_inner_state, inner_info)`, + and is used to generate new particles. + update_inner_kernel_params_fn + A function that takes the `NSState`, `NSInfo` from the completed NS + step, and the current inner kernel parameters dictionary, and returns + a dictionary of parameters to be used for the kernel in the *next* NS step. + + Returns + ------- + Callable + A kernel function for adaptive Nested Sampling. It takes an `rng_key` and the + current `NSState` and returns a tuple containing the new `NSState` and + the `NSInfo` for the step. + """ + + base_kernel = base_build_kernel( + logprior_fn, + loglikelihood_fn, + delete_fn, + inner_kernel, + ) + + def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: + """Performs one step of adaptive Nested Sampling. + + This involves running a step of the base Nested Sampling algorithm using + the current inner kernel parameters, and then updating these parameters + for the next step. + + Parameters + ---------- + rng_key + A JAX PRNG key. + state + The current `NSState`. + + Returns + ------- + tuple[NSState, NSInfo] + A tuple with the new `NSState` (including updated inner kernel + parameters) and the `NSInfo` for this step. + """ + new_state, info = base_kernel(rng_key, state) + + inner_kernel_params = update_inner_kernel_params_fn( + new_state, info, new_state.inner_kernel_params + ) + new_state = new_state._replace(inner_kernel_params=inner_kernel_params) + return new_state, info + + return kernel diff --git a/blackjax/ns/base.py b/blackjax/ns/base.py new file mode 100644 index 000000000..978f242d4 --- /dev/null +++ b/blackjax/ns/base.py @@ -0,0 +1,448 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base components for Nested Sampling algorithms in BlackJAX. + +This module provides the fundamental data structures (`NSState`, `NSInfo`) and +a basic, non-adaptive kernel for Nested Sampling. Nested Sampling is a +Monte Carlo method primarily aimed at Bayesian evidence (marginal likelihood) +computation and posterior sampling, particularly effective for multi-modal +distributions. + +The core idea is to transform the multi-dimensional evidence integral into a +one-dimensional integral over the prior volume, ordered by likelihood. This is +achieved by iteratively replacing the point with the lowest likelihood among a +set of "live" points with a new point sampled from the prior, subject to the +constraint that its likelihood must be higher than the one just discarded. + +This base implementation uses a provided kernel to perform the constrained +sampling. +""" + +from typing import Callable, Dict, NamedTuple, Optional + +import jax +import jax.numpy as jnp +from jax.scipy.special import logsumexp + +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = ["init", "build_kernel", "NSState", "NSInfo", "delete_fn"] + + +class NSState(NamedTuple): + """State of the Nested Sampler. + + Attributes + ---------- + particles + A PyTree of arrays, where each leaf array has a leading dimension + equal to the number of live particles. Stores the current positions of + the live particles. + loglikelihood + An array of log-likelihood values, one for each live particle. + loglikelihood_birth + An array storing the log-likelihood threshold that each current live + particle was required to exceed when it was "born" (i.e., sampled). + This is used for reconstructing the nested sampling path. + logprior + An array of log-prior values, one for each live particle. + pid + Particle ID. An array of integers tracking the identity or lineage of + particles, primarily for diagnostic purposes. + logX + The log of the current prior volume estimate. + logZ + The accumulated log evidence estimate from the "dead" points . + logZ_live + The current estimate of the log evidence contribution from the live points. + inner_kernel_params + A dictionary of parameters for the inner kernel. + """ + + particles: ArrayLikeTree + loglikelihood: Array # The log-likelihood of the particles + loglikelihood_birth: Array # The log-likelihood threshold at particle birth + logprior: Array # The log-prior density of the particles + pid: Array # particle IDs + logX: Array # The current log-volume estimate + logZ: Array # The accumulated evidence estimate + logZ_live: Array # The current evidence estimate + inner_kernel_params: Dict # Parameters for the inner kernel + + +class NSInfo(NamedTuple): + """Additional information returned at each step of the Nested Sampling algorithm. + + Attributes + ---------- + particles + The PyTree of particles that were marked as "dead" (replaced) in the + current step. + loglikelihood + The log-likelihood values of the dead particles. + loglikelihood_birth + The birth log-likelihood thresholds of the dead particles. + logprior + The log-prior values of the dead particles. + inner_kernel_info + A NamedTuple (or any PyTree) containing information from the update step + (inner kernel) used to generate new live particles. The content + depends on the specific inner kernel used. + """ + + particles: ArrayTree + loglikelihood: Array # The log-likelihood of the particles + loglikelihood_birth: Array # The log-likelihood threshold at particle birth + logprior: Array # The log-prior density of the particles + inner_kernel_info: NamedTuple # Information from the inner kernel update step + + +class PartitionedState(NamedTuple): + """State container that partitions out the loglikelihood and logprior. + + This intermediate construction wraps around the usual State of an MCMC chain + so that the loglikelihood and logprior can be efficiently recorded, a + necessary step for the Parition function reconstruction that Nested + Sampling builds + + + Attributes + ---------- + position + A PyTree of arrays representing the current positions of the particles. + Each leaf array has a leading dimension corresponding to the number of particles. + logprior + An array of log-prior density values evaluated at the particle positions. + Shape: (n_particles,) + loglikelihood + An array of log-likelihood values evaluated at the particle positions. + Shape: (n_particles,) + """ + + position: ArrayLikeTree # Current positions of particles in the inner kernel + logprior: Array # Log-prior values for particles in the inner kernel + loglikelihood: Array # Log-likelihood values for particles in the inner kernel + + +class PartitionedInfo(NamedTuple): + """Transition information that additionally records a partitioned loglikelihood + and logprior. + + See PartitionedState + + Attributes + ---------- + position + A PyTree of arrays representing the final positions after the transition step. + Structure matches the input particle positions. + logprior + An array of log-prior density values at the final positions. + Kept separate to support posterior repartitioning schemes. + Shape: (n_particles,) + loglikelihood + An array of log-likelihood values at the final positions. + Kept separate to support posterior repartitioning schemes. + Shape: (n_particles,) + info + Additional transition-specific diagnostic information from the step. + The content and structure depend on the specific transition implementation + (e.g., acceptance rates, step sizes, number of evaluations, etc.). + """ + + position: ArrayTree + logprior: ArrayTree + loglikelihood: ArrayTree + info: NamedTuple + + +def new_state_and_info(position, logprior, loglikelihood, info): + """Create new PartitionedState and PartitionedInfo from transition results. + + This utility function packages the results of a transition into the standard + partitioned state and info containers, maintaining the separation of logprior + and loglikelihood components. + + Parameters + ---------- + position + The particle positions after the transition step. + logprior + The log-prior densities at the new positions. + loglikelihood + The log-likelihood values at the new positions. + info + Additional transition-specific information from the step. + + Returns + ------- + tuple[PartitionedState, PartitionedInfo] + A tuple containing the new partitioned state and associated information. + """ + new_state = PartitionedState( + position=position, + logprior=logprior, + loglikelihood=loglikelihood, + ) + info = PartitionedInfo( + position=position, + logprior=logprior, + loglikelihood=loglikelihood, + info=info, + ) + return new_state, info + + +def init( + particles: ArrayLikeTree, + logprior_fn: Callable, + loglikelihood_fn: Callable, + loglikelihood_birth: Array = -jnp.nan, + logX: Optional[Array] = 0.0, + logZ: Optional[Array] = -jnp.inf, +) -> NSState: + """Initializes the Nested Sampler state. + + Parameters + ---------- + particles + An initial set of particles (PyTree of arrays) drawn from the prior + distribution. The leading dimension of each leaf array must be equal to + the number of particles. + logprior_fn + A function that computes the log-prior of a single particle. + loglikelihood_fn + A function that computes the log-likelihood of a single particle. + loglikelihood_birth + The initial log-likelihood birth threshold. Defaults to -NaN, which + implies no initial likelihood constraint beyond the prior. + logX + The initial log prior volume estimate. Defaults to 0.0. + logZ + The initial log evidence estimate. Defaults to -inf. + + Returns + ------- + NSState + The initial state of the Nested Sampler. + """ + loglikelihood = loglikelihood_fn(particles) + loglikelihood_birth = loglikelihood_birth * jnp.ones_like(loglikelihood) + logprior = logprior_fn(particles) + pid = jnp.arange(len(loglikelihood)) + dtype = loglikelihood.dtype + logX = jnp.array(logX, dtype=dtype) + logZ = jnp.array(logZ, dtype=dtype) + logZ_live = logmeanexp(loglikelihood) + logX + inner_kernel_params: Dict = {} + return NSState( + particles, + loglikelihood, + loglikelihood_birth, + logprior, + pid, + logX, + logZ, + logZ_live, + inner_kernel_params, + ) + + +def build_kernel( + logprior_fn: Callable, + loglikelihood_fn: Callable, + delete_fn: Callable, + inner_kernel: Callable, +) -> Callable: + """Build a generic Nested Sampling kernel. + + This kernel implements one step of the Nested Sampling algorithm. In each step: + 1. A set of particles with the lowest log-likelihoods are identified and + marked as "dead" using `delete_fn`. The log-likelihood of the "worst" + of these dead particles (i.e., max among the lowest ones) defines the new + likelihood constraint `loglikelihood_0`. + 2. Live particles are selected (typically with replacement from the remaining + live particles, determined by `delete_fn`) to act as starting points for + the updates. + 3. These selected live particles are evolved using an kernel + `inner_kernel`. The sampling is constrained to the region where + `loglikelihood(new_particle) > loglikelihood_0`. + 4. The newly generated particles replace particles marked for replacement, + (typically the ones that have just been deleted). + 5. The prior volume `logX` and evidence `logZ` are updated based on the + number of deleted particles and their likelihoods. + + This base version does not adapt the kernel parameters. + + Parameters + ---------- + logprior_fn + A function that computes the log-prior probability of a single particle. + loglikelihood_fn + A function that computes the log-likelihood of a single particle. + delete_fn + this particle deletion function has the signature + `(rng_key, current_state) -> (dead_idx, target_update_idx, start_idx)` + and identifies particles to be deleted, particles to be updated, and + selects live particles to be starting points for the inner kernel + for new particle generation. + inner_kernel + This kernel function has the signature + `(rng_key, inner_state, logprior_fn, loglikelihood_fn, loglikelihood_0, params) -> (new_inner_state, inner_info)`, + and is used to generate new particles. + + Returns + ------- + Callable + A kernel function for Nested Sampling: + `(rng_key, state) -> (new_state, ns_info)`. + """ + + def kernel(rng_key: PRNGKey, state: NSState) -> tuple[NSState, NSInfo]: + # Delete, and grab all the dead information + rng_key, delete_fn_key = jax.random.split(rng_key) + dead_idx, target_update_idx, start_idx = delete_fn(delete_fn_key, state) + dead_particles = jax.tree.map(lambda x: x[dead_idx], state.particles) + dead_loglikelihood = state.loglikelihood[dead_idx] + dead_loglikelihood_birth = state.loglikelihood_birth[dead_idx] + dead_logprior = state.logprior[dead_idx] + + # Resample the live particles + loglikelihood_0 = dead_loglikelihood.max() + rng_key, sample_key = jax.random.split(rng_key) + sample_keys = jax.random.split(sample_key, len(start_idx)) + particles = jax.tree.map(lambda x: x[start_idx], state.particles) + logprior = state.logprior[start_idx] + loglikelihood = state.loglikelihood[start_idx] + inner_state = PartitionedState(particles, logprior, loglikelihood) + new_inner_state, inner_info = inner_kernel( + sample_keys, + inner_state, + logprior_fn, + loglikelihood_fn, + loglikelihood_0, + state.inner_kernel_params, + ) + + # Update the particles + particles = jax.tree_util.tree_map( + lambda p, n: p.at[target_update_idx].set(n), + state.particles, + new_inner_state.position, + ) + loglikelihood = state.loglikelihood.at[target_update_idx].set( + new_inner_state.loglikelihood + ) + loglikelihood_birth = state.loglikelihood_birth.at[target_update_idx].set( + loglikelihood_0 * jnp.ones(len(target_update_idx)) + ) + logprior = state.logprior.at[target_update_idx].set(new_inner_state.logprior) + pid = state.pid.at[target_update_idx].set(state.pid[start_idx]) + + # Update the run-time information + logX, logZ, logZ_live = update_ns_runtime_info( + state.logX, state.logZ, loglikelihood, dead_loglikelihood + ) + + # Return updated state and info + state = NSState( + particles, + loglikelihood, + loglikelihood_birth, + logprior, + pid, + logX, + logZ, + logZ_live, + state.inner_kernel_params, + ) + info = NSInfo( + dead_particles, + dead_loglikelihood, + dead_loglikelihood_birth, + dead_logprior, + inner_info, + ) + return state, info + + return kernel + + +def delete_fn( + rng_key: PRNGKey, state: NSState, num_delete: int +) -> tuple[Array, Array, Array]: + """Identifies particles to be deleted and selects live particles for resampling. + + This function implements a common strategy in Nested Sampling: + 1. Identify the `num_delete` particles with the lowest log-likelihoods. These + are marked as "dead". + 2. From the remaining live particles (those not marked as dead), `num_delete` + particles are chosen (typically with replacement, weighted by their + current importance weights, here it is uniform from survivors) + to serve as starting points for generating new particles. + + Parameters + ---------- + rng_key + A JAX PRNG key, used here for choosing live particles. + state + The current state of the Nested Sampler. + num_delete + The number of particles to delete and subsequently replace. + + Returns + ------- + tuple[Array, Array, Array] + A tuple containing: + - `dead_idx`: An array of indices corresponding to the particles + marked for deletion. + - `target_update_idx`: An array of indices corresponding to the + particles to be updated (same as dead_idx in this implementation). + - `start_idx`: An array of indices corresponding to the particles + selected for initialization. + """ + loglikelihood = state.loglikelihood + neg_dead_loglikelihood, dead_idx = jax.lax.top_k(-loglikelihood, num_delete) + constraint_loglikelihood = loglikelihood > -neg_dead_loglikelihood.min() + weights = jnp.array(constraint_loglikelihood, dtype=jnp.float32) + weights = jnp.where(weights.sum() > 0.0, weights, jnp.ones_like(weights)) + start_idx = jax.random.choice( + rng_key, + len(weights), + shape=(num_delete,), + p=weights / weights.sum(), + replace=True, + ) + target_update_idx = dead_idx + return dead_idx, target_update_idx, start_idx + + +def update_ns_runtime_info( + logX: Array, logZ: Array, loglikelihood: Array, dead_loglikelihood: Array +) -> tuple[Array, Array, Array]: + num_particles = len(loglikelihood) + num_deleted = len(dead_loglikelihood) + num_live = jnp.arange(num_particles, num_particles - num_deleted, -1) + delta_logX = -1 / num_live + logX = logX + jnp.cumsum(delta_logX) + log_delta_X = logX + jnp.log(1 - jnp.exp(delta_logX)) + log_delta_Z = dead_loglikelihood + log_delta_X + + delta_logZ = logsumexp(log_delta_Z) + logZ = jnp.logaddexp(logZ, delta_logZ) + logZ_live = logmeanexp(loglikelihood) + logX[-1] + return logX[-1], logZ, logZ_live + + +def logmeanexp(x: Array) -> Array: + return logsumexp(x) - jnp.log(len(x)) diff --git a/blackjax/ns/nss.py b/blackjax/ns/nss.py new file mode 100644 index 000000000..236138fa2 --- /dev/null +++ b/blackjax/ns/nss.py @@ -0,0 +1,324 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Nested Slice Sampling (NSS) algorithm. + +This module implements the Nested Slice Sampling algorithm, which combines the +Nested Sampling framework with an inner Hit-and-Run Slice Sampling (HRSS) kernel +for exploring the constrained prior distribution at each likelihood level. + +The key idea is to leverage the efficiency of slice sampling for constrained +sampling tasks. The parameters of the HRSS kernel, specifically the covariance +matrix for proposing slice directions, are adaptively tuned based on the current +set of live particles. +""" + +from functools import partial +from typing import Callable, Dict, Optional + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from blackjax import SamplingAlgorithm +from blackjax.mcmc.ss import SliceState +from blackjax.mcmc.ss import build_kernel as build_slice_kernel +from blackjax.mcmc.ss import default_stepper_fn +from blackjax.mcmc.ss import ( + sample_direction_from_covariance as ss_sample_direction_from_covariance, +) +from blackjax.ns.adaptive import build_kernel as build_adaptive_kernel +from blackjax.ns.adaptive import init +from blackjax.ns.base import NSInfo, NSState +from blackjax.ns.base import delete_fn as default_delete_fn +from blackjax.ns.base import new_state_and_info +from blackjax.ns.utils import get_first_row, repeat_kernel +from blackjax.smc.tuning.from_particles import ( + particles_as_rows, + particles_covariance_matrix, +) +from blackjax.types import ArrayTree, PRNGKey + +__all__ = [ + "init", + "as_top_level_api", + "build_kernel", +] + + +def sample_direction_from_covariance( + rng_key: PRNGKey, params: Dict[str, ArrayTree] +) -> ArrayTree: + """Default function to generate a normalized slice direction for NSS. + + This function is designed to work with covariance parameters adapted by + `default_adapt_direction_params_fn`. It expects `params` to contain + 'cov', a PyTree structured identically to a single particle. Each leaf + of this 'cov' PyTree contains rows of the full covariance matrix that + correspond to that leaf's elements in the flattened particle vector. + (Specifically, if the full DxD covariance matrix of flattened particles is + `M_flat`, and `unravel_fn` un-flattens a D-vector to the particle PyTree, + then the input `cov` is effectively `jax.vmap(unravel_fn)(M_flat)`). + + The function reassembles the full (D,D) covariance matrix from this + PyTree structure. It then samples a flat direction vector `d_flat` from + a multivariate Gaussian $\\mathcal{N}(0, M_{reassembled})$, normalizes + `d_flat` using the Mahalanobis norm defined by $M_{reassembled}^{-1}$, + and finally un-flattens this normalized direction back into the + particle's PyTree structure using an `unravel_fn` derived from the + particle structure. + + Parameters + ---------- + rng_key + A JAX PRNG key. + params + Keyword arguments, must contain: + - `cov`: A PyTree (structured like a particle) whose leaves are rows + of the covariance matrix, typically output by + `compute_covariance_from_particles`. + + Returns + ------- + ArrayTree + A Mahalanobis-normalized direction vector (PyTree, matching the + structure of a single particle), to be used by the slice sampler. + """ + cov = params["cov"] + row = get_first_row(cov) + _, unravel_fn = ravel_pytree(row) + cov = particles_as_rows(cov) + d = ss_sample_direction_from_covariance(rng_key, cov) + return unravel_fn(d) + + +def compute_covariance_from_particles( + state: NSState, + info: NSInfo, + inner_kernel_params: Optional[Dict[str, ArrayTree]] = None, +) -> Dict[str, ArrayTree]: + """Default function to adapt/tune the slice direction proposal parameters. + + This function computes the empirical covariance matrix from the current set of + live particles in `state.particles`. This covariance matrix is then returned + and can be used by the slice direction generation function (e.g., + `default_generate_slice_direction_fn`) in the next Nested Sampling iteration. + + Parameters + ---------- + state + The current `NSState` of the Nested Sampler, containing the live particles. + info + The `NSInfo` from the last Nested Sampling step (currently unused by this function). + inner_kernel_params + A dictionary of parameters for the inner kernel (currently unused by this function). + + Returns + ------- + Dict[str, ArrayTree] + A dictionary `{'cov': cov_pytree}`. `cov_pytree` is a PyTree with the + same structure as a single particle. If the full DxD covariance matrix + of the flattened particles is `M_flat`, and `unravel_fn` is the function + to un-flatten a D-vector to the particle's PyTree structure, then + `cov_pytree` is equivalent to `jax.vmap(unravel_fn)(M_flat)`. + This means each leaf of `cov_pytree` will have a shape `(D, *leaf_original_dims)`. + """ + cov_matrix = jnp.atleast_2d(particles_covariance_matrix(state.particles)) + single_particle = get_first_row(state.particles) + _, unravel_fn = ravel_pytree(single_particle) + cov_pytree = jax.vmap(unravel_fn)(cov_matrix) + return {"cov": cov_pytree} + + +def build_kernel( + logprior_fn: Callable, + loglikelihood_fn: Callable, + num_inner_steps: int, + num_delete: int = 1, + stepper_fn: Callable = default_stepper_fn, + adapt_direction_params_fn: Callable = compute_covariance_from_particles, + generate_slice_direction_fn: Callable = sample_direction_from_covariance, + max_steps: int = 10, + max_shrinkage: int = 100, +) -> Callable: + """Builds the Nested Slice Sampling kernel. + + This function creates a Nested Slice Sampling kernel that uses + Hit-and-Run Slice Sampling (HRSS) as its inner kernel. The parameters + for the HRSS direction proposal (specifically, the covariance matrix) + are adaptively tuned at each step using `adapt_direction_params_fn`. + + Parameters + ---------- + logprior_fn + A function that computes the log-prior probability of a single particle. + loglikelihood_fn + A function that computes the log-likelihood of a single particle. + num_inner_steps + The number of HRSS steps to run for each new particle generation. + This should be a multiple of the dimension of the parameter space. + num_delete + The number of particles to delete and replace at each NS step. + Defaults to 1. + stepper_fn + The stepper function `(x, direction, t) -> x_new` for the HRSS kernel. + Defaults to `default_stepper_fn`. + adapt_direction_params_fn + A function `(ns_state, ns_info) -> dict_of_params` that computes/adapts + the parameters (e.g., covariance matrix) for the slice direction proposal, + based on the current NS state. Defaults to `compute_covariance_from_particles`. + generate_slice_direction_fn + A function `(rng_key, **params) -> direction_pytree` that generates a + normalized direction for HRSS, using parameters from `adapt_direction_params_fn`. + Defaults to `sample_direction_from_covariance`. + max_steps + The maximum number of steps to take when expanding the interval in + each direction during the stepping-out phase. Defaults to 10. + max_shrinkage + The maximum number of shrinking steps to perform to avoid infinite loops. + Defaults to 100. + + Returns + ------- + Callable + A kernel function for Nested Slice Sampling that takes an `rng_key` and + the current `NSState` and returns a tuple containing the new `NSState` and + the `NSInfo` for the step. + """ + + slice_kernel = build_slice_kernel(stepper_fn, max_steps, max_shrinkage) + + @repeat_kernel(num_inner_steps) + def inner_kernel( + rng_key, state, logprior_fn, loglikelihood_fn, loglikelihood_0, params + ): + # Do constrained slice sampling + slice_state = SliceState( + position=state.position, + logdensity=state.logprior, + constraint=jnp.array([state.loglikelihood]), + ) + rng_key, prop_key = jax.random.split(rng_key, 2) + d = generate_slice_direction_fn(prop_key, params) + logdensity_fn = logprior_fn + constraint_fn = lambda x: jnp.array([loglikelihood_fn(x)]) + constraint = jnp.array([loglikelihood_0]) + strict = jnp.array([True]) + new_slice_state, slice_info = slice_kernel( + rng_key, slice_state, logdensity_fn, d, constraint_fn, constraint, strict + ) + + # Pass the relevant information back to PartitionedState and PartitionedInfo + return new_state_and_info( + position=new_slice_state.position, + logprior=new_slice_state.logdensity, + loglikelihood=new_slice_state.constraint[0], + info=slice_info, + ) + + delete_fn = partial(default_delete_fn, num_delete=num_delete) + + # Vectorize the inner kernel for parallel execution + in_axes = (0, 0, None, None, None, None) + + update_inner_kernel_params_fn = adapt_direction_params_fn + kernel = build_adaptive_kernel( + logprior_fn, + loglikelihood_fn, + delete_fn, + jax.vmap(inner_kernel, in_axes=in_axes), + update_inner_kernel_params_fn, + ) + return kernel + + +def as_top_level_api( + logprior_fn: Callable, + loglikelihood_fn: Callable, + num_inner_steps: int, + num_delete: int = 1, + stepper_fn: Callable = default_stepper_fn, + adapt_direction_params_fn: Callable = compute_covariance_from_particles, + generate_slice_direction_fn: Callable = sample_direction_from_covariance, + max_steps: int = 10, + max_shrinkage: int = 100, +) -> SamplingAlgorithm: + """Creates an adaptive Nested Slice Sampling (NSS) algorithm. + + This function configures a Nested Sampling algorithm that uses Hit-and-Run + Slice Sampling (HRSS) as its inner kernel. The parameters for the HRSS + direction proposal (specifically, the covariance matrix) are adaptively tuned + at each step using `adapt_direction_params_fn`. + + Parameters + ---------- + logprior_fn + A function that computes the log-prior probability of a single particle. + loglikelihood_fn + A function that computes the log-likelihood of a single particle. + num_inner_steps + The number of HRSS steps to run for each new particle generation. + This should be a multiple of the dimension of the parameter space. + num_delete + The number of particles to delete and replace at each NS step. + Defaults to 1. + stepper_fn + The stepper function `(x, direction, t) -> x_new` for the HRSS kernel. + Defaults to `default_stepper`. + adapt_direction_params_fn + A function `(ns_state, ns_info) -> dict_of_params` that computes/adapts + the parameters (e.g., covariance matrix) for the slice direction proposal, + based on the current NS state. Defaults to `compute_covariance_from_particles`. + generate_slice_direction_fn + A function `(rng_key, **params) -> direction_pytree` that generates a + normalized direction for HRSS, using parameters from `adapt_direction_params_fn`. + Defaults to `sample_direction_from_covariance`. + max_steps + The maximum number of steps to take when expanding the interval in + each direction during the stepping-out phase. Defaults to 10. + max_shrinkage + The maximum number of shrinking steps to perform to avoid infinite loops. + Defaults to 100. + + Returns + ------- + SamplingAlgorithm + A `SamplingAlgorithm` tuple containing `init` and `step` functions for + the configured Nested Slice Sampler. The state managed by this + algorithm is `NSState`. + """ + + kernel = build_kernel( + logprior_fn, + loglikelihood_fn, + num_inner_steps, + num_delete, + stepper_fn=stepper_fn, + adapt_direction_params_fn=adapt_direction_params_fn, + generate_slice_direction_fn=generate_slice_direction_fn, + max_steps=max_steps, + max_shrinkage=max_shrinkage, + ) + + def init_fn(position, rng_key=None): + # Vectorize the functions for parallel evaluation over particles + return init( + position, + logprior_fn=jax.vmap(logprior_fn), + loglikelihood_fn=jax.vmap(loglikelihood_fn), + update_inner_kernel_params_fn=adapt_direction_params_fn, + ) + + step_fn = kernel + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/ns/utils.py b/blackjax/ns/utils.py new file mode 100644 index 000000000..15f3ef35b --- /dev/null +++ b/blackjax/ns/utils.py @@ -0,0 +1,398 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for Nested Sampling. + +This module provides helper functions for common tasks associated with Nested +Sampling, such as calculating log-volumes, log-weights, effective sample sizes, +and post-processing of results. +""" + +import functools +from typing import Callable, Dict, Tuple + +import jax +import jax.numpy as jnp + +from blackjax.ns.base import NSInfo, NSState +from blackjax.types import Array, ArrayTree, PRNGKey + + +def log1mexp(x: Array) -> Array: + """Computes log(1 - exp(x)) in a numerically stable way. + + This function implements the algorithm from Mächler (2012) [1]_ for computing + log(1 - exp(x)) while avoiding precision issues, especially when x is close to 0. + + Parameters + ---------- + x + Input array or scalar. Values in x should be less than or equal to 0; + the function returns `jnp.nan` for `x > 0`. + + Returns + ------- + Array + The value of log(1 - exp(x)). + + References + ---------- + .. [1] Mächler, M. (2012). Accurately computing log(1-exp(-|a|)). + CRAN R project, package Rmpfr, vignette log1mexp-note.pdf. + https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf + """ + return jnp.where( + x > -0.6931472, # approx log(2) + jnp.log(-jnp.expm1(x)), + jnp.log1p(-jnp.exp(x)), + ) + + +def compute_num_live(info: NSInfo) -> Array: + """Compute the effective number of live points at each death contour. + + In Nested Sampling, especially with batch deletions (k > 1), the conceptual + number of live points changes with each individual particle considered "dead" + within that batch. + + The function works by: + 1. Creating "birth" events (particle added to live set, count +1) and "death" + events (particle removed, count -1). + 2. Sorting all events by their log-likelihood. In case of ties, birth events + can be processed before death events by sorting on the count type (1 before -1), + though the primary sort is logL. + 3. Computing the cumulative sum of these +1/-1 counts. This gives the number + of particles with log-likelihood greater than or equal to the current event's logL. + 4. For each death event, this cumulative sum (plus 1, because the dead particle itself + was live just before its "death") represents `m*_i`. + + Parameters + ---------- + info + An `NSInfo` object (or a PyTree with compatible `loglikelihood_birth` + and `loglikelihood` fields, typically from a concatenated history of NS steps) + containing the birth and death log-likelihoods of particles. + + Returns + ------- + Array + An array where each element `num_live[j]` is the effective number of live + points `m*_i` when the j-th particle (in the sorted list of dead particles) + was considered "dead". + """ + birth_logL = info.loglikelihood_birth + death_logL = info.loglikelihood + + birth_events = jnp.column_stack( + (birth_logL, jnp.ones_like(birth_logL, dtype=jnp.int32)) + ) + death_events = jnp.column_stack( + (death_logL, -jnp.ones_like(death_logL, dtype=jnp.int32)) + ) + combined = jnp.concatenate([birth_events, death_events], axis=0) + logL_col = combined[:, 0] + n_col = combined[:, 1] + not_nan_sort_key = ~jnp.isnan(logL_col) + logL_sort_key = logL_col + n_sort_key = n_col + sorted_indices = jnp.lexsort((n_sort_key, logL_sort_key, not_nan_sort_key)) + sorted_n_col = n_col[sorted_indices] + cumsum = jnp.cumsum(sorted_n_col) + cumsum = jnp.maximum(cumsum, 0) + death_mask_sorted = sorted_n_col == -1 + num_live = cumsum[death_mask_sorted] + 1 + + return num_live + + +def logX(rng_key: PRNGKey, dead_info: NSInfo, shape: int = 100) -> tuple[Array, Array]: + """Simulate the stochastic evolution of log prior volumes. + + This function estimates the sequence of log prior volumes `logX_i` and the + log prior volume elements `log(dX_i)` associated with each dead particle. + For each dead particle `i`, the change in log volume is modeled as + `delta_logX_i = log(u_i) / m*_i`, where `u_i` is a standard uniform random + variable and `m*_i` is the effective number of live points when particle `i` died. + + Parameters + ---------- + rng_key + A JAX PRNG key for generating uniform random variates. + dead_info + An `NSInfo` object (or compatible PyTree) containing `loglikelihood_birth` + and `loglikelihood` for all dead particles accumulated during an NS run. + It's assumed these particles are already sorted by their death log-likelihood. + shape + The shape of Monte Carlo samples to generate for the stochastic + log-volume sequence. Each sample represents one possible path of + volume shrinkage. Default is 100. + + Returns + ------- + tuple[Array, Array] + - `logX_cumulative`: An array of shape `(num_dead_particles, *shape)` + containing `shape` simulated sequences of cumulative log prior volumes `log(X_i)`. + - `log_dX_elements`: An array of shape `(num_dead_particles, *shape)` + containing `shape` simulated sequences of log prior volume elements `log(dX_i)`. + `dX_i` is approximately `X_i - X_{i+1}`. + """ + rng_key, subkey = jax.random.split(rng_key) + min_val = jnp.finfo(dead_info.loglikelihood.dtype).tiny + r = jnp.log( + jax.random.uniform( + subkey, shape=(dead_info.loglikelihood.shape[0], shape) + ).clip(min_val, 1 - min_val) + ) + + num_live = compute_num_live(dead_info) + t = r / num_live[:, jnp.newaxis] + logX = jnp.cumsum(t, axis=0) + + logXp = jnp.concatenate([jnp.zeros((1, logX.shape[1])), logX[:-1]], axis=0) + logXm = jnp.concatenate([logX[1:], jnp.full((1, logX.shape[1]), -jnp.inf)], axis=0) + log_diff = logXm - logXp + logdX = log1mexp(log_diff) + logXp - jnp.log(2) + return logX, logdX + + +def log_weights( + rng_key: PRNGKey, dead_info: NSInfo, shape: int = 100, beta: float = 1.0 +) -> Array: + """Calculate the log importance weights for Nested Sampling results. + + The importance weight for each dead particle `i` is `w_i = dX_i * L_i^beta`, + where `dX_i` is the prior volume element associated with the particle and + `L_i` is its likelihood. This function computes `log(w_i)` using stochastically + simulated `log(dX_i)` values. + + Parameters + ---------- + rng_key + A JAX PRNG key for simulating `log(dX_i)`. + dead_info + An `NSInfo` object (or compatible PyTree) containing `loglikelihood_birth` + and `loglikelihood` for all dead particles. + shape + The shape of Monte Carlo samples to use for simulating `log(dX_i)`. + Default is 100. + beta + The inverse temperature. Typically 1.0 for standard evidence calculation. + Allows for reweighting to different temperatures. + + Returns + ------- + Array + An array of log importance weights, shape `(num_dead_particles, *shape)`. + The original order of particles in `dead_info` is preserved. + """ + sort_indices = jnp.argsort(dead_info.loglikelihood) + unsort_indices = jnp.empty_like(sort_indices) + unsort_indices = unsort_indices.at[sort_indices].set(jnp.arange(len(sort_indices))) + dead_info_sorted = jax.tree.map(lambda x: x[sort_indices], dead_info) + _, log_dX = logX(rng_key, dead_info_sorted, shape) + log_w = log_dX + beta * dead_info_sorted.loglikelihood[..., jnp.newaxis] + return log_w[unsort_indices] + + +def finalise(live: NSState, dead: list[NSInfo]) -> NSInfo: + """Combines the history of dead particle information with the final live points. + + At the end of a Nested Sampling run, the remaining live points are treated + as if they were the next set of "dead" points to complete the evidence + integral and posterior sample set. This function concatenates the `NSInfo` + objects accumulated for dead particles throughout the run with a new `NSInfo` + object created from the final live particles in `live`. + + Parameters + ---------- + live + The final `NSState` of the Nested Sampler, containing the live particles. + dead + A list of `NSInfo` objects, where each object contains information + about the particles that "died" at one step of the NS algorithm. + + Returns + ------- + NSInfo + A single `NSInfo` object where all fields are concatenations of the + corresponding fields from `dead` and the final live points. + The `update_info` from the last element of `dead` is used + for the final live points' `update_info` (as a placeholder). + """ + + all_pytrees_to_combine = dead + [ + NSInfo( + live.particles, + live.loglikelihood, + live.loglikelihood_birth, + live.logprior, + dead[-1].inner_kernel_info, + ) + ] + combined_dead_info = jax.tree.map( + lambda *args: jnp.concatenate(args), + all_pytrees_to_combine[0], + *all_pytrees_to_combine[1:], + ) + return combined_dead_info + + +def ess(rng_key: PRNGKey, dead_info_map: NSInfo) -> Array: + """Computes the Effective Sample Size (ESS) from log-weights. + + The ESS is a measure of the quality of importance samples, indicating + how many independent samples the weighted set is equivalent to. + It's calculated as `(sum w_i)^2 / sum (w_i^2)`. This function computes + the mean ESS across multiple stochastic log-weight samples. + + Parameters + ---------- + rng_key + A JAX PRNG key, used by `log_weights`. + dead_info_map + An `NSInfo` object containing the full set of dead (and final live) + particles, typically the output of `finalise`. + + Returns + ------- + Array + The mean Effective Sample Size, a scalar float. + """ + logw = log_weights(rng_key, dead_info_map).mean(axis=-1) + logw -= logw.max() + l_sum_w = jax.scipy.special.logsumexp(logw) + l_sum_w_sq = jax.scipy.special.logsumexp(2 * logw) + ess = jnp.exp(2 * l_sum_w - l_sum_w_sq) + return ess + + +def sample(rng_key: PRNGKey, dead_info_map: NSInfo, shape: int = 1000) -> ArrayTree: + """Resamples particles according to their importance weights. + + This function takes the full set of dead (and final live) particles and + their computed importance weights, and draws `shape` particles with + replacement, where the probability of drawing each particle is proportional + to its weight. This produces an unweighted sample from the target posterior + distribution. + + Parameters + ---------- + rng_key + A JAX PRNG key, used for both `log_weights` and `jax.random.choice`. + dead_info_map + An `NSInfo` object containing the full set of dead (and final live) + particles, typically the output of `finalise`. + shape + The number of posterior samples to draw. Defaults to 1000. + + Returns + ------- + ArrayTree + A PyTree of resampled particles, where each leaf has `shape`. + """ + logw = log_weights(rng_key, dead_info_map).mean(axis=-1) + indices = jax.random.choice( + rng_key, + dead_info_map.loglikelihood.shape[0], + p=jnp.exp(logw.squeeze() - jnp.max(logw)), + shape=(shape,), + replace=True, + ) + return jax.tree.map(lambda leaf: leaf[indices], dead_info_map.particles) + + +def get_first_row(x: ArrayTree) -> ArrayTree: + """Extracts the first "row" (element along the leading axis) of each leaf in a PyTree. + + This is typically used to get a single particle's structure or values from + a PyTree representing a collection of particles, where the leading dimension + of each leaf array corresponds to the particle index. + + Parameters + ---------- + x + A PyTree of arrays, where each leaf array has a leading dimension. + + Returns + ------- + ArrayTree + A PyTree with the same structure as `x`, but where each leaf is the + first slice `leaf[0]` of the corresponding leaf in `x`. + """ + return jax.tree.map(lambda x: x[0], x) + + +def repeat_kernel(num_repeats: int): + """Decorator to repeat a kernel function multiple times.""" + + def decorator(kernel): + @functools.wraps(kernel) + def repeated_kernel(rng_key: PRNGKey, state, *args, **kwargs): + def body_fn(state, rng_key): + return kernel(rng_key, state, *args, **kwargs) + + keys = jax.random.split(rng_key, num_repeats) + return jax.lax.scan(body_fn, state, keys) + + return repeated_kernel + + return decorator + + +def uniform_prior( + rng_key: PRNGKey, num_live: int, bounds: Dict[str, Tuple[float, float]] +) -> Tuple[ArrayTree, Callable]: + """Helper function to create a uniform prior for parameters. + + This function generates a set of initial parameter samples uniformly + distributed within specified bounds. It also provides a log-prior + function that computes the log-prior probability for a given set of + parameters. + + Parameters + ---------- + rng_key + A JAX PRNG key for random number generation. + num_live + The number of live particles to sample. + bounds + A dictionary mapping parameter names to their bounds (tuples of min and max). + Each parameter will be sampled uniformly within these bounds. + Example: {'param1': (0.0, 1.0), 'param2': (-5.0, 5.0)} + + Returns + ------- + tuple + - `particles`: A PyTree of sampled parameters, where each leaf has shape `(num_live,)`. + - `logprior_fn`: A function that computes the log-prior probability + for a given set of parameters. + """ + + def logprior_fn(params): + logprior = 0.0 + for p, (a, b) in bounds.items(): + x = params[p] + logprior += jax.scipy.stats.uniform.logpdf(x, a, b - a) + return logprior + + def prior_sample(rng_key): + init_keys = jax.random.split(rng_key, len(bounds)) + params = {} + for rng_key, (p, (a, b)) in zip(init_keys, bounds.items()): + params[p] = jax.random.uniform(rng_key, minval=a, maxval=b) + return params + + init_keys = jax.random.split(rng_key, num_live) + particles = jax.vmap(prior_sample)(init_keys) + + return particles, logprior_fn diff --git a/blackjax/smc/from_mcmc.py b/blackjax/smc/from_mcmc.py index 0e60b5968..75e5c34a6 100644 --- a/blackjax/smc/from_mcmc.py +++ b/blackjax/smc/from_mcmc.py @@ -8,6 +8,23 @@ from blackjax.types import PRNGKey +def unshared_parameters_and_step_fn(mcmc_parameters, mcmc_step_fn): + """Splits MCMC parameters into two dictionaries. The shared dictionary + represents the parameters common to all chains, and the unshared are + different per chain. + Binds the step fn using the shared parameters. + """ + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + return unshared_mcmc_parameters, shared_mcmc_step_fn + + def build_kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable, @@ -34,15 +51,9 @@ def step( logposterior_fn: Callable, log_weights_fn: Callable, ) -> tuple[smc.base.SMCState, smc.base.SMCInfo]: - shared_mcmc_parameters = {} - unshared_mcmc_parameters = {} - for k, v in mcmc_parameters.items(): - if v.shape[0] == 1: - shared_mcmc_parameters[k] = v[0, ...] - else: - unshared_mcmc_parameters[k] = v - - shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + unshared_mcmc_parameters, shared_mcmc_step_fn = unshared_parameters_and_step_fn( + mcmc_parameters, mcmc_step_fn + ) update_fn, num_resampled = update_strategy( mcmc_init_fn, diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 2a63fd1ce..334a1488c 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -1,5 +1,7 @@ from typing import Callable, Dict, NamedTuple, Tuple +import jax + from blackjax.base import SamplingAlgorithm from blackjax.smc.base import SMCInfo, SMCState from blackjax.types import ArrayTree, PRNGKey @@ -28,8 +30,11 @@ def build_kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], + mcmc_parameter_update_fn: Callable[ + [PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree] + ], num_mcmc_steps: int = 10, + smc_returns_state_with_parameter_override=False, **extra_parameters, ) -> Callable: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner @@ -40,7 +45,8 @@ def build_kernel( ---------- smc_algorithm Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of - a sampling algorithm that returns an SMCState and SMCInfo pair). + a sampling algorithm that returns an SMCState and SMCInfo pair). It is also possible for this + to return an StateWithParameterOverride, in such case smc_returns_state_with_parameter_override needs to be True logprior_fn A function that computes the log density of the prior distribution loglikelihood_fn @@ -54,7 +60,30 @@ def build_kernel( A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. extra_parameters: parameters to be used for the creation of the smc_algorithm. + smc_returns_state_with_parameter_override: + a boolean indicating that the underlying smc_algorithm returns a smc_returns_state_with_parameter_override. + this is used in order to compose different adaptation mechanisms, such as pretuning with tuning. """ + if smc_returns_state_with_parameter_override: + + def extract_state_for_delegate(state): + return state + + def compose_new_state(new_state, new_parameter_override): + composed_parameter_override = ( + new_state.parameter_override | new_parameter_override + ) + return StateWithParameterOverride( + new_state.sampler_state, composed_parameter_override + ) + + else: + + def extract_state_for_delegate(state): + return state.sampler_state + + def compose_new_state(new_state, new_parameter_override): + return StateWithParameterOverride(new_state, new_parameter_override) def kernel( rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters @@ -69,9 +98,14 @@ def kernel( num_mcmc_steps=num_mcmc_steps, **extra_parameters, ).step - new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) - new_parameter_override = mcmc_parameter_update_fn(new_state, info) - return StateWithParameterOverride(new_state, new_parameter_override), info + parameter_update_key, step_key = jax.random.split(rng_key, 2) + new_state, info = step_fn( + step_key, extract_state_for_delegate(state), **extra_step_parameters + ) + new_parameter_override = mcmc_parameter_update_fn( + parameter_update_key, new_state, info + ) + return compose_new_state(new_state, new_parameter_override), info return kernel @@ -83,9 +117,12 @@ def as_top_level_api( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], + mcmc_parameter_update_fn: Callable[ + [PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree] + ], initial_parameter_value, num_mcmc_steps: int = 10, + smc_returns_state_with_parameter_override=False, **extra_parameters, ) -> SamplingAlgorithm: """In the context of an SMC sampler (whose step_fn returning state @@ -130,6 +167,7 @@ def as_top_level_api( resampling_fn, mcmc_parameter_update_fn, num_mcmc_steps, + smc_returns_state_with_parameter_override, **extra_parameters, ) diff --git a/blackjax/smc/pretuning.py b/blackjax/smc/pretuning.py new file mode 100644 index 000000000..374b8f425 --- /dev/null +++ b/blackjax/smc/pretuning.py @@ -0,0 +1,352 @@ +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple + +import jax +import jax.numpy as jnp +import jax.random +from jax._src.flatten_util import ravel_pytree + +from blackjax import SamplingAlgorithm, smc +from blackjax.smc.base import SMCInfo, update_and_take_last +from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc +from blackjax.smc.from_mcmc import unshared_parameters_and_step_fn +from blackjax.smc.inner_kernel_tuning import StateWithParameterOverride +from blackjax.smc.resampling import stratified +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_gaussian_noise + + +class SMCInfoWithParameterDistribution(NamedTuple): + """Stores both the sampling status and also a dictionary + with parameter names as keys and (n_particles, *) arrays as values. + The latter represents a parameter per chain for the next mutation step. + """ + + smc_info: SMCInfo + parameter_override: Dict[str, ArrayTree] + + +def esjd(m): + """Implements ESJD (expected squared jumping distance). Inner Mahalanobis distance + is computed using the Cholesky decomposition of M=LLt, and then inverting L. + Whenever M is symmetrical definite positive then it must exist a Cholesky Decomposition. + For example, if M is the Covariance Matrix of Metropolis-Hastings or + the Inverse Mass Matrix of Hamiltonian Monte Carlo. + """ + L = jnp.linalg.cholesky(m) + + def measure(previous_position, next_position, acceptance_probability): + difference = ravel_pytree(previous_position)[0] - ravel_pytree(next_position)[0] + difference_by_matrix = jnp.matmul(L, difference) + norm = jnp.linalg.norm(difference_by_matrix, 2) + return acceptance_probability * jnp.power(norm, 2) + + return jax.vmap(measure) + + +def update_parameter_distribution( + key: PRNGKey, + previous_param_samples: ArrayLikeTree, + previous_particles: ArrayLikeTree, + latest_particles: ArrayLikeTree, + measure_of_chain_mixing: Callable, + alpha: float, + sigma_parameters: ArrayLikeTree, + acceptance_probability: Array, +): + """Given an existing parameter distribution that was used to mutate previous_particles + into latest_particles, updates that parameter distribution by resampling from previous_param_samples after adding + noise to those samples. The weights used are a linear function of the measure of chain mixing. + Only works with float parameters, not integers. + See Equation 4 in https://arxiv.org/pdf/1005.1193.pdf + + Parameters + ---------- + previous_param_samples: + samples of the parameters of SMC inner MCMC chains. To be updated. + previous_particles: + particles from which the kernel step started + latest_particles: + particles after the step was performed + measure_of_chain_mixing: Callable + a callable that can compute a performance measure per chain + alpha: + a scalar to add to the weighting. See paper for details + sigma_parameters: + noise to add to the population of parameters to mutate them. must have the same shape of + previous_param_samples. + acceptance_probability: + the energy difference for each of the chains when taking a step from previous_particles + into latest_particles. + """ + noise_key, resampling_key = jax.random.split(key, 2) + + noises = jax.tree.map( + lambda x, s: generate_gaussian_noise(noise_key, x.astype("float32"), sigma=s), + previous_param_samples, + sigma_parameters, + ) + new_samples = jax.tree.map(lambda x, y: x + y, noises, previous_param_samples) + + chain_mixing_measurement = measure_of_chain_mixing( + previous_particles, latest_particles, acceptance_probability + ) + weights = alpha + chain_mixing_measurement + weights = weights / jnp.sum(weights) + resampling_idx = stratified(resampling_key, weights, len(chain_mixing_measurement)) + return ( + jax.tree.map(lambda x: x[resampling_idx], new_samples), + chain_mixing_measurement, + ) + + +def default_measure_factory(state): + inverse_mass_matrix = state.parameter_override["inverse_mass_matrix"] + if not (len(inverse_mass_matrix.shape) == 3 and inverse_mass_matrix.shape[0] == 1): + raise ValueError("ESJD only works if chains share the inverse_mass_matrix.") + + return esjd(inverse_mass_matrix[0]) + + +def build_pretune( + mcmc_init_fn: Callable, + mcmc_step_fn: Callable, + alpha: float, + sigma_parameters: ArrayLikeTree, + n_particles: int, + performance_of_chain_measure_factory: Callable = default_measure_factory, + natural_parameters: Optional[List[str]] = None, + positive_parameters: Optional[List[str]] = None, +): + """Implements Buchholz et al https://arxiv.org/pdf/1808.07730 pretuning procedure. + The goal is to maintain a probability distribution of parameters, in order + to assign different values to each inner MCMC chain. + To have performant parameters for the distribution at step t, it takes a single step, measures + the chain mixing, and reweights the probability distribution of parameters accordingly. + Note that although similar, this strategy is different than inner_kernel_tuning. The latter updates + the parameters based on the particles and transition information after the SMC step is executed. This + implementation runs a single MCMC step which gets discarded, to then proceed with the SMC step execution. + """ + if natural_parameters is None: + round_to_integer_fn = lambda x: x + else: + + def round_to_integer_fn(x): + for k in natural_parameters: + x[k] = jax.tree.map(lambda a: jnp.abs(jnp.round(a).astype(int)), x[k]) + return x + + if positive_parameters is None: + make_positive_fn = lambda x: x + else: + + def make_positive_fn(x): + for k in positive_parameters: + x[k] = jax.tree.map(jnp.abs, x[k]) + return x + + def pretune(key, state, logposterior): + unshared_mcmc_parameters, shared_mcmc_step_fn = unshared_parameters_and_step_fn( + state.parameter_override, mcmc_step_fn + ) + + one_step_fn, _ = update_and_take_last( + mcmc_init_fn, logposterior, shared_mcmc_step_fn, 1, n_particles + ) + + new_state, info = one_step_fn( + jax.random.split(key, n_particles), + state.sampler_state.particles, + unshared_mcmc_parameters, + ) + + performance_of_chain_measure = performance_of_chain_measure_factory(state) + + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + key, + previous_param_samples={ + key: state.parameter_override[key] for key in sigma_parameters + }, + previous_particles=state.sampler_state.particles, + latest_particles=new_state, + measure_of_chain_mixing=performance_of_chain_measure, + alpha=alpha, + sigma_parameters=sigma_parameters, + acceptance_probability=info.acceptance_rate, + ) + + return ( + make_positive_fn(round_to_integer_fn(new_parameter_distribution)), + chain_mixing_measurement, + ) + + def pretune_and_update(key, state: StateWithParameterOverride, logposterior): + """ + Updates the parameters that need to be pretuned and returns the rest. + """ + new_parameter_distribution, chain_mixing_measurement = pretune( + key, state, logposterior + ) + old_parameter_distribution = state.parameter_override + updated_parameter_distribution = old_parameter_distribution + for k in new_parameter_distribution: + updated_parameter_distribution[k] = new_parameter_distribution[k] + + return updated_parameter_distribution + + return pretune_and_update + + +def build_kernel( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + pretune_fn: Callable, + num_mcmc_steps: int = 10, + update_strategy=update_and_take_last, + **extra_parameters, +) -> Callable: + """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner + MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, + based on particles. The parameter type must be a valid JAX type. + + Parameters + ---------- + smc_algorithm + Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of + a sampling algorithm that returns an SMCState and SMCInfo pair). + logprior_fn + A function that computes the log density of the prior distribution + loglikelihood_fn + A function that returns the probability at a given position. + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) + mcmc_init_fn + A callable that initializes the inner kernel + pretune_fn: + A callable that can update the probability distribution of parameters. + extra_parameters: + parameters to be used for the creation of the smc_algorithm. + """ + delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) + + def pretuned_step( + rng_key: PRNGKey, + state, + num_mcmc_steps: int, + mcmc_parameters: dict, + logposterior_fn: Callable, + log_weights_fn: Callable, + ) -> tuple[smc.base.SMCState, SMCInfoWithParameterDistribution]: + """Wraps the output of smc.from_mcmc.build_kernel into a pretuning + step method. + This one should be a subtype of the former, in the sense that a usage of the former + can be replaced with an instance of this one. + """ + + pretune_key, step_key = jax.random.split(rng_key, 2) + pretuned_parameters = pretune_fn( + pretune_key, + StateWithParameterOverride(state, mcmc_parameters), + logposterior_fn, + ) + state, info = delegate( + rng_key, + state, + num_mcmc_steps, + pretuned_parameters, + logposterior_fn, + log_weights_fn, + ) + return state, SMCInfoWithParameterDistribution(info, pretuned_parameters) + + def kernel( + rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + extra_parameters["update_particles_fn"] = pretuned_step + step_fn = smc_algorithm( + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + mcmc_step_fn=mcmc_step_fn, + mcmc_init_fn=mcmc_init_fn, + mcmc_parameters=state.parameter_override, + resampling_fn=resampling_fn, + num_mcmc_steps=num_mcmc_steps, + **extra_parameters, + ).step + new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) + return ( + StateWithParameterOverride(new_state, info.parameter_override), + info.smc_info, + ) + + return kernel + + +def init(alg_init_fn, position, initial_parameter_value): + return StateWithParameterOverride(alg_init_fn(position), initial_parameter_value) + + +def as_top_level_api( + smc_algorithm, + logprior_fn: Callable, + loglikelihood_fn: Callable, + mcmc_step_fn: Callable, + mcmc_init_fn: Callable, + resampling_fn: Callable, + num_mcmc_steps: int, + initial_parameter_value: ArrayLikeTree, + pretune_fn: Callable, + **extra_parameters, +): + """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner + MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, + based on particles. The parameter type must be a valid JAX type. + + Parameters + ---------- + smc_algorithm + Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of + a sampling algorithm that returns an SMCState and SMCInfo pair). + logprior_fn + A function that computes the log density of the prior distribution + loglikelihood_fn + A function that returns the probability at a given position. + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) + mcmc_init_fn + A callable that initializes the inner kernel + pretune_fn: + A callable that can update the probability distribution of parameters. + extra_parameters: + parameters to be used for the creation of the smc_algorithm. + """ + + kernel = build_kernel( + smc_algorithm, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + pretune_fn, + num_mcmc_steps, + **extra_parameters, + ) + + def init_fn(position, rng_key=None): + del rng_key + return init(smc_algorithm.init, position, initial_parameter_value) + + def step_fn( + rng_key: PRNGKey, state, **extra_step_parameters + ) -> Tuple[StateWithParameterOverride, SMCInfo]: + return kernel(rng_key, state, **extra_step_parameters) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 88539deaa..350037f9c 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -55,6 +55,7 @@ def build_kernel( mcmc_init_fn: Callable, resampling_fn: Callable, update_strategy: Callable = update_and_take_last, + update_particles_fn: Optional[Callable] = None, ) -> Callable: """Build the base Tempered SMC kernel. @@ -92,8 +93,12 @@ def build_kernel( information about the transition. """ - delegate = smc_from_mcmc.build_kernel( - mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy + update_particles = ( + smc_from_mcmc.build_kernel( + mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy + ) + if update_particles_fn is None + else update_particles_fn ) def kernel( @@ -135,7 +140,7 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood - smc_state, info = delegate( + smc_state, info = update_particles( rng_key, state, num_mcmc_steps, @@ -162,6 +167,7 @@ def as_top_level_api( resampling_fn: Callable, num_mcmc_steps: Optional[int] = 10, update_strategy=update_and_take_last, + update_particles_fn=None, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. @@ -196,6 +202,7 @@ def as_top_level_api( mcmc_init_fn, resampling_fn, update_strategy, + update_particles_fn, ) def init_fn(position: ArrayLikeTree, rng_key=None): diff --git a/blackjax/smc/tuning/from_kernel_info.py b/blackjax/smc/tuning/from_kernel_info.py index a039e66c1..fa2c7054c 100644 --- a/blackjax/smc/tuning/from_kernel_info.py +++ b/blackjax/smc/tuning/from_kernel_info.py @@ -1,4 +1,5 @@ """ +static (all particles get the same value) strategies to tune the parameters of mcmc kernels used within smc, based on MCMC states """ diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 4c8ca98da..505e7f3a1 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -1,5 +1,5 @@ """ -strategies to tune the parameters of mcmc kernels +static (all particles get the same value) strategies to tune the parameters of mcmc kernels used within SMC, based on particles. """ import jax @@ -12,7 +12,7 @@ "particles_means", "particles_stds", "particles_covariance_matrix", - "mass_matrix_from_particles", + "inverse_mass_matrix_from_particles", ] @@ -28,18 +28,16 @@ def particles_covariance_matrix(particles): return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False) -def mass_matrix_from_particles(particles) -> Array: +def inverse_mass_matrix_from_particles(particles) -> Array: """ Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf - Computing a mass matrix to be used in HMC from particles. - Given the particles covariance matrix, set all non-diagonal elements as zero, - take the inverse, and keep the diagonal. + Computing an inverse mass matrix to be used in HMC from particles. Returns ------- - A mass Matrix + An inverse mass matrix """ - return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0)) + return jnp.diag(jnp.var(particles_as_rows(particles), axis=0)) def particles_as_rows(particles): diff --git a/docs/examples/nested_sampling.py b/docs/examples/nested_sampling.py new file mode 100644 index 000000000..bb8f05e82 --- /dev/null +++ b/docs/examples/nested_sampling.py @@ -0,0 +1,120 @@ +import jax +import jax.numpy as jnp +import tqdm +from jax.scipy.linalg import inv, solve + +import blackjax +from blackjax.ns.utils import finalise, log_weights + +# jax.config.update("jax_enable_x64", True) + +rng_key = jax.random.PRNGKey(0) +d = 5 + +C = jax.random.normal(rng_key, (d, d)) * 0.1 +like_cov = C @ C.T +like_mean = jax.random.normal(rng_key, (d,)) +prior_mean = jnp.zeros(d) +prior_cov = jnp.eye(d) * 1 +logprior_fn = lambda x: jax.scipy.stats.multivariate_normal.logpdf( + x, prior_mean, prior_cov +) + + +def loglikelihood_fn(x): + return jax.scipy.stats.multivariate_normal.logpdf(x, mean=like_mean, cov=like_cov) + + +def compute_logZ(mu_L, Sigma_L, logLmax=0, mu_pi=None, Sigma_pi=None): + Sigma_P = inv(inv(Sigma_pi) + inv(Sigma_L)) + mu_P = jnp.dot(Sigma_P, (solve(Sigma_pi, mu_pi) + solve(Sigma_L, mu_L))) + logdet_Sigma_P = jnp.linalg.slogdet(Sigma_P)[1] + logdet_Sigma_pi = jnp.linalg.slogdet(Sigma_pi)[1] + + return ( + logLmax + + logdet_Sigma_P / 2 + - logdet_Sigma_pi / 2 + - jnp.dot((mu_P - mu_pi), solve(Sigma_pi, mu_P - mu_pi)) / 2 + - jnp.dot((mu_P - mu_L), solve(Sigma_L, mu_P - mu_L)) / 2 + ) + + +log_analytic_evidence = compute_logZ( + like_mean, + like_cov, + mu_pi=prior_mean, + Sigma_pi=prior_cov, + logLmax=loglikelihood_fn(like_mean), +) + +############################################ +# Nested Sampling algorithm definition +############################################ + +# We use the loaded `nested slice sampling` here, bypassing the choice of inner kernel and +# inner kernel tuning, in favour of a simpler UI that loads the vectorized slice sampler + +# n_live is the number of live samples to draw initially and maintain through the run +n_live = 500 +# num_delete is the number of samples to delete each outer kernel iteration, as the inner kernel is parallelised we do this +# to update all of these points in parallel, useful for GPU acceleration hopefully. +num_delete = 20 +# num_inner_steps is the number of MCMC steps to perform with the inner kernel in order to decorrelate the resampled points +# we set this conservatively high here at 5 times the dimension of the parameter space +num_inner_steps = d * 5 + +algo = blackjax.nss( + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + num_delete=num_delete, + num_inner_steps=num_inner_steps, +) + +rng_key, init_key, sample_key = jax.random.split(rng_key, 3) + +initial_particles = jax.random.multivariate_normal( + init_key, prior_mean, prior_cov, (n_live,) +) + +# We can run the algorithm for a fixed number of steps but we run into a quirk of nested sampling here. The state after N iterations +# does not necessarily contain any useful posterior points, it will have accumulated an estimate of the marginal likelihood, and this +# is what is usefully tracked. + + +# n_steps = 1000 +# (live, _), dead = jax.lax.scan((one_step), (state, rng_key), length=n_steps) + + +# Also typically we would wrap the outer in a while loop, as the compression of nested sampling can push well past the posterior typical +# set if left fixed. This leaves a slightly strange construction, but works well. We want to accumulate the algorithm info (as this is) +# how we will reconstruct posterior points, but the lax while loop wrapper won't accumulate well. So we will jit compile the outer step +# and run it in a python loop + +live = algo.init(initial_particles) +step_fn = jax.jit(algo.step) +dead = [] +# with jax.disable_jit(): +for _ in tqdm.trange(1000): + # We track the estimate of the evidence in the live points as logZ_live, and the accumulated sum across all steps in logZ + # this gives a handy termination that allows us to stop early + if live.logZ_live - live.logZ < -3: # type: ignore[attr-defined] + break + rng_key, subkey = jax.random.split(rng_key, 2) + live, dead_info = step_fn(subkey, live) + dead.append(dead_info) + +# It is now not too bad to remap the list of NSInfos into a single instance +# note in theory we should include the live points, but assuming we have done things correctly and hit the termination criteria, +# they will contain negligible weight +# dead = jax.tree.map(lambda *args: jnp.concatenate(args), *dead) + +# From here we can use the utils to compute the log weights and the evidence of the accumulated dead points +# sampling log weights lets us get a sensible error on the evidence estimate +nested_samples = finalise(live, dead) +logw = log_weights(rng_key, nested_samples) +logZs = jax.scipy.special.logsumexp(logw, axis=0) + +print(f"Analytic evidence: {log_analytic_evidence:.2f}") +print(f"Runtime evidence: {live.logZ:.2f}") # type: ignore[attr-defined] +print(f"Estimated evidence: {logZs.mean():.2f} +- {logZs.std():.2f}") diff --git a/pyproject.toml b/pyproject.toml index 0739361e2..cbd2cefd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "fastprogress>=1.0.0", "jax>=0.4.16", "jaxlib>=0.4.16", - "jaxopt>=0.8", + "jaxopt<=0.8.3", "optax>=0.1.7", "typing-extensions>=4.4.0", ] diff --git a/requirements-doc.txt b/requirements-doc.txt index 83af1ffe3..fe8089cf4 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -6,7 +6,7 @@ flax ipython jax>=0.4.25 jaxlib>=0.4.25 -jaxopt +jaxopt<=0.8.3 jupytext myst_nb>=1.0.0 numba diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index c38009e5e..fd7af4450 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -77,10 +77,23 @@ def kinetic_energy(p, position=None): "c": jnp.ones((2, 1)), } _, unravel_fn = ravel_pytree(mvnormal_position_init) -key0, key1 = jax.random.split(jax.random.key(52)) -mvnormal_momentum_init = unravel_fn(jax.random.normal(key0, (6,))) -a = jax.random.normal(key1, (6, 6)) -cov = jnp.matmul(a.T, a) +mvnormal_momentum_init = { + "a": jnp.asarray(0.53288144), + "b": jnp.asarray([0.25310317, 1.3788314, -0.13486017]), + "c": jnp.asarray([[-0.59082425], [1.2088736]]), +} + +cov = jnp.asarray( + [ + [5.9959664, 1.1494889, -1.0420643, -0.6328479, -0.20363973, 2.1600752], + [1.1494889, 1.3504763, -0.3601517, -0.98311526, 1.1569028, -1.4185406], + [-1.0420643, -0.3601517, 6.3011055, -2.0662997, -0.10126236, 1.2898219], + [-0.6328479, -0.98311526, -2.0662997, 4.82699, -2.575554, 2.5724294], + [-0.20363973, 1.1569028, -0.10126236, -2.575554, 3.35319, -2.9411654], + [2.1600752, -1.4185406, 1.2898219, 2.5724294, -2.9411654, 6.3740206], + ] +) + # Validated numerically mvnormal_position_end = unravel_fn( jnp.asarray([0.38887993, 0.85231394, 2.7879136, 3.0339851, 0.5856687, 1.9291426]) @@ -238,7 +251,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -263,7 +276,7 @@ def test_isokinetic_velocity_verlet(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + op1 = esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_proposal.py b/tests/mcmc/test_proposal.py index 3a0c3ac38..391a66656 100644 --- a/tests/mcmc/test_proposal.py +++ b/tests/mcmc/test_proposal.py @@ -2,6 +2,7 @@ import jax import numpy as np import pytest +from absl.testing import parameterized from jax import numpy as jnp from blackjax.mcmc.random_walk import normal @@ -10,25 +11,18 @@ class TestNormalProposalDistribution(chex.TestCase): def setUp(self): super().setUp() - self.key = jax.random.key(20220611) + self.key = jax.random.key(20250120) - def test_normal_univariate(self): + @parameterized.parameters([10.0, 15000.0]) + def test_normal_univariate(self, initial_position): """ Move samples are generated in the univariate case, with std following sigma, and independently of the position. """ - key1, key2 = jax.random.split(self.key) + keys = jax.random.split(self.key, 200) proposal = normal(sigma=jnp.array([1.0])) - samples_from_initial_position = [ - proposal(key, jnp.array([10.0])) for key in jax.random.split(key1, 100) - ] - samples_from_another_position = [ - proposal(key, jnp.array([15000.0])) for key in jax.random.split(key2, 100) - ] - - for samples in [samples_from_initial_position, samples_from_another_position]: - np.testing.assert_allclose(0.0, np.mean(samples), rtol=1e-2, atol=1e-1) - np.testing.assert_allclose(1.0, np.std(samples), rtol=1e-2, atol=1e-1) + samples = [proposal(key, jnp.array([initial_position])) for key in keys] + self._check_mean_and_std(jnp.array([0.0]), jnp.array([1.0]), samples) def test_normal_multivariate(self): proposal = normal(sigma=jnp.array([1.0, 2.0])) @@ -61,7 +55,7 @@ def _check_mean_and_std(expected_mean, expected_std, samples): ) np.testing.assert_allclose( expected_std, - np.sqrt(np.diag(np.cov(np.array(samples).T))), + np.sqrt(np.diag(np.atleast_2d(np.cov(np.array(samples).T)))), rtol=1e-2, atol=1e-1, ) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 98572cabc..4d8a9fa61 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -15,6 +15,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -112,15 +113,16 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + _, ) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -133,7 +135,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, samples = run_inference_algorithm( @@ -146,6 +148,147 @@ def run_mclmc( return samples + def run_adjusted_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc_dynamic.init( + position=initial_position, + logdensity_fn=logdensity_fn, + random_generator_arg=init_key, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( + integrator=integrator, + integration_steps_fn=lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), + inverse_mass_matrix=inverse_mass_matrix, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + ) + + target_acc_rate = 0.65 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + _, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc_dynamic( + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=lambda key: jnp.ceil( + jax.random.uniform(key) * rescale(L / step_size) + ), + integrator=integrator, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + + def run_adjusted_mclmc_static( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, + logdensity_fn=logdensity_fn, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=avg_num_integration_steps, + ) + + target_acc_rate = 0.9 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + _, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + num_integration_steps=L / step_size, + integrator=integrator, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -259,8 +402,58 @@ def test_mclmc(self): coefs_samples = states["coefs"][3000:] scale_samples = np.exp(states["log_scale"][3000:]) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) + + def test_adjusted_mclmc(self): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) + + def test_adjusted_mclmc_static(self): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc_static( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) def test_mclmc_preconditioning(self): class IllConditionedGaussian: @@ -302,7 +495,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov(): + def get_inverse_mass_matrix(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -313,16 +506,13 @@ def get_sqrt_diag_cov(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) - ( - _, - blackjax_mclmc_sampler_params, - ) = blackjax.mclmc_find_L_and_step_size( + (_, blackjax_mclmc_sampler_params, _) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, state=initial_state, @@ -330,13 +520,14 @@ def get_sqrt_diag_cov(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov + return blackjax_mclmc_sampler_params.inverse_mass_matrix - sqrt_diag_cov = get_sqrt_diag_cov() + inverse_mass_matrix = get_inverse_mass_matrix() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), + (inverse_mass_matrix**2) + / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 @@ -510,8 +701,8 @@ def test_barker(self): coefs_samples = states["coefs"][3000:] scale_samples = np.exp(states["log_scale"][3000:]) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + np.testing.assert_allclose(np.mean(scale_samples), 1.0, rtol=1e-2, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, rtol=1e-2, atol=1e-1) class SGMCMCTest(chex.TestCase): @@ -764,7 +955,7 @@ def test_irmh(self): @chex.all_variants(with_pmap=False) def test_nuts(self): inference_algorithm = blackjax.nuts( - self.normal_logprob, step_size=4.0, inverse_mass_matrix=jnp.array([1.0]) + self.normal_logprob, step_size=1.0, inverse_mass_matrix=jnp.array([1.0]) ) initial_state = inference_algorithm.init(jnp.array(3.0)) @@ -924,7 +1115,7 @@ def test_barker(self): }, { "algorithm": blackjax.barker_proposal, - "parameters": {"step_size": 0.5}, + "parameters": {"step_size": 0.45}, "is_mass_matrix_diagonal": None, }, ] diff --git a/tests/mcmc/test_slice_sampling.py b/tests/mcmc/test_slice_sampling.py new file mode 100644 index 000000000..54c2a721f --- /dev/null +++ b/tests/mcmc/test_slice_sampling.py @@ -0,0 +1,236 @@ +"""Test the Slice Sampling algorithm""" +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +from absl.testing import absltest, parameterized + +import blackjax +from blackjax.mcmc import ss + + +def logdensity_fn(x): + """Standard normal density""" + return stats.norm.logpdf(x).sum() + + +def multimodal_logdensity(x): + """Mixture of two Gaussians""" + mode1 = stats.norm.logpdf(x - 2.0) + mode2 = stats.norm.logpdf(x + 2.0) + return jnp.logaddexp(mode1, mode2).sum() + + +def constrained_logdensity(x): + """Truncated normal (x > 0)""" + return jnp.where(x > 0, stats.norm.logpdf(x), -jnp.inf).sum() + + +class SliceSamplingTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def test_slice_init(self): + """Test slice sampler initialization""" + position = jnp.array([1.0, 2.0]) + state = ss.init(position, logdensity_fn) + + chex.assert_trees_all_close(state.position, position) + expected_logdensity = logdensity_fn(position) + chex.assert_trees_all_close(state.logdensity, expected_logdensity) + + def test_vertical_slice(self): + """Test vertical slice height sampling""" + key = jax.random.key(123) + position = jnp.array([0.0]) + state = ss.init(position, logdensity_fn) + + # Sample many slice heights + keys = jax.random.split(key, 1000) + new_state, info = jax.vmap(ss.vertical_slice, in_axes=(0, None))(keys, state) + + # Heights should be below log density at position + logdens_at_pos = logdensity_fn(position) + self.assertTrue(jnp.all(new_state.logslice <= logdens_at_pos)) + + # Heights should be reasonably distributed + mean_height = jnp.mean(new_state.logslice) + expected_mean = logdens_at_pos - 1.0 # E[log(U)] = -1 for U~Uniform(0,1) + chex.assert_trees_all_close(mean_height, expected_mean, atol=0.1) + + @parameterized.parameters([1, 2, 5]) + def test_slice_sampling_dimensions(self, ndim): + """Test slice sampling in different dimensions""" + key = jax.random.key(456) + position = jnp.zeros(ndim) + + # Simple step function + def stepper_fn(x, d, t): + return x + t * d + + # Build kernel + def direction_fn(rng_key): + return jax.random.normal(rng_key, (ndim,)) + + kernel = ss.build_hrss_kernel(direction_fn, stepper_fn) + state = ss.init(position, logdensity_fn) + + # Take one step + new_state, info = kernel(key, state, logdensity_fn) + + chex.assert_shape(new_state.position, (ndim,)) + self.assertIsInstance(new_state.logdensity, (float, jax.Array)) + + def test_constrained_slice_sampling(self): + """Test slice sampling with constraints""" + key = jax.random.key(789) + position = jnp.array([1.0]) # Start in valid region + + def stepper_fn(x, d, t): + return x + t * d + + kernel = ss.build_kernel(stepper_fn) + state = ss.init(position, constrained_logdensity) + + # Direction pointing outward + direction = jnp.array([1.0]) + + # Constraint function + def constraint_fn(x): + return jnp.array([]) # No additional constraints for this test + + new_state, info = kernel( + key, + state, + constrained_logdensity, + direction, + constraint_fn, + jnp.array([]), + jnp.array([]), + ) + + # Should remain in valid region + self.assertTrue(jnp.all(new_state.position > 0)) + + def test_default_direction_generation(self): + """Test default direction generation function""" + key = jax.random.key(101112) + cov = jnp.eye(3) * 2.0 + + direction = ss.sample_direction_from_covariance(key, cov) + + chex.assert_shape(direction, (3,)) + + # Direction should be normalized in Mahalanobis sense + invcov = jnp.linalg.inv(cov) + mahal_norm = jnp.sqrt(jnp.einsum("i,ij,j", direction, invcov, direction)) + chex.assert_trees_all_close(mahal_norm, 1.0, atol=1e-6) + + def test_hrss_top_level_api(self): + """Test hit-and-run slice sampling top-level API""" + cov = jnp.eye(2) + algorithm = ss.hrss_as_top_level_api(logdensity_fn, cov) + + # Check it returns a SamplingAlgorithm + self.assertIsInstance(algorithm, blackjax.base.SamplingAlgorithm) + + # Test init and step functions + position = jnp.array([0.0, 0.0]) + state = algorithm.init(position) + + key = jax.random.key(123) + new_state, info = algorithm.step(key, state) + + chex.assert_shape(new_state.position, (2,)) + + def test_slice_sampling_statistical_correctness(self): + """Test that slice sampling produces correct statistics""" + n_samples = 100 # Reduced significantly for faster testing + key = jax.random.key(42) + + # Use HRSS for sampling from standard normal + cov = jnp.eye(1) + algorithm = ss.hrss_as_top_level_api(logdensity_fn, cov) + + # Run inference + initial_position = jnp.array([0.0]) + initial_state = algorithm.init(initial_position) + + # Simple sampling loop with progress tracking + samples = [] + state = initial_state + keys = jax.random.split(key, n_samples) + + for i, sample_key in enumerate(keys): + state, info = algorithm.step(sample_key, state) + samples.append(state.position) + # Early exit if we get stuck + if i > 0 and jnp.isnan(state.position).any(): + break + + if len(samples) < 10: # If we got very few samples, skip statistical test + self.skipTest("Not enough samples generated") + + samples = jnp.array(samples) + + # Check basic properties + self.assertFalse(jnp.isnan(samples).any(), "Samples contain NaN") + self.assertFalse(jnp.isinf(samples).any(), "Samples contain Inf") + + # Very loose statistical checks for small sample size + sample_mean = jnp.mean(samples) + sample_std = jnp.std(samples) + + # Just check that mean is reasonable and std is positive + self.assertLess(abs(sample_mean), 2.0, "Mean is unreasonably far from 0") + self.assertGreater(sample_std, 0.1, "Standard deviation is too small") + self.assertLess(sample_std, 5.0, "Standard deviation is too large") + + def test_default_stepper_fn(self): + """Test default stepper function""" + x = jnp.array([1.0, 2.0]) + d = jnp.array([0.5, -0.5]) + t = 2.0 + + result = ss.default_stepper_fn(x, d, t) + expected = x + t * d + + chex.assert_trees_all_close(result, expected) + + def test_slice_info_structure(self): + """Test that SliceInfo contains expected fields""" + key = jax.random.key(789) + position = jnp.array([0.0]) + + def stepper_fn(x, d, t): + return x + t * d + + kernel = ss.build_kernel(stepper_fn) + state = ss.init(position, logdensity_fn) + direction = jnp.array([1.0]) + + def constraint_fn(x): + return jnp.array([]) + + new_state, info = kernel( + key, + state, + logdensity_fn, + direction, + constraint_fn, + jnp.array([]), + jnp.array([]), + ) + + # Check that info has expected structure + self.assertIsInstance(info, ss.SliceInfo) + self.assertTrue(hasattr(info, "constraint")) + self.assertTrue(hasattr(info, "l_steps")) + self.assertTrue(hasattr(info, "r_steps")) + self.assertTrue(hasattr(info, "s_steps")) + self.assertTrue(hasattr(info, "evals")) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/ns/__init__.py b/tests/ns/__init__.py new file mode 100644 index 000000000..7d1e4dbc3 --- /dev/null +++ b/tests/ns/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/ns/test_nested_sampling.py b/tests/ns/test_nested_sampling.py new file mode 100644 index 000000000..4280a9d8c --- /dev/null +++ b/tests/ns/test_nested_sampling.py @@ -0,0 +1,669 @@ +"""Test the Nested Sampling algorithms""" +import functools + +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +from absl.testing import absltest, parameterized + +from blackjax.ns import adaptive, base, nss, utils + + +def gaussian_logprior(x): + """Standard normal prior""" + return stats.norm.logpdf(x).sum() + + +def gaussian_loglikelihood(x): + """Gaussian likelihood with offset""" + return stats.norm.logpdf(x - 1.0).sum() + + +def uniform_logprior_2d(x): + """Uniform prior on [-5, 5]^2""" + return jnp.where(jnp.all(jnp.abs(x) <= 5.0), 0.0, -jnp.inf) + + +def gaussian_mixture_loglikelihood(x): + """2D Gaussian mixture for multi-modal testing""" + mixture1 = stats.norm.logpdf(x - jnp.array([2.0, 0.0])).sum() + mixture2 = stats.norm.logpdf(x - jnp.array([-2.0, 0.0])).sum() + return jnp.logaddexp(mixture1, mixture2) + + +class NestedSamplingTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def test_base_ns_init(self): + """Test basic NS initialization""" + key = jax.random.key(123) + num_live = 50 + + # Generate initial particles + particles = jax.random.normal(key, (num_live,)) + + # Initialize NS state + state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) + + # Check state structure + chex.assert_shape(state.particles, (num_live,)) + chex.assert_shape(state.loglikelihood, (num_live,)) + chex.assert_shape(state.logprior, (num_live,)) + chex.assert_shape(state.pid, (num_live,)) + + # Check that loglikelihood and logprior are properly computed + expected_loglik = jax.vmap(gaussian_loglikelihood)(particles) + expected_logprior = jax.vmap(gaussian_logprior)(particles) + + chex.assert_trees_all_close(state.loglikelihood, expected_loglik) + chex.assert_trees_all_close(state.logprior, expected_logprior) + + def test_delete_fn(self): + """Test particle deletion function""" + key = jax.random.key(456) + num_live = 20 + num_delete = 3 + + particles = jax.random.normal(key, (num_live,)) + state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) + + dead_idx, target_idx, start_idx = base.delete_fn(key, state, num_delete) + + # Check correct number of deletions + chex.assert_shape(dead_idx, (num_delete,)) + chex.assert_shape(target_idx, (num_delete,)) + chex.assert_shape(start_idx, (num_delete,)) + + # Check that worst particles are selected + worst_loglik = jnp.sort(state.loglikelihood)[:num_delete] + selected_loglik = state.loglikelihood[dead_idx] + chex.assert_trees_all_close(jnp.sort(selected_loglik), worst_loglik) + + @parameterized.parameters([1, 2, 5]) + def test_ns_step_consistency(self, num_delete): + """Test NS step maintains particle count""" + key = jax.random.key(789) + num_live = 50 + + particles = jax.random.normal(key, (num_live, 2)) + state = base.init( + particles, uniform_logprior_2d, gaussian_mixture_loglikelihood + ) + + # Mock inner kernel for testing + def mock_inner_kernel( + rng_key, inner_state, logprior_fn, loglikelihood_fn, loglikelihood_0, params + ): + # Simple random walk for testing + new_pos = ( + inner_state["position"] + + jax.random.normal(rng_key, inner_state["position"].shape) * 0.1 + ) + new_logprior = logprior_fn(new_pos) + new_loglik = loglikelihood_fn(new_pos) + + new_inner_state = { + "position": new_pos, + "logprior": new_logprior, + "loglikelihood": new_loglik, + } + return new_inner_state, {} + + delete_fn = functools.partial(base.delete_fn, num_delete=num_delete) + kernel = base.build_kernel( + uniform_logprior_2d, + gaussian_mixture_loglikelihood, + delete_fn, + mock_inner_kernel, + ) + + # Test that the kernel can be constructed with mock components + # Full execution would require more complex mocking of inner kernel behavior + self.assertTrue(callable(kernel)) + + # Test delete function works + dead_idx, target_idx, start_idx = base.delete_fn(key, state, num_delete) + chex.assert_shape(dead_idx, (num_delete,)) + chex.assert_shape(target_idx, (num_delete,)) + chex.assert_shape(start_idx, (num_delete,)) + + def test_utils_functions(self): + """Test utility functions""" + key = jax.random.key(101112) + + # Create mock dead info + n_dead = 20 + dead_loglik = jnp.sort(jax.random.uniform(key, (n_dead,))) * 10 - 5 + dead_loglik_birth = jnp.full_like(dead_loglik, -jnp.inf) + + mock_info = base.NSInfo( + particles=jnp.zeros((n_dead, 2)), + loglikelihood=dead_loglik, + loglikelihood_birth=dead_loglik_birth, + logprior=jnp.zeros(n_dead), + inner_kernel_info={}, + ) + + # Test compute_num_live + num_live = utils.compute_num_live(mock_info) + chex.assert_shape(num_live, (n_dead,)) + + # Test logX simulation + logX_seq, logdX_seq = utils.logX(key, mock_info, shape=10) + chex.assert_shape(logX_seq, (n_dead, 10)) + chex.assert_shape(logdX_seq, (n_dead, 10)) + + # Check logX is decreasing + self.assertTrue(jnp.all(logX_seq[1:] <= logX_seq[:-1])) + + +class AdaptiveNestedSamplingTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def test_adaptive_init(self): + """Test adaptive NS initialization""" + key = jax.random.key(123) + num_live = 30 + + particles = jax.random.normal(key, (num_live,)) + + def mock_update_params_fn(state, info, current_params): + return {"test_param": 1.0} + + state = adaptive.init( + particles, + gaussian_logprior, + gaussian_loglikelihood, + update_inner_kernel_params_fn=mock_update_params_fn, + ) + + # Check that inner kernel params were set + self.assertEqual(state.inner_kernel_params["test_param"], 1.0) + + +class NestedSliceSamplingTest(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def test_nss_direction_functions(self): + """Test NSS direction generation functions""" + key = jax.random.key(456) + + # Test covariance computation + particles = jax.random.normal(key, (50, 3)) + state = base.init(particles, gaussian_logprior, gaussian_loglikelihood) + + params = nss.compute_covariance_from_particles(state, None, {}) + + # Check that covariance is computed + self.assertIn("cov", params) + cov_pytree = params["cov"] + chex.assert_shape(cov_pytree, (3, 3)) + + # Test direction sampling + direction = nss.sample_direction_from_covariance(key, params) + chex.assert_shape(direction, (3,)) + + def test_nss_kernel_construction(self): + """Test NSS kernel can be constructed""" + kernel = nss.build_kernel( + gaussian_logprior, gaussian_loglikelihood, num_inner_steps=10 + ) + + # Test that kernel is callable + self.assertTrue(callable(kernel)) + + +class NestedSamplingStatisticalTest(chex.TestCase): + """Statistical correctness tests for nested sampling algorithms.""" + + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def test_1d_gaussian_evidence_estimation(self): + """Test evidence estimation with analytic validation for unnormalized Gaussian.""" + + # Simple case: unnormalized Gaussian likelihood exp(-0.5*x²), uniform prior [-3,3] + prior_a, prior_b = -3.0, 3.0 + + def logprior_fn(x): + return jnp.where( + (x >= prior_a) & (x <= prior_b), -jnp.log(prior_b - prior_a), -jnp.inf + ) + + def loglikelihood_fn(x): + # Unnormalized Gaussian: exp(-0.5 * x²) + return -0.5 * x**2 + + # Analytic evidence: Z = ∫[-3,3] (1/6) * exp(-0.5*x²) dx + # = (1/6) * √(2π) * [Φ(3) - Φ(-3)] + from scipy.stats import norm + + prior_width = prior_b - prior_a + integral_part = jnp.sqrt(2 * jnp.pi) * (norm.cdf(3.0) - norm.cdf(-3.0)) + analytical_evidence = integral_part / prior_width + analytical_log_evidence = jnp.log(analytical_evidence) + + # Generate mock nested sampling data + num_steps = 60 + key = jax.random.key(42) + + # Create positions spanning the prior range + positions = jnp.linspace(prior_a + 0.05, prior_b - 0.05, num_steps).reshape( + -1, 1 + ) + dead_loglik = jax.vmap(loglikelihood_fn)(positions.flatten()) + dead_logprior = jax.vmap(logprior_fn)(positions.flatten()) + + # Sort by likelihood (as NS naturally produces) + sorted_indices = jnp.argsort(dead_loglik) + dead_loglik = dead_loglik[sorted_indices] + positions = positions[sorted_indices] + dead_logprior = dead_logprior[sorted_indices] + + # Birth likelihoods - start from prior + dead_loglik_birth = jnp.full_like(dead_loglik, -jnp.inf) + + # Create NSInfo object + mock_info = base.NSInfo( + particles=positions, + loglikelihood=dead_loglik, + loglikelihood_birth=dead_loglik_birth, + logprior=dead_logprior, + inner_kernel_info={}, + ) + + # Generate many evidence estimates for statistical testing + n_evidence_samples = 500 + key = jax.random.key(789) + keys = jax.random.split(key, n_evidence_samples) + + def single_evidence_estimate(rng_key): + log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=15) + return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) + + # Compute evidence estimates + log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) + log_evidence_samples = log_evidence_samples.flatten() + + # Statistical validation + mean_estimate = jnp.mean(log_evidence_samples) + std_estimate = jnp.std(log_evidence_samples) + + # Check statistical consistency with 95% confidence interval + # For mock data with simplified NS, expect some bias but should be in ballpark + tolerance = 2.0 * std_estimate # 95% CI + bias = jnp.abs(mean_estimate - analytical_log_evidence) + + self.assertLess( + bias, + tolerance, + f"Evidence estimate {mean_estimate:.3f} vs analytic {analytical_log_evidence:.3f} " + f"differs by {bias:.3f}, which exceeds 2σ = {tolerance:.3f}", + ) + + # Also test that individual estimates are reasonable + self.assertFalse( + jnp.any(jnp.isnan(log_evidence_samples)), + "No evidence estimates should be NaN", + ) + self.assertFalse( + jnp.any(jnp.isinf(log_evidence_samples)), + "No evidence estimates should be infinite", + ) + + # Check that estimates are in a reasonable range + self.assertGreater( + mean_estimate, analytical_log_evidence - 1.0, "Mean estimate not too low" + ) + self.assertLess( + mean_estimate, analytical_log_evidence + 1.0, "Mean estimate not too high" + ) + + def test_uniform_prior_evidence(self): + """Test evidence estimation for uniform prior with simple likelihood.""" + + # Setup: Uniform prior on [0, 1], simple likelihood + def logprior_fn(x): + return jnp.where((x >= 0.0) & (x <= 1.0), 0.0, -jnp.inf) + + def loglikelihood_fn(x): + # Simple quadratic likelihood peaked at 0.5 + return -10.0 * (x - 0.5) ** 2 + + # Analytical evidence can be computed numerically for comparison + # Z = integral_0^1 exp(-10(x-0.5)^2) dx ≈ sqrt(π/10) * erf(...) + + num_live = 50 + key = jax.random.key(456) + + # Initialize particles uniformly in [0, 1] + particles = jax.random.uniform(key, (num_live,)) + state = base.init(particles, logprior_fn, loglikelihood_fn) + + # Check that initialization worked correctly + self.assertTrue(jnp.all(state.particles >= 0.0)) + self.assertTrue(jnp.all(state.particles <= 1.0)) + self.assertFalse(jnp.any(jnp.isinf(state.logprior))) + self.assertFalse(jnp.any(jnp.isnan(state.loglikelihood))) + + # Test evidence contribution from live points + logZ_live_contribution = state.logZ_live + self.assertIsInstance(logZ_live_contribution, (float, jax.Array)) + self.assertFalse(jnp.isnan(logZ_live_contribution)) + + def test_evidence_monotonicity(self): + """Test that evidence estimates are monotonically increasing during NS run.""" + + # Simple setup for testing monotonicity + def logprior_fn(x): + return stats.norm.logpdf(x) + + def loglikelihood_fn(x): + return -0.5 * x**2 # Simple quadratic + + num_live = 30 + key = jax.random.key(789) + + particles = jax.random.normal(key, (num_live,)) + initial_state = base.init(particles, logprior_fn, loglikelihood_fn) + + # Test that we can track evidence during run + logZ_sequence = [initial_state.logZ] + + # Simulate a few evidence updates manually + for i in range(5): + # Simulate removing worst particle and updating evidence + worst_idx = jnp.argmin(initial_state.loglikelihood) + dead_loglik = initial_state.loglikelihood[worst_idx] + + # Update evidence (simplified) + delta_logX = -1.0 / num_live # Approximate volume decrease + new_logZ = jnp.logaddexp(initial_state.logZ, dead_loglik + delta_logX) + logZ_sequence.append(new_logZ) + + # Update for next iteration (simplified) + new_loglik = jnp.concatenate( + [ + initial_state.loglikelihood[:worst_idx], + initial_state.loglikelihood[worst_idx + 1 :], + jnp.array([dead_loglik + 0.1]), # Mock new particle + ] + ) + initial_state = initial_state._replace(loglikelihood=new_loglik) + + # Check monotonicity + logZ_array = jnp.array(logZ_sequence) + differences = logZ_array[1:] - logZ_array[:-1] + self.assertTrue( + jnp.all(differences >= -1e-10), + "Evidence should be monotonically increasing", + ) + + def test_nested_sampling_utils_statistical_properties(self): + """Test statistical properties of nested sampling utility functions.""" + key = jax.random.key(101112) + + # Create realistic mock data + n_dead = 100 + + # Generate realistic loglikelihood sequence (increasing) + base_loglik = jnp.linspace(-10, -1, n_dead) + noise = jax.random.normal(key, (n_dead,)) * 0.1 + dead_loglik = jnp.sort(base_loglik + noise) + + # Create more realistic birth likelihoods that reflect actual NS behavior + # Particles can be born at various levels, not just at previous death + key, subkey = jax.random.split(key) + birth_noise = jax.random.uniform(subkey, (n_dead,)) * 2.0 - 1.0 # [-1, 1] + dead_loglik_birth = jnp.concatenate( + [ + jnp.array([-jnp.inf]), # First particle born from prior + dead_loglik[:-1] + birth_noise[1:] * 0.5, # Others with some variation + ] + ) + # Ensure birth likelihoods don't exceed death likelihoods + dead_loglik_birth = jnp.minimum(dead_loglik_birth, dead_loglik - 0.01) + + mock_info = base.NSInfo( + particles=jnp.zeros((n_dead, 2)), + loglikelihood=dead_loglik, + loglikelihood_birth=dead_loglik_birth, + logprior=jnp.zeros(n_dead), + inner_kernel_info={}, + ) + + # Test compute_num_live + num_live = utils.compute_num_live(mock_info) + chex.assert_shape(num_live, (n_dead,)) + + # Basic sanity checks for number of live points + # NOTE: num_live should NOT be monotonically decreasing in general NS! + # It follows a sawtooth pattern as particles die and are replenished + self.assertTrue( + jnp.all(num_live >= 1), "Should always have at least 1 live point" + ) + self.assertTrue( + jnp.all(num_live <= 1000), # Reasonable upper bound + "Number of live points should be reasonable", + ) + self.assertFalse( + jnp.any(jnp.isnan(num_live)), "Number of live points should not be NaN" + ) + + # Test logX simulation + n_samples = 50 + logX_seq, logdX_seq = utils.logX(key, mock_info, shape=n_samples) + chex.assert_shape(logX_seq, (n_dead, n_samples)) + chex.assert_shape(logdX_seq, (n_dead, n_samples)) + + # Log volumes should be decreasing + self.assertTrue( + jnp.all(logX_seq[1:] <= logX_seq[:-1]), "Log volumes should be decreasing" + ) + + # All log volume elements should be negative (since dX < X) + finite_logdX = logdX_seq[jnp.isfinite(logdX_seq)] + if len(finite_logdX) > 0: + self.assertTrue( + jnp.all(finite_logdX <= 0.0), "Log volume elements should be negative" + ) + + # Test log_weights function + log_weights_matrix = utils.log_weights(key, mock_info, shape=n_samples) + chex.assert_shape(log_weights_matrix, (n_dead, n_samples)) + + # Weights should be finite for most particles + finite_weights = jnp.isfinite(log_weights_matrix) + self.assertGreater( + jnp.sum(finite_weights), + n_dead * n_samples * 0.5, + "Most weights should be finite", + ) + + def test_gaussian_evidence_narrow_prior(self): + """Test evidence estimation with narrow prior for challenging case.""" + + # Setup: Gaussian likelihood with narrow uniform prior (more challenging) + mu_true = 1.2 + sigma_true = 0.6 + prior_a, prior_b = 0.8, 1.6 # Narrow prior around the mean + + def logprior_fn(x): + return jnp.where( + (x >= prior_a) & (x <= prior_b), -jnp.log(prior_b - prior_a), -jnp.inf + ) + + def loglikelihood_fn(x): + return -0.5 * ((x - mu_true) / sigma_true) ** 2 - 0.5 * jnp.log( + 2 * jnp.pi * sigma_true**2 + ) + + # Analytic evidence + from scipy.stats import norm + + analytical_evidence = ( + norm.cdf((prior_b - mu_true) / sigma_true) + - norm.cdf((prior_a - mu_true) / sigma_true) + ) / (prior_b - prior_a) + analytical_log_evidence = jnp.log(analytical_evidence) + + # Generate mock NS data with higher resolution for narrow prior + num_steps = 60 + key = jax.random.key(12345) + + # Dense sampling in the narrow prior region + positions = jnp.linspace(prior_a + 0.01, prior_b - 0.01, num_steps).reshape( + -1, 1 + ) + dead_loglik = jax.vmap(loglikelihood_fn)(positions.flatten()) + dead_logprior = jax.vmap(logprior_fn)(positions.flatten()) + + # Sort by likelihood + sorted_indices = jnp.argsort(dead_loglik) + dead_loglik = dead_loglik[sorted_indices] + positions = positions[sorted_indices] + dead_logprior = dead_logprior[sorted_indices] + + # Birth likelihoods + key, subkey = jax.random.split(key) + birth_noise = jax.random.uniform(subkey, (num_steps,)) * 0.3 - 0.15 + dead_loglik_birth = jnp.concatenate( + [jnp.array([-jnp.inf]), dead_loglik[:-1] + birth_noise[1:]] + ) + dead_loglik_birth = jnp.minimum(dead_loglik_birth, dead_loglik - 0.01) + + mock_info = base.NSInfo( + particles=positions, + loglikelihood=dead_loglik, + loglikelihood_birth=dead_loglik_birth, + logprior=dead_logprior, + inner_kernel_info={}, + ) + + # Generate evidence estimates for statistical testing + n_evidence_samples = 800 + key = jax.random.key(555) + keys = jax.random.split(key, n_evidence_samples) + + def single_evidence_estimate(rng_key): + log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=15) + return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) + + log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) + log_evidence_samples = log_evidence_samples.flatten() + + # Statistical validation + mean_estimate = jnp.mean(log_evidence_samples) + std_estimate = jnp.std(log_evidence_samples) + + # 99% confidence interval test + lower_bound = mean_estimate - 2.576 * std_estimate # 99% CI + upper_bound = mean_estimate + 2.576 * std_estimate + + self.assertGreater( + analytical_log_evidence, + lower_bound, + f"Analytic evidence {analytical_log_evidence:.3f} below 99% CI lower bound {lower_bound:.3f}", + ) + self.assertLess( + analytical_log_evidence, + upper_bound, + f"Analytic evidence {analytical_log_evidence:.3f} above 99% CI upper bound {upper_bound:.3f}", + ) + + def test_evidence_integration_simple_case(self): + """Test evidence calculation for a simple analytical case with constant likelihood.""" + # Test case: uniform prior on [0,2], constant likelihood + # Evidence = ∫[0,2] (1/width) * exp(loglik_constant) dx = exp(loglik_constant) + + loglik_constant = -1.5 + prior_width = 2.0 # Prior on [0, 2] + n_dead = 40 + + # Analytic answer: evidence = ∫[0,2] (1/2) * exp(-1.5) dx = exp(-1.5) + analytical_log_evidence = loglik_constant + + # Mock data: all particles have same likelihood (constant function) + dead_loglik = jnp.full(n_dead, loglik_constant) + dead_loglik_birth = jnp.full(n_dead, -jnp.inf) # All from prior + + mock_info = base.NSInfo( + particles=jnp.zeros((n_dead, 1)), + loglikelihood=dead_loglik, + loglikelihood_birth=dead_loglik_birth, + logprior=jnp.full( + n_dead, -jnp.log(prior_width) + ), # Uniform prior log density + inner_kernel_info={}, + ) + + # Generate many evidence estimates + n_samples = 500 + key = jax.random.key(999) + keys = jax.random.split(key, n_samples) + + def single_evidence_estimate(rng_key): + log_weights_matrix = utils.log_weights(rng_key, mock_info, shape=25) + return jax.scipy.special.logsumexp(log_weights_matrix, axis=0) + + log_evidence_samples = jax.vmap(single_evidence_estimate)(keys) + log_evidence_samples = log_evidence_samples.flatten() + + mean_estimate = jnp.mean(log_evidence_samples) + std_estimate = jnp.std(log_evidence_samples) + + # For constant likelihood case, should be very accurate + # 95% confidence interval + lower_bound = mean_estimate - 1.96 * std_estimate + upper_bound = mean_estimate + 1.96 * std_estimate + + self.assertGreater( + analytical_log_evidence, + lower_bound, + f"Analytic evidence {analytical_log_evidence:.3f} below 95% CI", + ) + self.assertLess( + analytical_log_evidence, + upper_bound, + f"Analytic evidence {analytical_log_evidence:.3f} above 95% CI", + ) + + def test_effective_sample_size_calculation(self): + """Test effective sample size calculation.""" + key = jax.random.key(67890) + + # Create mock data with varying weights + n_dead = 50 + dead_loglik = jax.random.uniform(key, (n_dead,)) * 5 - 10 # Range [-10, -5] + dead_loglik_birth = jnp.full(n_dead, -jnp.inf) + + mock_info = base.NSInfo( + particles=jnp.zeros((n_dead, 1)), + loglikelihood=jnp.sort(dead_loglik), # Ensure increasing + loglikelihood_birth=dead_loglik_birth, + logprior=jnp.zeros(n_dead), + inner_kernel_info={}, + ) + + # Calculate ESS + ess_value = utils.ess(key, mock_info) + + # ESS should be positive and reasonable + self.assertIsInstance(ess_value, (float, jax.Array)) + self.assertGreater(ess_value, 0.0, "ESS should be positive") + self.assertLessEqual( + ess_value, n_dead, "ESS should not exceed number of samples" + ) + self.assertFalse(jnp.isnan(ess_value), "ESS should not be NaN") + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 7d6190af5..d7daaf839 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -15,9 +15,10 @@ from blackjax.mcmc.random_walk import build_irmh from blackjax.smc import extend_params from blackjax.smc.inner_kernel_tuning import as_top_level_api as inner_kernel_tuning +from blackjax.smc.pretuning import build_pretune from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( - mass_matrix_from_particles, + inverse_mass_matrix_from_particles, particles_as_rows, particles_covariance_matrix, particles_means, @@ -93,7 +94,7 @@ def smc_inner_kernel_tuning_test_case( proposal_factory = MagicMock() proposal_factory.return_value = 100 - def mcmc_parameter_update_fn(state, info): + def mcmc_parameter_update_fn(key, state, info): return extend_params({"mean": 100}) prior = lambda x: stats.norm.logpdf(x) @@ -186,30 +187,30 @@ def setUp(self): self.key = jax.random.key(42) def test_inverse_mass_matrix_from_particles(self): - inverse_mass_matrix = mass_matrix_from_particles( + inverse_mass_matrix = inverse_mass_matrix_from_particles( np.array([np.array(10.0), np.array(3.0)]) ) np.testing.assert_allclose( - inverse_mass_matrix, np.diag(np.array([0.08163])), rtol=1e-4 + inverse_mass_matrix, np.diag(np.array([12.25])), rtol=1e-4 ) def test_inverse_mass_matrix_from_multivariate_particles(self): - inverse_mass_matrix = mass_matrix_from_particles( + inverse_mass_matrix = inverse_mass_matrix_from_particles( np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) ) np.testing.assert_allclose( - inverse_mass_matrix, np.diag(np.array([0.081633, 0.033058])), rtol=1e-4 + inverse_mass_matrix, np.diag(np.array([12.25, 30.25])), rtol=1e-4 ) def test_inverse_mass_matrix_from_multivariable_particles(self): var1 = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) var2 = np.array([jnp.array([10.0]), jnp.array([3.0])]) init_particles = {"var1": var1, "var2": var2} - mass_matrix = mass_matrix_from_particles(init_particles) + mass_matrix = inverse_mass_matrix_from_particles(init_particles) assert mass_matrix.shape == (3, 3) np.testing.assert_allclose( np.diag(mass_matrix), - np.array([0.081633, 0.033058, 0.081633], dtype="float32"), + np.array([12.25, 30.25, 12.25], dtype="float32"), rtol=1e-4, ) @@ -217,10 +218,10 @@ def test_inverse_mass_matrix_from_multivariable_univariate_particles(self): var1 = np.array([3.0, 2.0]) var2 = np.array([10.0, 3.0]) init_particles = {"var1": var1, "var2": var2} - mass_matrix = mass_matrix_from_particles(init_particles) + mass_matrix = inverse_mass_matrix_from_particles(init_particles) assert mass_matrix.shape == (2, 2) np.testing.assert_allclose( - np.diag(mass_matrix), np.array([4, 0.081633], dtype="float32"), rtol=1e-4 + np.diag(mass_matrix), np.array([0.25, 12.25], dtype="float32"), rtol=1e-4 ) @@ -279,10 +280,12 @@ def test_with_adaptive_tempered(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() - def parameter_update(state, info): + def parameter_update(key, state, info): return extend_params( { - "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "inverse_mass_matrix": inverse_mass_matrix_from_particles( + state.particles + ), "step_size": 10e-2, "num_integration_steps": 50, }, @@ -308,21 +311,7 @@ def parameter_update(state, info): ) init_state = init(init_particles) smc_kernel = self.variant(step) - - def inference_loop(kernel, rng_key, initial_state): - def cond(carry): - _, state = carry - return state.sampler_state.lmbda < 1 - - def body(carry): - i, state = carry - subkey = jax.random.fold_in(rng_key, i) - state, _ = kernel(subkey, state) - return i + 1, state - - return jax.lax.while_loop(cond, body, (0, initial_state)) - - _, state = inference_loop(smc_kernel, self.key, init_state) + _, state = adaptive_tempered_loop(smc_kernel, self.key, init_state) assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @@ -336,10 +325,12 @@ def test_with_tempered_smc(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() - def parameter_update(state, info): + def parameter_update(key, state, info): return extend_params( { - "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "inverse_mass_matrix": inverse_mass_matrix_from_particles( + state.particles + ), "step_size": 10e-2, "num_integration_steps": 50, }, @@ -393,5 +384,128 @@ def test_particles_as_rows(self): np.testing.assert_array_equal(np.arange(3 * 5 + 2), flatten_particles[0]) +def adaptive_tempered_loop(kernel, rng_key, initial_state): + def cond(carry): + _, state = carry + return state.sampler_state.lmbda < 1 + + def body(carry): + i, state = carry + subkey = jax.random.fold_in(rng_key, i) + state, _ = kernel(subkey, state) + return i + 1, state + + return jax.lax.while_loop(cond, body, (0, initial_state)) + + +class MultipleTuningTest(SMCLinearRegressionTestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.all_variants(with_pmap=False) + def test_tuning_pretuning(self): + """ + Tests that we can apply tuning on some parameters + and pretuning in some others at the same time. + """ + + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + n_particles = 100 + dimentions = 2 + + step_size_key, integration_steps_key = jax.random.split(self.key, 2) + + # Set initial samples for integration steps and step sizes. + integration_steps_distribution = jnp.round( + jax.random.uniform( + integration_steps_key, (n_particles,), minval=1, maxval=50 + ) + ).astype(int) + + step_sizes_distribution = jax.random.uniform( + step_size_key, (n_particles,), minval=1e-1 / 2, maxval=1e-1 * 2 + ) + + # Fixes inverse_mass_matrix and distribution for the other two parameters. + initial_parameters = dict( + inverse_mass_matrix=extend_params(jnp.eye(dimentions)), + step_size=step_sizes_distribution, + num_integration_steps=integration_steps_distribution, + ) + + pretune = build_pretune( + blackjax.hmc.init, + blackjax.hmc.build_kernel(), + alpha=2, + n_particles=n_particles, + sigma_parameters={ + "step_size": jnp.array(0.1), + "num_integration_steps": jnp.array(2.0), + }, + natural_parameters=["num_integration_steps"], + positive_parameters=["step_size"], + ) + + def pretuning_factory( + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + mcmc_parameters, + resampling_fn, + num_mcmc_steps, + initial_parameter_value, + target_ess, + ): + # we need to wrap the pretuning into a factory, which is what + # the inner_kernel_tuning expects + return blackjax.pretuning( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + num_mcmc_steps, + initial_parameter_value, + pretune, + target_ess=target_ess, + ) + + def mcmc_parameter_update_fn(key, state, info): + imm = inverse_mass_matrix_from_particles(state.sampler_state.particles) + return {"inverse_mass_matrix": extend_params(imm)} + + step = blackjax.smc.inner_kernel_tuning.build_kernel( + pretuning_factory, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + mcmc_parameter_update_fn=mcmc_parameter_update_fn, + initial_parameter_value=initial_parameters, + num_mcmc_steps=10, + target_ess=0.5, + smc_returns_state_with_parameter_override=True, + ) + + def init(position): + return blackjax.smc.inner_kernel_tuning.init( + blackjax.adaptive_tempered_smc.init, position, initial_parameters + ) + + init_state = init(init_particles) + smc_kernel = self.variant(step) + _, state = adaptive_tempered_loop(smc_kernel, self.key, init_state) + self.assert_linear_regression_test_case(state.sampler_state) + + if __name__ == "__main__": absltest.main() diff --git a/tests/smc/test_pretuning.py b/tests/smc/test_pretuning.py new file mode 100644 index 000000000..d24996eaf --- /dev/null +++ b/tests/smc/test_pretuning.py @@ -0,0 +1,286 @@ +import unittest + +import chex +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest + +import blackjax +from blackjax.smc import extend_params, resampling +from blackjax.smc.pretuning import ( + build_pretune, + esjd, + init, + update_parameter_distribution, +) +from tests.smc import SMCLinearRegressionTestCase + + +class TestMeasureOfChainMixing(unittest.TestCase): + previous_position = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + + next_position = np.array([jnp.array([20.0, 30.0]), jnp.array([9.0, 12.0])]) + + def test_measure_of_chain_mixing_identity(self): + """ + Given identity matrix and 1. acceptance probability + then the mixing is the square of norm 2. + """ + m = np.eye(2) + + acceptance_probabilities = np.array([1.0, 1.0]) + chain_mixing = esjd(m)( + self.previous_position, self.next_position, acceptance_probabilities + ) + np.testing.assert_allclose(chain_mixing[0], 325) + np.testing.assert_allclose(chain_mixing[1], 100) + + def test_measure_of_chain_mixing_with_non_1_acceptance_rate(self): + """ + Given identity matrix + then the mixing is the square of norm 2. multiplied by the acceptance rate + """ + m = np.eye(2) + + acceptance_probabilities = np.array([0.5, 0.2]) + chain_mixing = esjd(m)( + self.previous_position, self.next_position, acceptance_probabilities + ) + np.testing.assert_allclose(chain_mixing[0], 162.5) + np.testing.assert_allclose(chain_mixing[1], 20) + + def test_measure_of_chain_mixing(self): + m = np.array([[3, 0], [0, 5]]) + + previous_position = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) + + next_position = np.array([jnp.array([20.0, 30.0]), jnp.array([9.0, 12.0])]) + + acceptance_probabilities = np.array([1.0, 1.0]) + + chain_mixing = esjd(m)( + previous_position, next_position, acceptance_probabilities + ) + + assert chain_mixing.shape == (2,) + np.testing.assert_allclose(chain_mixing[0], 10 * 10 * 3 + 15 * 15 * 5) + np.testing.assert_allclose(chain_mixing[1], 6 * 6 * 3 + 8 * 8 * 5) + + +class TestUpdateParameterDistribution(chex.TestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + self.previous_position = np.array( + [jnp.array([10.0, 15.0]), jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])] + ) + self.next_position = np.array( + [jnp.array([20.0, 30.0]), jnp.array([10.0, 15.0]), jnp.array([9.0, 12.0])] + ) + + def test_update_param_distribution(self): + """ + Given an extremely good mixing on one chain, + and that the alpha parameter is 0, then the parameters + of that chain with a slight mutation due to noise are reused. + """ + + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + self.key, + jnp.array([1.0, 2.0, 3.0]), + self.previous_position, + self.next_position, + measure_of_chain_mixing=lambda x, y, z: jnp.array([1.0, 0.0, 0.0]), + alpha=0, + sigma_parameters=0.0001, + acceptance_probability=None, + ) + + np.testing.assert_allclose( + new_parameter_distribution, + np.array([1, 1, 1], dtype="float32"), + rtol=1e-3, + ) + np.testing.assert_allclose( + chain_mixing_measurement, + np.array([1, 0, 0], dtype="float32"), + rtol=1e-6, + ) + + def test_update_multi_sigmas(self): + """ + When we have multiple parameters, the performance is attached to its combination + so sampling must work accordingly. + """ + ( + new_parameter_distribution, + chain_mixing_measurement, + ) = update_parameter_distribution( + self.key, + { + "param_a": jnp.array([1.0, 2.0, 3.0]), + "param_b": jnp.array([[5.0, 6.0], [6.0, 7.0], [4.0, 5.0]]), + }, + self.previous_position, + self.next_position, + measure_of_chain_mixing=lambda x, y, z: jnp.array([1.0, 0.0, 0.0]), + alpha=0, + sigma_parameters={"param_a": 0.0001, "param_b": 0.00001}, + acceptance_probability=None, + ) + print(chain_mixing_measurement) + np.testing.assert_allclose(chain_mixing_measurement, np.array([1.0, 0, 0])) + + np.testing.assert_allclose( + new_parameter_distribution["param_a"], jnp.array([1.0, 1.0, 1.0]), atol=0.1 + ) + np.testing.assert_allclose( + new_parameter_distribution["param_b"], + jnp.array([[5.0, 6.0], [5.0, 6.0], [5.0, 6.0]]), + atol=0.1, + ) + + +def tuned_adaptive_tempered_inference_loop(kernel, rng_key, initial_state): + def cond(carry): + _, state, *_ = carry + return state.sampler_state.lmbda < 1 + + def body(carry): + i, state, curr_loglikelihood = carry + subkey = jax.random.fold_in(rng_key, i) + state, info = kernel(subkey, state) + return i + 1, state, curr_loglikelihood + info.log_likelihood_increment + + total_iter, final_state, log_likelihood = jax.lax.while_loop( + cond, body, (0, initial_state, 0.0) + ) + return final_state + + +class PretuningSMCTest(SMCLinearRegressionTestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.variants(with_jit=True) + def test_tempered(self): + step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel( + blackjax.tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + ) + + def loop(smc_kernel, init_particles, initial_parameters): + initial_state = init( + blackjax.tempered_smc.init, init_particles, initial_parameters + ) + + def body_fn(carry, lmbda): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, lmbda=lmbda) + return (i + 1, new_state), (new_state, info) + + num_tempering_steps = 10 + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + + (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + return result + + self.linear_regression_test_case(step_provider, loop) + + @chex.variants(with_jit=True) + def test_adaptive_tempered(self): + step_provider = lambda logprior_fn, loglikelihood_fn, pretune: blackjax.smc.pretuning.build_kernel( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + target_ess=0.5, + ) + + def loop(smc_kernel, init_particles, initial_parameters): + initial_state = init( + blackjax.tempered_smc.init, init_particles, initial_parameters + ) + return tuned_adaptive_tempered_inference_loop( + smc_kernel, self.key, initial_state + ) + + self.linear_regression_test_case(step_provider, loop) + + def linear_regression_test_case(self, step_provider, loop): + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + num_particles = 100 + sampling_key, step_size_key, integration_steps_key = jax.random.split( + self.key, 3 + ) + integration_steps_distribution = jnp.round( + jax.random.uniform( + integration_steps_key, (num_particles,), minval=1, maxval=100 + ) + ).astype(int) + + step_sizes_distribution = jax.random.uniform( + step_size_key, (num_particles,), minval=0, maxval=0.1 + ) + + # Fixes inverse_mass_matrix and distribution for the other two parameters. + initial_parameters = dict( + inverse_mass_matrix=extend_params(jnp.eye(2)), + step_size=step_sizes_distribution, + num_integration_steps=integration_steps_distribution, + ) + assert initial_parameters["step_size"].shape == (num_particles,) + assert initial_parameters["num_integration_steps"].shape == (num_particles,) + + pretune = build_pretune( + blackjax.hmc.init, + blackjax.hmc.build_kernel(), + alpha=1, + n_particles=num_particles, + sigma_parameters={"step_size": 0.01, "num_integration_steps": 2}, + natural_parameters=["num_integration_steps"], + positive_parameters=["step_size"], + ) + + step = step_provider(logprior_fn, loglikelihood_fn, pretune) + + smc_kernel = self.variant(step) + + result = loop(smc_kernel, init_particles, initial_parameters) + self.assert_linear_regression_test_case(result.sampler_state) + assert set(result.parameter_override.keys()) == { + "step_size", + "num_integration_steps", + "inverse_mass_matrix", + } + assert result.parameter_override["step_size"].shape == (num_particles,) + assert result.parameter_override["num_integration_steps"].shape == ( + num_particles, + ) + assert all(result.parameter_override["step_size"] > 0) + assert all(result.parameter_override["num_integration_steps"] > 0) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index b0e86e0b0..769078c8d 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -79,7 +79,7 @@ def test_smc_waste_free(self): {}, ) same_for_all_params = dict( - step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 + step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=100 ) hmc_kernel = functools.partial( blackjax.hmc.build_kernel(), **same_for_all_params