diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index 9e773e9b6..13764d544 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from typing import Callable import jax @@ -21,37 +22,14 @@ import blackjax.smc.solver as solver import blackjax.smc.tempered as tempered from blackjax.base import SamplingAlgorithm +from blackjax.smc import from_mcmc as smc_from_mcmc from blackjax.types import ArrayLikeTree, PRNGKey __all__ = ["build_kernel", "init", "as_top_level_api"] -def build_kernel( - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - resampling_fn: Callable, - target_ess: float, - root_solver: Callable = solver.dichotomy, - **extra_parameters, -) -> Callable: - r"""Build a Tempered SMC step using an adaptive schedule. - - Parameters - ---------- - logprior_fn: Callable - A function that computes the log-prior density. - loglikelihood_fn: Callable - A function that returns the log-likelihood density. - mcmc_kernel_factory: Callable - A callable function that creates a mcmc kernel from a log-probability - density function. - make_mcmc_state: Callable - A function that creates a new mcmc state from a position and a - log-probability density function. - resampling_fn: Callable - A random function that resamples generated particles based of weights +def build_kernel(loglikelihood_fn, target_ess, root_solver, tempered_kernel): + """ target_ess: float The target ESS for the adaptive MCMC tempering root_solver: Callable, optional @@ -60,13 +38,6 @@ def build_kernel( use_log_ess: bool, optional Use ESS in log space to solve for delta, default is `True`. This is usually more stable when using gradient based solvers. - - Returns - ------- - A callable that takes a rng_key and a TemperedSMCState that contains the current state - of the chain and that returns a new state of the chain along with - information about the transition. - """ def compute_delta(state: tempered.TemperedSMCState) -> float: @@ -83,24 +54,13 @@ def compute_delta(state: tempered.TemperedSMCState) -> float: return delta - tempered_kernel = tempered.build_kernel( - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, - **extra_parameters, - ) - def kernel( rng_key: PRNGKey, state: tempered.TemperedSMCState, - num_mcmc_steps: int, - mcmc_parameters: dict, ) -> tuple[tempered.TemperedSMCState, base.SMCInfo]: delta = compute_delta(state) lmbda = delta + state.lmbda - return tempered_kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters) + return tempered_kernel(rng_key, state, lmbda) return kernel @@ -118,6 +78,8 @@ def as_top_level_api( target_ess: float, root_solver: Callable = solver.dichotomy, num_mcmc_steps: int = 10, + mcmc_run_strategy=None, + mutation_step=None, **extra_parameters, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. @@ -148,29 +110,48 @@ def as_top_level_api( Returns ------- A ``SamplingAlgorithm``. - """ - kernel = build_kernel( + if num_mcmc_steps is not None: + mcmc_run_strategy = functools.partial( + base.update_and_take_last, num_mcmc_steps=num_mcmc_steps + ) + mutation_step = smc_from_mcmc.build_kernel( + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + mcmc_parameters, + mcmc_run_strategy, + ) + + elif mcmc_run_strategy is not None: + mutation_step = smc_from_mcmc.build_kernel( + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + mcmc_parameters, + mcmc_run_strategy, + ) + + elif mutation_step is not None: + mutation_step = mutation_step + else: + raise ValueError( + "You must either supply num_mcmc_steps, or mcmc_run_strategy or mutation_step" + ) + + tempered_kernel = tempered.build_kernel( logprior_fn, loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, - target_ess, - root_solver, - **extra_parameters, + mutation_step, ) + kernel = build_kernel(loglikelihood_fn, target_ess, root_solver, tempered_kernel) + def init_fn(position: ArrayLikeTree, rng_key=None): del rng_key return init(position) def step_fn(rng_key: PRNGKey, state): - return kernel( - rng_key, - state, - num_mcmc_steps, - mcmc_parameters, - ) + return kernel(rng_key, state) return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/smc/builder_api.py b/blackjax/smc/builder_api.py new file mode 100644 index 000000000..e5be0c1a6 --- /dev/null +++ b/blackjax/smc/builder_api.py @@ -0,0 +1,305 @@ +import functools + +import blackjax.smc.adaptive_tempered +from blackjax import SamplingAlgorithm, inner_kernel_tuning +from blackjax.smc import ( + adaptive_tempered, + from_mcmc, + partial_posteriors_path, + pretuning, + resampling, + solver, + tempered, +) +from blackjax.smc.base import update_and_take_last +from blackjax.smc.from_mcmc import build_kernel as smc_from_mcmc +from blackjax.smc.pretuning import build_kernel +from blackjax.smc.waste_free import waste_free_smc + + +class SMCSamplerBuilder: + """ + SMC is a meta-algorithm in the sense that can be constructed in + different ways by composing inner components. The aim of this API + is to foster modifying such compositions easily. + """ + + def __init__(self, resampling_fn=resampling.systematic): + self.step_structure = None + self.update_strategy = None + self.mcmc_parameter_update_fn = None + self.pretune_fn = None + self.resampling_fn = resampling_fn + + # Different ways of building the sequence of distributions + def adaptive_tempering_sequence( + self, target_ess, logprior_fn, loglikelihood_fn, root_solver=solver.dichotomy + ): + self.step_structure = "adaptive_tempering" + self.step_structure_algorithm = blackjax.smc.adaptive_tempered.build_kernel + self.step_structure_init = blackjax.smc.tempered.init + self.logprior_fn = logprior_fn + self.loglikelihood_fn = loglikelihood_fn + self.root_solver = root_solver + self.target_ess = target_ess + return self + + def tempering_from_sequence(self, logprior_fn, loglikelihood_fn): + self.step_structure = "tempering" + self.step_structure_algorithm = blackjax.smc.tempered.build_kernel + self.step_structure_init = blackjax.smc.tempered.init + self.logprior_fn = logprior_fn + self.loglikelihood_fn = loglikelihood_fn + return self + + def partial_posteriors_sequence(self, partial_logposterior_factory): + self.step_structure = "partial_posteriors" + self.step_structure_algorithm = ( + blackjax.smc.partial_posteriors_path.build_kernel + ) + self.partial_logposterior_factory = partial_logposterior_factory + return self + + # Inner kernel construction + def inner_kernel(self, init, step, inner_kernel_params): + self._inner_kernel_step = step + self._inner_kernel_init = init + self._inner_kernel_params = inner_kernel_params + return self + + # Ways of updating the particles + def mutate_waste_free(self, n_particles, p): + if self.update_strategy is not None: + raise ValueError("Can't use two update strategies at the same time") + self.update_strategy = waste_free_smc(n_particles, p) + return self + + def mutate_and_take_last(self, mcmc_steps): + if self.update_strategy is not None: + raise ValueError("Can't use two update strategies at the same time") + self.update_strategy = functools.partial( + update_and_take_last, num_mcmc_steps=mcmc_steps + ) + return self + + # Ways of tuning or pre-tuning the inner kernel parameters + def with_inner_kernel_tuning(self, mcmc_parameter_update_fn): + if self.mcmc_parameter_update_fn is not None: + raise ValueError( + "Can't call inner_kernel_tuning twice, consider merging all calls into one" + ) + + self.mcmc_parameter_update_fn = mcmc_parameter_update_fn + return self + + def with_pretuning(self, pretune_fn): + if self.pretune_fn is not None: + raise ValueError( + "Can't call pretune twice, consider merging all calls into one" + ) + self.pretune_fn = pretune_fn + return self + + def build(self): + if self.update_strategy is None: + raise ValueError( + "You must choose an update strategy, either waste_free() or mutate_and_take_last()" + ) + + if self.step_structure is None: + raise ValueError( + "You must either call adaptive_tempering(), " + "tempering_sequence()" + " or partial_posteriors_path()" + ) + + if self.step_structure == "adaptive_tempering": + init, step = self._adaptive_tempered_from_parameters() + elif self.step_structure == "tempering": + init, step = self._tempered_from_parameters() + elif self.step_structure == "partial_posteriors": + init, step = self._partial_posterior_from_parameters() + else: + raise NotImplementedError( + "The SMCBuilder API supports three ways of structuring SMC" + "steps: adaptive tempering, fixed-sequence tempering or " + "partial posteriors (data tempering). " + ) + + if self.mcmc_parameter_update_fn is None and self.pretune_fn is None: + # no tuning or pretuning is used + return SamplingAlgorithm(init, step(self._inner_kernel_params)) + + if self.mcmc_parameter_update_fn is not None and self.pretune_fn is None: + # only tuning + def new_init(position): + return inner_kernel_tuning.init( + init, position, self._inner_kernel_params + ) + + return SamplingAlgorithm( + new_init, + inner_kernel_tuning.build_kernel(step, self.mcmc_parameter_update_fn), + ) + if self.mcmc_parameter_update_fn is None and self.pretune_fn is not None: + # only pretune + return self._build_pretuning() + + # Both Pretune and Tune + raise NotImplementedError("Tuning and pretuning used together hasn't been implemented yet") + + def _adaptive_tempered_from_parameters(self): + def from_parameteres(inner_kernel_params): + mutation_step = from_mcmc.build_kernel( + self._inner_kernel_step, + self._inner_kernel_init, + self.resampling_fn, + inner_kernel_params, + self.update_strategy, + ) + + tempered_kernel = tempered.build_kernel( + self.logprior_fn, self.loglikelihood_fn, mutation_step + ) + + step = adaptive_tempered.build_kernel( + self.loglikelihood_fn, + self.target_ess, + self.root_solver, + tempered_kernel, + ) + return step + + init = tempered.init + return init, from_parameteres + + def _tempered_from_parameters(self): + def from_parameteres(inner_kernel_params): + mutation_step = from_mcmc.build_kernel( + self._inner_kernel_step, + self._inner_kernel_init, + self.resampling_fn, + inner_kernel_params, + self.update_strategy, + ) + + tempered_kernel = tempered.build_kernel( + self.logprior_fn, self.loglikelihood_fn, mutation_step + ) + return tempered_kernel + + init = tempered.init + return init, from_parameteres + + def _partial_posterior_from_parameters(self): + def from_parameters(params): + update_particles = from_mcmc.build_kernel( + self._inner_kernel_step, + self._inner_kernel_init, + self.resampling_fn, + params, + self.update_strategy, + ) + return partial_posteriors_path.build_kernel( + self.partial_logposterior_factory, update_particles + ) + + return (partial_posteriors_path.init, from_parameters) + + def _build_pretuning(self): + def delegate(rng_key, state, logposterior_fn, log_weights_fn, mcmc_parameteres): + return smc_from_mcmc( + self._inner_kernel_step, + self._inner_kernel_init, + self.resampling_fn, + mcmc_parameteres, + self.update_strategy, + )( + rng_key, + state, + logposterior_fn, + log_weights_fn, + ) + + if self.step_structure == "adaptive_tempering": + + def smc_algorithm_from_params(mcmc_parameters, pretuned_step): + tempered_kernel = blackjax.smc.tempered.build_kernel( + logprior_fn=self.logprior_fn, + loglikelihood_fn=self.loglikelihood_fn, + update_particles=functools.partial( + pretuned_step, mcmc_parameters=mcmc_parameters + ), + ) + + return blackjax.smc.adaptive_tempered.build_kernel( + loglikelihood_fn=self.loglikelihood_fn, + target_ess=self.target_ess, + root_solver=self.root_solver, + tempered_kernel=tempered_kernel, + ) + + elif self.step_structure == "tempering": + + def smc_algorithm_from_params(mcmc_parameters, pretuned_step): + return blackjax.smc.tempered.build_kernel( + logprior_fn=self.logprior_fn, + loglikelihood_fn=self.loglikelihood_fn, + update_particles=functools.partial( + pretuned_step, mcmc_parameters=mcmc_parameters + ), + ) + + kernel = build_kernel(smc_algorithm_from_params, self.pretune_fn, delegate) + + def init_fn(position, rng_key=None): + del rng_key + return pretuning.init( + blackjax.smc.tempered.init, position, self._inner_kernel_params + ) + + return SamplingAlgorithm(init_fn, kernel) + + def _tune_and_pretune(self): + def pt( + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + mcmc_parameters, + resampling_fn, + num_mcmc_steps, + initial_parameter_value, + target_ess, + ): + return blackjax.pretuning( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + num_mcmc_steps, + target_ess=self.target_ess, + pretune_fn=self.pretune, + ) + + kernel = blackjax.smc.inner_kernel_tuning.build_kernel( + pt, + self.logprior_fn, + self.loglikelihood_fn, + self._inner_kernel_step, + self._inner_kernel_init, + self.resampling_fn, + self.mcmc_parameter_update_fn, + initial_parameter_value=self._inner_kernel_params, + target_ess=self.target_ess, + smc_returns_state_with_parameter_override=True, + ) + + def init2(position): + return blackjax.smc.inner_kernel_tuning.init( + blackjax.adaptive_tempered_smc.init, position, initial_parameters + ) + + return SamplingAlgorithm(init2, kernel) diff --git a/blackjax/smc/from_mcmc.py b/blackjax/smc/from_mcmc.py index 75e5c34a6..53798e23e 100644 --- a/blackjax/smc/from_mcmc.py +++ b/blackjax/smc/from_mcmc.py @@ -29,7 +29,8 @@ def build_kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - update_strategy: Callable = update_and_take_last, + mcmc_parameters, + particle_mutation_fn: Callable = update_and_take_last, ): """SMC step from MCMC kernels. Builds MCMC kernels from the input parameters, which may change across iterations. @@ -46,8 +47,6 @@ def build_kernel( def step( rng_key: PRNGKey, state, - num_mcmc_steps: int, - mcmc_parameters: dict, logposterior_fn: Callable, log_weights_fn: Callable, ) -> tuple[smc.base.SMCState, smc.base.SMCInfo]: @@ -55,12 +54,11 @@ def step( mcmc_parameters, mcmc_step_fn ) - update_fn, num_resampled = update_strategy( + update_fn, num_resampled = particle_mutation_fn( mcmc_init_fn, logposterior_fn, shared_mcmc_step_fn, n_particles=state.weights.shape[0], - num_mcmc_steps=num_mcmc_steps, ) return smc.base.step( diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 2a63fd1ce..4bfebef91 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -22,15 +22,8 @@ def init(alg_init_fn, position, initial_parameter_value): def build_kernel( - smc_algorithm, - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - resampling_fn: Callable, + smc_step_from_mcmc_parameters: Callable, mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], - num_mcmc_steps: int = 10, - **extra_parameters, ) -> Callable: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, @@ -38,37 +31,17 @@ def build_kernel( Parameters ---------- - smc_algorithm - Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of - a sampling algorithm that returns an SMCState and SMCInfo pair). - logprior_fn - A function that computes the log density of the prior distribution - loglikelihood_fn - A function that returns the probability at a given position. - mcmc_step_fn: - The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. - mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) - mcmc_init_fn - A callable that initializes the inner kernel + smc_step_from_mcmc_parameters + A Callable that can return either blackjax.adaptive_tempered_smc.step or blackjax.tempered_smc.step (or any other implementation of + a sampling algorithm step that returns an SMCState and SMCInfo pair), out of a dictionary of parameters for MCMC inner chains. mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. - extra_parameters: - parameters to be used for the creation of the smc_algorithm. """ def kernel( rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters ) -> Tuple[StateWithParameterOverride, SMCInfo]: - step_fn = smc_algorithm( - logprior_fn=logprior_fn, - loglikelihood_fn=loglikelihood_fn, - mcmc_step_fn=mcmc_step_fn, - mcmc_init_fn=mcmc_init_fn, - mcmc_parameters=state.parameter_override, - resampling_fn=resampling_fn, - num_mcmc_steps=num_mcmc_steps, - **extra_parameters, - ).step + step_fn = smc_step_from_mcmc_parameters(state.parameter_override) new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) new_parameter_override = mcmc_parameter_update_fn(new_state, info) return StateWithParameterOverride(new_state, new_parameter_override), info @@ -121,16 +94,21 @@ def as_top_level_api( """ + def smc_step_from_mcmc_parameters(parameters): + return smc_algorithm( + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + mcmc_step_fn=mcmc_step_fn, + mcmc_init_fn=mcmc_init_fn, + mcmc_parameters=parameters, + resampling_fn=resampling_fn, + num_mcmc_steps=num_mcmc_steps, + **extra_parameters, + ).step + kernel = build_kernel( - smc_algorithm, - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, + smc_step_from_mcmc_parameters, mcmc_parameter_update_fn, - num_mcmc_steps, - **extra_parameters, ) def init_fn(position, rng_key=None): diff --git a/blackjax/smc/partial_posteriors_path.py b/blackjax/smc/partial_posteriors_path.py index 81f19716d..0281b7d26 100644 --- a/blackjax/smc/partial_posteriors_path.py +++ b/blackjax/smc/partial_posteriors_path.py @@ -1,3 +1,4 @@ +import functools from typing import Callable, NamedTuple, Optional, Tuple import jax @@ -37,13 +38,8 @@ def init(particles: ArrayLikeTree, num_datapoints: int) -> PartialPosteriorsSMCS def build_kernel( - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - resampling_fn: Callable, - num_mcmc_steps: Optional[int], - mcmc_parameters: ArrayTree, partial_logposterior_factory: Callable[[Array], Callable], - update_strategy=update_and_take_last, + update_particles: Callable, ) -> Callable: """Build the Partial Posteriors (data tempering) SMC kernel. The distribution's trajectory includes increasingly adding more @@ -70,7 +66,6 @@ def build_kernel( A callable that takes a rng_key and PartialPosteriorsSMCState and selectors for the current and previous posteriors, and takes a data-tempered SMC state. """ - delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) def step( key, state: PartialPosteriorsSMCState, data_mask: Array @@ -82,9 +77,7 @@ def step( def log_weights_fn(x): return logposterior_fn(x) - previous_logposterior_fn(x) - state, info = delegate( - key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn - ) + state, info = update_particles(key, state, logposterior_fn, log_weights_fn) return ( PartialPosteriorsSMCState(state.particles, state.weights, data_mask), @@ -101,21 +94,20 @@ def as_top_level_api( resampling_fn: Callable, num_mcmc_steps, partial_logposterior_factory: Callable, - update_strategy=update_and_take_last, + mcmc_run_strategy=update_and_take_last, ) -> SamplingAlgorithm: """A factory that wraps the kernel into a SamplingAlgorithm object. See build_kernel for full documentation on the parameters. """ + if num_mcmc_steps is not None: + mcmc_run_strategy = functools.partial( + update_and_take_last, num_mcmc_steps=num_mcmc_steps + ) - kernel = build_kernel( - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, - num_mcmc_steps, - mcmc_parameters, - partial_logposterior_factory, - update_strategy, + update_particles = smc_from_mcmc( + mcmc_step_fn, mcmc_init_fn, resampling_fn, mcmc_parameters, mcmc_run_strategy ) + kernel = build_kernel(partial_logposterior_factory, update_particles) def init_fn(position: ArrayLikeTree, num_observations, rng_key=None): del rng_key diff --git a/blackjax/smc/pretuning.py b/blackjax/smc/pretuning.py index f489a0dc2..2b4193fa5 100644 --- a/blackjax/smc/pretuning.py +++ b/blackjax/smc/pretuning.py @@ -1,3 +1,4 @@ +import functools from typing import Callable, Dict, List, NamedTuple, Optional, Tuple import jax @@ -194,16 +195,9 @@ def pretune_and_update(key, state: StateWithParameterOverride, logposterior): def build_kernel( - smc_algorithm, - logprior_fn: Callable, - loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - resampling_fn: Callable, + smc_algorithm_from_params, pretune_fn: Callable, - num_mcmc_steps: int = 10, - update_strategy=update_and_take_last, - **extra_parameters, + delegate, ) -> Callable: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner MCMC that is used to perturbate/update each of the particles. This adaptation tunes some parameter of that MCMC, @@ -228,15 +222,13 @@ def build_kernel( extra_parameters: parameters to be used for the creation of the smc_algorithm. """ - delegate = smc_from_mcmc(mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy) def pretuned_step( rng_key: PRNGKey, state, - num_mcmc_steps: int, - mcmc_parameters: dict, logposterior_fn: Callable, log_weights_fn: Callable, + mcmc_parameters, ) -> tuple[smc.base.SMCState, SMCInfoWithParameterDistribution]: """Wraps the output of smc.from_mcmc.build_kernel into a pretuning + step method. This one should be a subtype of the former, in the sense that a usage of the former @@ -249,30 +241,16 @@ def pretuned_step( StateWithParameterOverride(state, mcmc_parameters), logposterior_fn, ) + state, info = delegate( - rng_key, - state, - num_mcmc_steps, - pretuned_parameters, - logposterior_fn, - log_weights_fn, + rng_key, state, logposterior_fn, log_weights_fn, pretuned_parameters ) return state, SMCInfoWithParameterDistribution(info, pretuned_parameters) def kernel( rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters ) -> Tuple[StateWithParameterOverride, SMCInfo]: - extra_parameters["update_particles_fn"] = pretuned_step - step_fn = smc_algorithm( - logprior_fn=logprior_fn, - loglikelihood_fn=loglikelihood_fn, - mcmc_step_fn=mcmc_step_fn, - mcmc_init_fn=mcmc_init_fn, - mcmc_parameters=state.parameter_override, - resampling_fn=resampling_fn, - num_mcmc_steps=num_mcmc_steps, - **extra_parameters, - ).step + step_fn = smc_algorithm_from_params(state.parameter_override, pretuned_step) new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) return ( StateWithParameterOverride(new_state, info.parameter_override), @@ -293,9 +271,10 @@ def as_top_level_api( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - num_mcmc_steps: int, + num_mcmc_steps: Optional[int], initial_parameter_value: ArrayLikeTree, pretune_fn: Callable, + mcmc_run_strategy: Callable = None, **extra_parameters, ): """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner @@ -321,18 +300,39 @@ def as_top_level_api( extra_parameters: parameters to be used for the creation of the smc_algorithm. """ + if num_mcmc_steps is not None: + us = functools.partial(update_and_take_last, num_mcmc_steps=num_mcmc_steps) + elif mcmc_run_strategy is not None: + us = mcmc_run_strategy + else: + raise ValueError("here") - kernel = build_kernel( - smc_algorithm, - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, - pretune_fn, - num_mcmc_steps, - **extra_parameters, - ) + def delegate(rng_key, state, logposterior_fn, log_weights_fn, mcmc_parameteres): + return smc_from_mcmc( + mcmc_step_fn, mcmc_init_fn, resampling_fn, mcmc_parameteres, us + )( + rng_key, + state, + logposterior_fn, + log_weights_fn, + ) + + def smc_algorithm_from_params(mcmc_parameters, pretuned_step): + return smc_algorithm( + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + mcmc_step_fn=mcmc_step_fn, + mcmc_init_fn=mcmc_init_fn, + mcmc_parameters=mcmc_parameters, + resampling_fn=resampling_fn, + num_mcmc_steps=None, + mutation_step=functools.partial( + pretuned_step, mcmc_parameters=mcmc_parameters + ), + **extra_parameters, + ).step + + kernel = build_kernel(smc_algorithm_from_params, pretune_fn, delegate) def init_fn(position, rng_key=None): del rng_key diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 350037f9c..7ba1f8b28 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from typing import Callable, NamedTuple, Optional import jax @@ -51,11 +52,7 @@ def init(particles: ArrayLikeTree): def build_kernel( logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_step_fn: Callable, - mcmc_init_fn: Callable, - resampling_fn: Callable, - update_strategy: Callable = update_and_take_last, - update_particles_fn: Optional[Callable] = None, + update_particles: Callable, ) -> Callable: """Build the base Tempered SMC kernel. @@ -93,20 +90,11 @@ def build_kernel( information about the transition. """ - update_particles = ( - smc_from_mcmc.build_kernel( - mcmc_step_fn, mcmc_init_fn, resampling_fn, update_strategy - ) - if update_particles_fn is None - else update_particles_fn - ) def kernel( rng_key: PRNGKey, state: TemperedSMCState, - num_mcmc_steps: int, lmbda: float, - mcmc_parameters: dict, ) -> tuple[TemperedSMCState, smc.base.SMCInfo]: """Move the particles one step using the Tempered SMC algorithm. @@ -143,8 +131,6 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: smc_state, info = update_particles( rng_key, state, - num_mcmc_steps, - mcmc_parameters, tempered_logposterior_fn, log_weights_fn, ) @@ -166,8 +152,8 @@ def as_top_level_api( mcmc_parameters: dict, resampling_fn: Callable, num_mcmc_steps: Optional[int] = 10, - update_strategy=update_and_take_last, - update_particles_fn=None, + mcmc_run_strategy=None, + mutation_step=None, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Adaptive Tempered SMC kernel. @@ -195,15 +181,35 @@ def as_top_level_api( """ - kernel = build_kernel( - logprior_fn, - loglikelihood_fn, - mcmc_step_fn, - mcmc_init_fn, - resampling_fn, - update_strategy, - update_particles_fn, - ) + if num_mcmc_steps is not None: + mcmc_run_strategy = functools.partial( + update_and_take_last, num_mcmc_steps=num_mcmc_steps + ) + mutation_step = smc_from_mcmc.build_kernel( + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + mcmc_parameters, + mcmc_run_strategy, + ) + + elif mcmc_run_strategy is not None: + mutation_step = smc_from_mcmc.build_kernel( + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + mcmc_parameters, + mcmc_run_strategy, + ) + + elif mutation_step is not None: + mutation_step = mutation_step + else: + raise ValueError( + "You must either supply num_mcmc_steps, or mcmc_run_strategy or mutation_step" + ) + + kernel = build_kernel(logprior_fn, loglikelihood_fn, mutation_step) def init_fn(position: ArrayLikeTree, rng_key=None): del rng_key @@ -213,9 +219,7 @@ def step_fn(rng_key: PRNGKey, state, lmbda): return kernel( rng_key, state, - num_mcmc_steps, lmbda, - mcmc_parameters, ) return SamplingAlgorithm(init_fn, step_fn) # type: ignore[arg-type] diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 7d6190af5..bb5b091a1 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -14,6 +14,7 @@ from blackjax import adaptive_tempered_smc, tempered_smc from blackjax.mcmc.random_walk import build_irmh from blackjax.smc import extend_params +from blackjax.smc.builder_api import SMCSamplerBuilder from blackjax.smc.inner_kernel_tuning import as_top_level_api as inner_kernel_tuning from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( @@ -71,20 +72,106 @@ def logdensity_fn(self, log_scale, coefs, preds, x): logpdf = stats.norm.logpdf(preds, y, scale) return jnp.sum(logpdf) + def top_level_api_sampler_provider( + self, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + smc_algorithm, + mcmc_parameter_update_fn, + initial_parameter_value, + smc_parameters, + ): + return inner_kernel_tuning( + logprior_fn=logprior_fn, + loglikelihood_fn=loglikelihood_fn, + mcmc_step_fn=mcmc_step_fn, + mcmc_init_fn=mcmc_init_fn, + resampling_fn=resampling_fn, + smc_algorithm=smc_algorithm, + mcmc_parameter_update_fn=mcmc_parameter_update_fn, + initial_parameter_value=initial_parameter_value, + **smc_parameters, + ) + + def builder_api_adaptive_provider( + self, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + smc_algorithm, + mcmc_parameter_update_fn, + initial_parameter_value, + smc_parameters, + ): + return ( + SMCSamplerBuilder(resampling_fn) + .adaptive_tempering_sequence(0.5, logprior_fn, loglikelihood_fn) + .inner_kernel(mcmc_init_fn, mcmc_step_fn, initial_parameter_value) + .with_inner_kernel_tuning(mcmc_parameter_update_fn) + .mutate_and_take_last(10) + .build() + ) + + def builder_api_tempered_provider( + self, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + smc_algorithm, + mcmc_parameter_update_fn, + initial_parameter_value, + smc_parameters, + ): + return ( + SMCSamplerBuilder(resampling_fn) + .tempering_from_sequence(logprior_fn, loglikelihood_fn) + .inner_kernel(mcmc_init_fn, mcmc_step_fn, initial_parameter_value) + .with_inner_kernel_tuning(mcmc_parameter_update_fn) + .mutate_and_take_last(10) + .build() + ) + + def test_smc_inner_kernel_adaptive_tempered_builder_api(self): + self.smc_inner_kernel_tuning_test_case( + blackjax.adaptive_tempered_smc, + smc_parameters={"target_ess": 0.5}, + step_parameters={}, + sampler_provider=self.builder_api_adaptive_provider, + ) + + def test_smc_inner_kernel_tempered_builder_api(self): + self.smc_inner_kernel_tuning_test_case( + blackjax.tempered_smc, + smc_parameters={}, + step_parameters={"lmbda": 0.75}, + sampler_provider=self.builder_api_tempered_provider, + ) + def test_smc_inner_kernel_adaptive_tempered(self): self.smc_inner_kernel_tuning_test_case( blackjax.adaptive_tempered_smc, smc_parameters={"target_ess": 0.5}, step_parameters={}, + sampler_provider=self.top_level_api_sampler_provider, ) def test_smc_inner_kernel_tempered(self): self.smc_inner_kernel_tuning_test_case( - blackjax.tempered_smc, smc_parameters={}, step_parameters={"lmbda": 0.75} + blackjax.tempered_smc, + smc_parameters={}, + step_parameters={"lmbda": 0.75}, + sampler_provider=self.top_level_api_sampler_provider, ) def smc_inner_kernel_tuning_test_case( - self, smc_algorithm, smc_parameters, step_parameters + self, smc_algorithm, smc_parameters, step_parameters, sampler_provider ): specialized_log_weights_fn = lambda tree: log_weights_fn(tree, 1.0) # Don't use exactly the invariant distribution for the MCMC kernel @@ -106,7 +193,7 @@ def wrapped_kernel(rng_key, state, logdensity, mean): functools.partial(irmh_proposal_distribution, mean=mean), ) - kernel = inner_kernel_tuning( + kernel = sampler_provider( logprior_fn=prior, loglikelihood_fn=specialized_log_weights_fn, mcmc_step_fn=wrapped_kernel, @@ -115,7 +202,7 @@ def wrapped_kernel(rng_key, state, logdensity, mean): smc_algorithm=smc_algorithm, mcmc_parameter_update_fn=mcmc_parameter_update_fn, initial_parameter_value=extend_params({"mean": 1.0}), - **smc_parameters, + smc_parameters=smc_parameters, ) new_state, new_info = kernel.step( diff --git a/tests/smc/test_partial_posteriors_smc.py b/tests/smc/test_partial_posteriors_smc.py index 78d57a934..4cf5d83d0 100644 --- a/tests/smc/test_partial_posteriors_smc.py +++ b/tests/smc/test_partial_posteriors_smc.py @@ -7,6 +7,7 @@ import blackjax import blackjax.smc.resampling as resampling from blackjax.smc import extend_params +from blackjax.smc.builder_api import SMCSamplerBuilder from tests.smc import SMCLinearRegressionTestCase @@ -18,7 +19,37 @@ def setUp(self): self.key = jax.random.key(42) @chex.variants(with_jit=True) - def test_partial_posteriors(self): + def test_partial_posteriors_top_level_api(self): + def sampler_provider( + kernel, init, parameters, steps, partial_logposterior_factory + ): + return blackjax.partial_posteriors_smc( + kernel, + init, + parameters, + resampling.systematic, + steps, + partial_logposterior_factory, + ) + + self.partial_posteriors_test_case(sampler_provider) + + @chex.variants(with_jit=True) + def test_partial_posteriors_builder_api(self): + def sampler_provider( + kernel, init, parameters, steps, partial_logposterior_factory + ): + return ( + SMCSamplerBuilder() + .partial_posteriors_sequence(partial_logposterior_factory) + .inner_kernel(init, kernel, parameters) + .mutate_and_take_last(steps) + .build() + ) + + self.partial_posteriors_test_case(sampler_provider) + + def partial_posteriors_test_case(self, sampler_provider): ( init_particles, logprior_fn, @@ -48,13 +79,12 @@ def partial_logposterior(x): return jax.jit(partial_logposterior) - init, kernel = blackjax.partial_posteriors_smc( + init, kernel = sampler_provider( hmc_kernel, hmc_init, hmc_parameters, - resampling.systematic, 50, - partial_logposterior_factory=partial_logposterior_factory, + partial_logposterior_factory, ) init_state = init(init_particles, 1000) diff --git a/tests/smc/test_pretuning.py b/tests/smc/test_pretuning.py index a677c99ae..d197af760 100644 --- a/tests/smc/test_pretuning.py +++ b/tests/smc/test_pretuning.py @@ -8,6 +8,8 @@ import blackjax from blackjax.smc import extend_params, resampling +from blackjax.smc.base import update_and_take_last +from blackjax.smc.builder_api import SMCSamplerBuilder from blackjax.smc.pretuning import ( build_pretune, esjd, @@ -145,13 +147,128 @@ def test_update_multi_sigmas(self): ) +def tuned_adaptive_tempered_inference_loop(kernel, rng_key, initial_state): + def cond(carry): + _, state, *_ = carry + return state.sampler_state.lmbda < 1 + + def body(carry): + i, state, curr_loglikelihood = carry + subkey = jax.random.fold_in(rng_key, i) + state, info = kernel(subkey, state) + return i + 1, state, curr_loglikelihood + info.log_likelihood_increment + + total_iter, final_state, log_likelihood = jax.lax.while_loop( + cond, body, (0, initial_state, 0.0) + ) + return final_state + + class PretuningSMCTest(SMCLinearRegressionTestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) + def loop(self, init, smc_kernel, init_particles): + initial_state = init(init_particles) + + def body_fn(carry, lmbda): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, lmbda=lmbda) + return (i + 1, new_state), (new_state, info) + + num_tempering_steps = 10 + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + + (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + return result + + @chex.variants(with_jit=True) + def test_tempered_builder_api(self): + step_provider = ( + lambda logprior_fn, loglikelihood_fn, pretune, initial_parameter_value: ( + SMCSamplerBuilder(resampling.systematic) + .tempering_from_sequence(logprior_fn, loglikelihood_fn) + .inner_kernel( + blackjax.hmc.init, + blackjax.hmc.build_kernel(), + initial_parameter_value, + ) + .mutate_and_take_last(10) + .with_pretuning(pretune) + .build() + ) + ) + + self.linear_regression_test_case(step_provider, self.loop) + + @chex.variants(with_jit=True) + def test_tempered_top_level_api(self): + def step_provider( + logprior_fn, loglikelihood_fn, pretune, initial_parameter_value + ): + return blackjax.pretuning( + blackjax.tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + initial_parameter_value=initial_parameter_value, + pretune_fn=pretune, + ) + + self.linear_regression_test_case(step_provider, self.loop) + + @chex.variants(with_jit=True) + def test_adaptive_tempered_builder_api(self): + step_provider = ( + lambda logprior_fn, loglikelihood_fn, pretune, initial_parameters: ( + SMCSamplerBuilder() + .adaptive_tempering_sequence(0.5, logprior_fn, loglikelihood_fn) + .inner_kernel( + blackjax.hmc.init, blackjax.hmc.build_kernel(), initial_parameters + ) + .with_pretuning(pretune) + .mutate_and_take_last(10) + .build() + ) + ) + + def loop(init, smc_kernel, init_particles): + initial_state = init(init_particles) + return tuned_adaptive_tempered_inference_loop( + smc_kernel, self.key, initial_state + ) + + self.linear_regression_test_case(step_provider, loop) + @chex.variants(with_jit=True) - def test_linear_regression(self): + def test_adaptive_tempered_top_level_api(self): + step_provider = lambda logprior_fn, loglikelihood_fn, pretune, initial_parameters: blackjax.pretuning( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + initial_parameter_value=initial_parameters, + target_ess=0.5, + ) + + def loop(init, smc_kernel, init_particles): + initial_state = init(init_particles) + return tuned_adaptive_tempered_inference_loop( + smc_kernel, self.key, initial_state + ) + + self.linear_regression_test_case(step_provider, loop) + + def linear_regression_test_case(self, step_provider, loop): ( init_particles, logprior_fn, @@ -191,32 +308,13 @@ def test_linear_regression(self): positive_parameters=["step_size"], ) - step = blackjax.smc.pretuning.build_kernel( - blackjax.tempered_smc, - logprior_fn, - loglikelihood_fn, - blackjax.hmc.build_kernel(), - blackjax.hmc.init, - resampling.systematic, - num_mcmc_steps=10, - pretune_fn=pretune, + init, step = step_provider( + logprior_fn, loglikelihood_fn, pretune, initial_parameters ) - initial_state = init( - blackjax.tempered_smc.init, init_particles, initial_parameters - ) smc_kernel = self.variant(step) - def body_fn(carry, lmbda): - i, state = carry - subkey = jax.random.fold_in(self.key, i) - new_state, info = smc_kernel(subkey, state, lmbda=lmbda) - return (i + 1, new_state), (new_state, info) - - num_tempering_steps = 10 - lambda_schedule = np.logspace(-5, 0, num_tempering_steps) - - (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + result = loop(init, smc_kernel, init_particles) self.assert_linear_regression_test_case(result.sampler_state) assert set(result.parameter_override.keys()) == { "step_size", diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index 527457d62..8a72c8c53 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -13,6 +13,7 @@ import blackjax.smc.solver as solver from blackjax import adaptive_tempered_smc, tempered_smc from blackjax.smc import extend_params +from blackjax.smc.builder_api import SMCSamplerBuilder from tests.smc import SMCLinearRegressionTestCase @@ -42,7 +43,58 @@ def setUp(self): self.key = jax.random.key(42) @chex.variants(with_jit=True) - def test_adaptive_tempered_smc(self): + def test_adaptive_tempered_smc_builder_api(self): + def sampler_provider( + target_ess, + logprior_fn, + loglikelihood_fn, + hmc_init, + hmc_kernel, + hmc_parameters, + ): + return ( + SMCSamplerBuilder(resampling.systematic) + .adaptive_tempering_sequence( + target_ess, logprior_fn, loglikelihood_fn, solver.dichotomy + ) + .inner_kernel(hmc_init, hmc_kernel, hmc_parameters) + .mutate_and_take_last(5) + .build() + ) + + self.adaptive_tempered_test_case(sampler_provider) + + @chex.variants(with_jit=True) + def test_adaptive_tempered_smc_top_level_api(self): + def sampler_provider( + target_ess, + logprior_fn, + loglikelihood_fn, + hmc_init, + hmc_kernel, + hmc_parameters, + ): + return adaptive_tempered_smc( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + resampling.systematic, + target_ess, + solver.dichotomy, + 5, + ) + + self.adaptive_tempered_test_case(sampler_provider) + + def adaptive_tempered_test_case(self, sampler_provider): + """ + extracting to a function instead of parametrizing + because the parallel testing framework + doesn't allow parametrization on functions + """ + num_particles = 100 x_data = np.random.normal(0, 1, size=(1000, 1)) @@ -87,17 +139,15 @@ def logprior_fn(x): ] for target_ess, hmc_parameters in zip([0.5, 0.5, 0.75], hmc_parameters_list): - tempering = adaptive_tempered_smc( + tempering = sampler_provider( + target_ess, logprior_fn, loglikelihood_fn, - hmc_kernel, hmc_init, + hmc_kernel, hmc_parameters, - resampling.systematic, - target_ess, - solver.dichotomy, - 5, ) + init_state = tempering.init(smc_state_init) n_iter, result, log_likelihood = self.variant( @@ -114,7 +164,31 @@ def logprior_fn(x): assert iterates[1] >= iterates[0] @chex.variants(with_jit=True) - def test_fixed_schedule_tempered_smc(self): + def test_fixed_schedule_tempered_smc_top_level_api(self): + self.fixed_schedule_tempered_smc_test_case(tempered_smc) + + @chex.variants(with_jit=True) + def test_fixed_schedule_tempered_smc_builder_api(self): + def sampler_provider( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + resampling_fn, + num_mcmc_steps, + ): + return ( + SMCSamplerBuilder(resampling_fn) + .tempering_from_sequence(logprior_fn, loglikelihood_fn) + .inner_kernel(hmc_init, hmc_kernel, hmc_parameters) + .mutate_and_take_last(num_mcmc_steps) + .build() + ) + + self.fixed_schedule_tempered_smc_test_case(sampler_provider) + + def fixed_schedule_tempered_smc_test_case(self, sampler_provider): ( init_particles, logprior_fn, @@ -134,7 +208,7 @@ def test_fixed_schedule_tempered_smc(self): }, ) - tempering = tempered_smc( + tempering = sampler_provider( logprior_fn, loglikelihood_fn, hmc_kernel, diff --git a/tests/smc/test_tune_with_pretune.py b/tests/smc/test_tune_with_pretune.py new file mode 100644 index 000000000..cf088bc59 --- /dev/null +++ b/tests/smc/test_tune_with_pretune.py @@ -0,0 +1,151 @@ +from blackjax.smc.builder_api import SMCSamplerBuilder +from tests.smc import SMCLinearRegressionTestCase +import jax + +class PretuningWithTuningSMCTest(SMCLinearRegressionTestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + def loop(self, init, smc_kernel, init_particles): + initial_state = init(init_particles) + + def body_fn(carry, lmbda): + i, state = carry + subkey = jax.random.fold_in(self.key, i) + new_state, info = smc_kernel(subkey, state, lmbda=lmbda) + return (i + 1, new_state), (new_state, info) + + num_tempering_steps = 10 + lambda_schedule = np.logspace(-5, 0, num_tempering_steps) + + (_, result), _ = jax.lax.scan(body_fn, (0, initial_state), lambda_schedule) + return result + @chex.variants(with_jit=True) + def test_tempered_top_level_api(self): + def step_provider( + logprior_fn, loglikelihood_fn, pretune, initial_parameter_value + ): + return blackjax.pretuning( + blackjax.tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + initial_parameter_value=initial_parameter_value, + pretune_fn=pretune, + ) + + self.linear_regression_test_case(step_provider, self.loop) + + @chex.variants(with_jit=True) + def test_adaptive_tempered_builder_api(self): + step_provider = ( + lambda logprior_fn, loglikelihood_fn, pretune, initial_parameters: ( + SMCSamplerBuilder() + .adaptive_tempering_sequence(0.5, logprior_fn, loglikelihood_fn) + .inner_kernel( + blackjax.hmc.init, blackjax.hmc.build_kernel(), initial_parameters + ) + .with_pretuning(pretune) + .mutate_and_take_last(10) + .build() + ) + ) + + def loop(init, smc_kernel, init_particles): + initial_state = init(init_particles) + return tuned_adaptive_tempered_inference_loop( + smc_kernel, self.key, initial_state + ) + + self.linear_regression_test_case(step_provider, loop) + + @chex.variants(with_jit=True) + def test_adaptive_tempered_top_level_api(self): + step_provider = lambda logprior_fn, loglikelihood_fn, pretune, initial_parameters: blackjax.pretuning( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + num_mcmc_steps=10, + pretune_fn=pretune, + initial_parameter_value=initial_parameters, + target_ess=0.5, + ) + + def loop(init, smc_kernel, init_particles): + initial_state = init(init_particles) + return tuned_adaptive_tempered_inference_loop( + smc_kernel, self.key, initial_state + ) + + self.linear_regression_test_case(step_provider, loop) + + def linear_regression_test_case(self, step_provider, loop): + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + num_particles = 100 + sampling_key, step_size_key, integration_steps_key = jax.random.split( + self.key, 3 + ) + integration_steps_distribution = jnp.round( + jax.random.uniform( + integration_steps_key, (num_particles,), minval=1, maxval=100 + ) + ).astype(int) + + step_sizes_distribution = jax.random.uniform( + step_size_key, (num_particles,), minval=0, maxval=0.1 + ) + + # Fixes inverse_mass_matrix and distribution for the other two parameters. + initial_parameters = dict( + inverse_mass_matrix=extend_params(jnp.eye(2)), + step_size=step_sizes_distribution, + num_integration_steps=integration_steps_distribution, + ) + assert initial_parameters["step_size"].shape == (num_particles,) + assert initial_parameters["num_integration_steps"].shape == (num_particles,) + + pretune = build_pretune( + blackjax.hmc.init, + blackjax.hmc.build_kernel(), + alpha=1, + n_particles=num_particles, + sigma_parameters={"step_size": 0.01, "num_integration_steps": 2}, + natural_parameters=["num_integration_steps"], + positive_parameters=["step_size"], + ) + + init, step = step_provider( + logprior_fn, loglikelihood_fn, pretune, initial_parameters + ) + + smc_kernel = self.variant(step) + + result = loop(init, smc_kernel, init_particles) + self.assert_linear_regression_test_case(result.sampler_state) + assert set(result.parameter_override.keys()) == { + "step_size", + "num_integration_steps", + "inverse_mass_matrix", + } + assert result.parameter_override["step_size"].shape == (num_particles,) + assert result.parameter_override["num_integration_steps"].shape == ( + num_particles, + ) + assert all(result.parameter_override["step_size"] > 0) + assert all(result.parameter_override["num_integration_steps"] > 0) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/smc/test_waste_free_smc.py b/tests/smc/test_waste_free_smc.py index a5eeef135..b6a098c1c 100644 --- a/tests/smc/test_waste_free_smc.py +++ b/tests/smc/test_waste_free_smc.py @@ -12,7 +12,8 @@ import blackjax import blackjax.smc.resampling as resampling from blackjax import adaptive_tempered_smc, tempered_smc -from blackjax.smc import extend_params +from blackjax.smc import extend_params, solver +from blackjax.smc.builder_api import SMCSamplerBuilder from blackjax.smc.waste_free import update_waste_free, waste_free_smc from tests.smc import SMCLinearRegressionTestCase from tests.smc.test_tempered_smc import inference_loop @@ -26,7 +27,51 @@ def setUp(self): self.key = jax.random.key(42) @chex.variants(with_jit=True) - def test_fixed_schedule_tempered_smc(self): + def test_fixed_schedule_tempered_smc_top_level_api(self): + def sampler_provider( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + n_particles, + p, + ): + return tempered_smc( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + resampling.systematic, + None, + waste_free_smc(n_particles, p), + ) + + self.fixed_schedule_tempered_smc_test_case(sampler_provider) + + @chex.variants(with_jit=True) + def test_fixed_schedule_tempered_smc_builder_api(self): + def sampler_provider( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + n_particles, + p, + ): + return ( + SMCSamplerBuilder(resampling.systematic) + .tempering_from_sequence(logprior_fn, loglikelihood_fn) + .inner_kernel(hmc_init, hmc_kernel, hmc_parameters) + .mutate_waste_free(n_particles, p) + .build() + ) + + self.fixed_schedule_tempered_smc_test_case(sampler_provider) + + def fixed_schedule_tempered_smc_test_case(self, sampler_provider): ( init_particles, logprior_fn, @@ -46,16 +91,10 @@ def test_fixed_schedule_tempered_smc(self): }, ) - tempering = tempered_smc( - logprior_fn, - loglikelihood_fn, - hmc_kernel, - hmc_init, - hmc_parameters, - resampling.systematic, - None, - waste_free_smc(100, 4), + tempering = sampler_provider( + logprior_fn, loglikelihood_fn, hmc_kernel, hmc_init, hmc_parameters, 100, 4 ) + init_state = tempering.init(init_particles) smc_kernel = self.variant(tempering.step) @@ -69,7 +108,54 @@ def body_fn(carry, lmbda): self.assert_linear_regression_test_case(result) @chex.variants(with_jit=True) - def test_adaptive_tempered_smc(self): + def test_adaptive_tempered_smc_top_level_api(self): + def sampler_provider( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + target_ess, + n_particles, + p, + ): + return adaptive_tempered_smc( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + resampling.systematic, + target_ess, + mcmc_run_strategy=waste_free_smc(n_particles, p), + num_mcmc_steps=None, + ) + + self.adaptive_tempered_smc_test_case(sampler_provider) + + @chex.variants(with_jit=True) + def test_adaptive_tempered_smc_builder_api(self): + def sampler_provider( + logprior_fn, + loglikelihood_fn, + hmc_kernel, + hmc_init, + hmc_parameters, + target_ess, + n_particles, + p, + ): + return ( + SMCSamplerBuilder() + .adaptive_tempering_sequence(target_ess, logprior_fn, loglikelihood_fn) + .inner_kernel(hmc_init, hmc_kernel, hmc_parameters) + .mutate_waste_free(n_particles, p) + .build() + ) + + self.adaptive_tempered_smc_test_case(sampler_provider) + + def adaptive_tempered_smc_test_case(self, sampler_provider): ( init_particles, logprior_fn, @@ -86,16 +172,15 @@ def test_adaptive_tempered_smc(self): }, ) - tempering = adaptive_tempered_smc( + tempering = sampler_provider( logprior_fn, loglikelihood_fn, hmc_kernel, hmc_init, hmc_parameters, - resampling.systematic, 0.5, - update_strategy=waste_free_smc(100, 4), - num_mcmc_steps=None, + 100, + 4, ) init_state = tempering.init(init_particles)