From fa158ea1b16bfec8a0872b736a37e6961c8c1cd5 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 11 Dec 2024 09:32:14 -0800 Subject: [PATCH 01/63] langevin structs --- blackjax/__init__.py | 2 + blackjax/mcmc/__init__.py | 2 + blackjax/mcmc/underdamped_langevin.py | 163 ++++++++++++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 blackjax/mcmc/underdamped_langevin.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index dfdcfc545..092a7e3bd 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -19,6 +19,7 @@ from .mcmc import mala as _mala from .mcmc import marginal_latent_gaussian from .mcmc import mclmc as _mclmc +from .mcmc import underdamped_langevin as _langevin from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk from .mcmc import rmhmc as _rmhmc @@ -109,6 +110,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +langevin = generate_top_level_api_from(_langevin) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 6e207741d..8179156f4 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -10,6 +10,7 @@ periodic_orbital, random_walk, rmhmc, + underdamped_langevin ) __all__ = [ @@ -24,4 +25,5 @@ "marginal_latent_gaussian", "random_walk", "mclmc", + "underdamped_langevin" ] diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py new file mode 100644 index 000000000..2bd2f1af1 --- /dev/null +++ b/blackjax/mcmc/underdamped_langevin.py @@ -0,0 +1,163 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Underdamped Langevin Kernel""" +from typing import Callable, NamedTuple + +import jax + +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import ( + IntegratorState, + isokinetic_mclachlan, + with_isokinetic_maruyama, +) +from blackjax.types import ArrayLike, PRNGKey +from blackjax.util import generate_unit_vector, pytree_size + +__all__ = ["LangevinInfo", "init", "build_kernel", "as_top_level_api"] + + +class LangevinInfo(NamedTuple): + """ + Additional information on the Langevin transition. + + logdensity + The log-density of the distribution at the current step of the Langevin chain. + kinetic_change + The difference in kinetic energy between the current and previous step. + energy_change + The difference in energy between the current and previous step. + """ + + logdensity: float + kinetic_change: float + energy_change: float + + +def init(position: ArrayLike, logdensity_fn, rng_key): + + l, g = jax.value_and_grad(logdensity_fn)(position) + + return IntegratorState( + position=position, + momentum=generate_unit_vector(rng_key, position), + logdensity=l, + logdensity_grad=g, + ) + + +def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): + """Build a HMC kernel. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Langevin dynamics. + L + the momentum decoherence rate. + step_size + step size of the integrator. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + + step = with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ) + + def kernel( + rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float + ) -> tuple[IntegratorState, LangevinInfo]: + (position, momentum, logdensity, logdensitygrad), kinetic_change = step( + state, step_size, L, rng_key + ) + + return IntegratorState( + position, momentum, logdensity, logdensitygrad + ), LangevinInfo( + logdensity=logdensity, + energy_change=kinetic_change - logdensity + state.logdensity, + kinetic_change=kinetic_change, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + L, + step_size, + integrator=isokinetic_mclachlan, + sqrt_diag_cov=1.0, +) -> SamplingAlgorithm: + """The general Langevin kernel builder (:meth:`blackjax.mcmc.langevin.build_kernel`, alias `blackjax.langevin.build_kernel`) can be + cumbersome to manipulate. Since most users only need to specify the kernel + parameters at initialization time, we provide a helper function that + specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.langevin` to SMC, adaptation, etc. algorithms. + + Examples + -------- + + A new langevin kernel can be initialized and used with the following code: + + .. code:: + + langevin = blackjax.mcmc.langevin.langevin( + logdensity_fn=logdensity_fn, + L=L, + step_size=step_size + ) + state = langevin.init(position) + new_state, info = langevin.step(rng_key, state) + + Kernels are not jit-compiled by default so you will need to do it manually: + + .. code:: + + step = jax.jit(langevin.step) + new_state, info = step(rng_key, state) + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + L + the momentum decoherence rate + step_size + step size of the integrator + integrator + an integrator. We recommend using the default here. + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) + + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key, state): + return kernel(rng_key, state, L, step_size) + + return SamplingAlgorithm(init_fn, update_fn) From 8eed424c181f0c5c0b5cf5e50780115d14704fc6 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 12:42:00 -0500 Subject: [PATCH 02/63] add static adjusted mclmc --- blackjax/__init__.py | 2 + blackjax/mcmc/__init__.py | 4 +- blackjax/mcmc/adjusted_mclmc.py | 57 ++---- blackjax/mcmc/adjusted_mclmc_dynamic.py | 257 ++++++++++++++++++++++++ tests/mcmc/test_sampling.py | 110 +++++++++- 5 files changed, 384 insertions(+), 46 deletions(-) create mode 100644 blackjax/mcmc/adjusted_mclmc_dynamic.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 6a0de3809..35c9e3b58 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -13,6 +13,7 @@ from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat from .mcmc import adjusted_mclmc as _adjusted_mclmc +from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice @@ -112,6 +113,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 1e1317684..fad5dcb97 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,5 +1,5 @@ from . import ( - adjusted_mclmc, + adjusted_mclmc_dynamic, barker, elliptical_slice, ghmc, @@ -25,5 +25,5 @@ "marginal_latent_gaussian", "random_walk", "mclmc", - "adjusted_mclmc", + "adjusted_mclmc_dynamic", ] diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 81fbc2835..8288772a3 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -11,7 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". + +NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. + +""" from typing import Callable, Union import jax @@ -19,28 +23,26 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence -from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.hmc import HMCInfo, HMCState from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector __all__ = ["init", "build_kernel", "as_top_level_api"] -def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): +def init(position: ArrayLikeTree, logdensity_fn: Callable): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) - return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + return HMCState(position, logdensity, logdensity_grad) def build_kernel( - integration_steps_fn, + num_integration_steps: int, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], sqrt_diag_cov=1.0, ): - """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. Parameters ---------- @@ -63,15 +65,13 @@ def build_kernel( def kernel( rng_key: PRNGKey, - state: DynamicHMCState, + state: HMCState, logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - ) -> tuple[DynamicHMCState, HMCInfo]: + ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" - num_integration_steps = integration_steps_fn(state.random_generator_arg) - key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( @@ -90,11 +90,10 @@ def kernel( ) return ( - DynamicHMCState( + HMCState( proposal.position, proposal.logdensity, proposal.logdensity_grad, - next_random_arg_fn(state.random_generator_arg), ), info, ) @@ -110,10 +109,9 @@ def as_top_level_api( *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, - next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + num_integration_steps, ) -> SamplingAlgorithm: - """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + """Implements the (basic) user interface for the MHMCHMC kernel. Parameters ---------- @@ -140,15 +138,15 @@ def as_top_level_api( """ kernel = build_kernel( - integration_steps_fn=integration_steps_fn, + num_integration_steps, integrator=integrator, - next_random_arg_fn=next_random_arg_fn, sqrt_diag_cov=sqrt_diag_cov, divergence_threshold=divergence_threshold, ) - def init_fn(position: ArrayLikeTree, rng_key: Array): - return init(position, logdensity_fn, rng_key) + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) def update_fn(rng_key: PRNGKey, state): return kernel( @@ -240,18 +238,3 @@ def generate( return sampled_state, info, other_proposal_info return generate - - -def rescale(mu): - """returns s, such that - round(U(0, 1) * s + 0.5) - has expected value mu. - """ - k = jnp.floor(2 * mu - 1) - x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) - return k + x - - -def trajectory_length(t, mu): - s = rescale(mu) - return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py new file mode 100644 index 000000000..81fbc2835 --- /dev/null +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -0,0 +1,257 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integration_steps_fn, + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + sqrt_diag_cov=1.0, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + num_integration_steps = integration_steps_fn(state.random_generator_arg) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + DynamicHMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + sqrt_diag_cov=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integration_steps_fn=integration_steps_fn, + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + sqrt_diag_cov=sqrt_diag_cov, + divergence_threshold=divergence_threshold, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 474f67293..45d60f84a 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -14,7 +14,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info -from blackjax.mcmc.adjusted_mclmc import rescale +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -146,7 +146,7 @@ def run_mclmc( return samples - def run_adjusted_mclmc( + def run_adjusted_mclmc_dynamic( self, logdensity_fn, num_steps, @@ -158,13 +158,13 @@ def run_adjusted_mclmc( init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.adjusted_mclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc_dynamic.init( position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) @@ -177,7 +177,7 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - target_acc_rate = 0.65 + target_acc_rate = 0.9 ( blackjax_state_after_tuning, @@ -197,7 +197,7 @@ def run_adjusted_mclmc( step_size = blackjax_mclmc_sampler_params.step_size L = blackjax_mclmc_sampler_params.L - alg = blackjax.adjusted_mclmc( + alg = blackjax.adjusted_mclmc_dynamic( logdensity_fn=logdensity_fn, step_size=step_size, integration_steps_fn=lambda key: jnp.ceil( @@ -218,6 +218,73 @@ def run_adjusted_mclmc( return out + def run_adjusted_mclmc( + self, + logdensity_fn, + num_steps, + initial_position, + key, + diagonal_preconditioning=False, + ): + integrator = isokinetic_mclachlan + + init_key, tune_key, run_key = jax.random.split(key, 3) + + initial_state = blackjax.mcmc.adjusted_mclmc.init( + position=initial_position, + logdensity_fn=logdensity_fn, + ) + + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + integrator=integrator, + num_integration_steps=avg_num_integration_steps, + sqrt_diag_cov=sqrt_diag_cov, + )( + rng_key=rng_key, + state=state, + step_size=step_size, + logdensity_fn=logdensity_fn, + ) + + target_acc_rate = 0.9 + + ( + blackjax_state_after_tuning, + blackjax_mclmc_sampler_params, + ) = blackjax.adjusted_mclmc_find_L_and_step_size( + mclmc_kernel=kernel, + num_steps=num_steps, + state=initial_state, + rng_key=tune_key, + target=target_acc_rate, + frac_tune1=0.1, + frac_tune2=0.1, + frac_tune3=0.1, + diagonal_preconditioning=diagonal_preconditioning, + ) + + step_size = blackjax_mclmc_sampler_params.step_size + L = blackjax_mclmc_sampler_params.L + + alg = blackjax.adjusted_mclmc( + logdensity_fn=logdensity_fn, + step_size=step_size, + num_integration_steps=L / step_size, + integrator=integrator, + sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + ) + + _, out = run_inference_algorithm( + rng_key=run_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=lambda state, _: state.position, + progress_bar=False, + ) + + return out + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -334,7 +401,35 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) - def test_adjusted_mclmc(self): + @parameterized.parameters([True, False]) + def test_adjusted_mclmc_dynamic(self, diagonal_preconditioning): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + states = self.run_adjusted_mclmc_dynamic( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=logdensity_fn, + key=inference_key, + num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, + ) + + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + + @parameterized.parameters([True, False]) + def test_adjusted_mclmc(self, diagonal_preconditioning): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -351,6 +446,7 @@ def test_adjusted_mclmc(self): logdensity_fn=logdensity_fn, key=inference_key, num_steps=10000, + diagonal_preconditioning=diagonal_preconditioning, ) coefs_samples = states["coefs"][3000:] From 9dd6bdba039003538eaf635fde7678defdbc7350 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 13:27:08 -0500 Subject: [PATCH 03/63] add static adjusted mclmc --- .../adaptation/adjusted_mclmc_adaptation.py | 10 ++--- blackjax/adaptation/mclmc_adaptation.py | 22 +++++----- blackjax/mcmc/adjusted_mclmc.py | 10 +++-- blackjax/mcmc/adjusted_mclmc_dynamic.py | 10 +++-- blackjax/mcmc/integrators.py | 12 +++--- blackjax/mcmc/mclmc.py | 8 ++-- tests/mcmc/test_integrators.py | 4 +- tests/mcmc/test_sampling.py | 40 +++++++++++-------- 8 files changed, 64 insertions(+), 52 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index f5d54e5c9..eabb642a3 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -74,7 +74,7 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -152,7 +152,7 @@ def step(iteration_state, weight_and_key): state=previous_state, avg_num_integration_steps=avg_num_integration_steps, step_size=params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) # step updating @@ -283,9 +283,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace( - sqrt_diag_cov=jnp.sqrt(variances), L=jnp.sqrt(dim) - ) + params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( @@ -323,7 +321,7 @@ def step(state, key): state=state, step_size=params.step_size, avg_num_integration_steps=params.L / params.step_size, - sqrt_diag_cov=params.sqrt_diag_cov, + inverse_mass_matrix=params.inverse_mass_matrix, ) return next_state, next_state.position diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 8452b6171..aa192b964 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -30,13 +30,13 @@ class MCLMCAdaptationState(NamedTuple): The momentum decoherent rate for the MCLMC algorithm. step_size The step size used for the MCLMC algorithm. - sqrt_diag_cov + inverse_mass_matrix A matrix used for preconditioning. """ L: float step_size: float - sqrt_diag_cov: float + inverse_mass_matrix: float def mclmc_find_L_and_step_size( @@ -87,10 +87,10 @@ def mclmc_find_L_and_step_size( Example ------- .. code:: - kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -106,7 +106,7 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, sqrt_diag_cov=jnp.ones((dim,)) + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -123,7 +123,7 @@ def mclmc_find_L_and_step_size( if frac_tune3 != 0: state, params = make_adaptation_L( - mclmc_kernel(params.sqrt_diag_cov), frac=frac_tune3, Lfactor=0.4 + mclmc_kernel(params.inverse_mass_matrix), frac=frac_tune3, Lfactor=0.4 )(state, params, num_steps, part2_key) return state, params @@ -152,7 +152,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): rng_key, nan_key = jax.random.split(rng_key) # dynamics - next_state, info = kernel(params.sqrt_diag_cov)( + next_state, info = kernel(params.inverse_mass_matrix)( rng_key=rng_key, state=previous_state, L=params.L, @@ -247,15 +247,15 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = params.L # determine L - sqrt_diag_cov = params.sqrt_diag_cov + inverse_mass_matrix = params.inverse_mass_matrix if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) L = jnp.sqrt(jnp.sum(variances)) if diagonal_preconditioning: - sqrt_diag_cov = jnp.sqrt(variances) - params = params._replace(sqrt_diag_cov=sqrt_diag_cov) + inverse_mass_matrix = variances + params = params._replace(inverse_mass_matrix=inverse_mass_matrix) L = jnp.sqrt(dim) # readjust the stepsize @@ -265,7 +265,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) - return state, MCLMCAdaptationState(L, params.step_size, sqrt_diag_cov) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return L_step_size_adaptation diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 8288772a3..9b868562c 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -40,7 +40,7 @@ def build_kernel( num_integration_steps: int, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -76,7 +76,9 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -105,7 +107,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -140,7 +142,7 @@ def as_top_level_api( kernel = build_kernel( num_integration_steps, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index 81fbc2835..1a69e1a28 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -38,7 +38,7 @@ def build_kernel( integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -76,7 +76,9 @@ def kernel( momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( integrator=integrators.with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) ), step_size=step_size, L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), @@ -106,7 +108,7 @@ def as_top_level_api( logdensity_fn: Callable, step_size: float, L_proposal_factor: float = jnp.inf, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, @@ -143,7 +145,7 @@ def as_top_level_api( integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 593683ca4..733e7e960 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -311,7 +311,9 @@ def _normalized_flatten_array(x, tol=1e-13): return jnp.where(norm > tol, x / norm, x), norm -def esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0): +def esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0): + sqrt_inverse_mass_matrix = jax.tree_util.tree_map(jnp.sqrt, inverse_mass_matrix) + def update( momentum: ArrayTree, logdensity_grad: ArrayTree, @@ -330,7 +332,7 @@ def update( logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) - flatten_grads = flatten_grads * sqrt_diag_cov + flatten_grads = flatten_grads * sqrt_inverse_mass_matrix flatten_momentum, _ = ravel_pytree(momentum) dims = flatten_momentum.shape[0] normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) @@ -342,7 +344,7 @@ def update( + 2 * zeta * flatten_momentum ) new_momentum_normalized, _ = _normalized_flatten_array(new_momentum_raw) - gr = unravel_fn(new_momentum_normalized * sqrt_diag_cov) + gr = unravel_fn(new_momentum_normalized * sqrt_inverse_mass_matrix) next_momentum = unravel_fn(new_momentum_normalized) kinetic_energy_change = ( delta @@ -374,11 +376,11 @@ def format_isokinetic_state_output( def generate_isokinetic_integrator(coefficients): def isokinetic_integrator( - logdensity_fn: Callable, sqrt_diag_cov: ArrayTree = 1.0 + logdensity_fn: Callable, inverse_mass_matrix: ArrayTree = 1.0 ) -> GeneralIntegrator: position_update_fn = euclidean_position_update_fn(logdensity_fn) one_step = generalized_two_stage_integrator( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov), + esh_dynamics_momentum_update_one_step(inverse_mass_matrix), position_update_fn, coefficients, format_output_fn=format_isokinetic_state_output, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e7a69849b..ff9638a1f 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): +def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): """Build a HMC kernel. Parameters @@ -81,7 +81,7 @@ def build_kernel(logdensity_fn, sqrt_diag_cov, integrator): """ step = with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, sqrt_diag_cov=sqrt_diag_cov) + integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) def kernel( @@ -107,7 +107,7 @@ def as_top_level_api( L, step_size, integrator=isokinetic_mclachlan, - sqrt_diag_cov=1.0, + inverse_mass_matrix=1.0, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +155,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, sqrt_diag_cov, integrator) + kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) diff --git a/tests/mcmc/test_integrators.py b/tests/mcmc/test_integrators.py index c38009e5e..c37c0ede6 100644 --- a/tests/mcmc/test_integrators.py +++ b/tests/mcmc/test_integrators.py @@ -238,7 +238,7 @@ def test_esh_momentum_update(self, dims): # Efficient implementation update_stable = self.variant( - esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) ) next_momentum1, *_ = update_stable(momentum, gradient, step_size, 1.0) np.testing.assert_array_almost_equal(next_momentum, next_momentum1) @@ -263,7 +263,7 @@ def test_isokinetic_velocity_verlet(self): next_state, kinetic_energy_change = step(initial_state, step_size) # explicit integration - op1 = esh_dynamics_momentum_update_one_step(sqrt_diag_cov=1.0) + op1 = esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0) op2 = integrators.euclidean_position_update_fn(logdensity_fn) position, momentum, _, logdensity_grad = initial_state momentum, kinetic_grad, kinetic_energy_change0 = op1( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 45d60f84a..a4ea66a9b 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -112,10 +112,10 @@ def run_mclmc( position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.mclmc.isokinetic_mclachlan, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -133,7 +133,7 @@ def run_mclmc( logdensity_fn, L=blackjax_mclmc_sampler_params.L, step_size=blackjax_mclmc_sampler_params.step_size, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, samples = run_inference_algorithm( @@ -144,6 +144,8 @@ def run_mclmc( transform=lambda state, info: state.position, ) + print(samples["coefs"][0].item()) + return samples def run_adjusted_mclmc_dynamic( @@ -164,12 +166,12 @@ def run_adjusted_mclmc_dynamic( random_generator_arg=init_key, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integrator=integrator, integration_steps_fn=lambda k: jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) ), - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -204,7 +206,7 @@ def run_adjusted_mclmc_dynamic( jax.random.uniform(key) * rescale(L / step_size) ), integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -216,6 +218,8 @@ def run_adjusted_mclmc_dynamic( progress_bar=False, ) + print(blackjax_mclmc_sampler_params.inverse_mass_matrix[1].item()) + return out def run_adjusted_mclmc( @@ -235,10 +239,10 @@ def run_adjusted_mclmc( logdensity_fn=logdensity_fn, ) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, num_integration_steps=avg_num_integration_steps, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -271,7 +275,7 @@ def run_adjusted_mclmc( step_size=step_size, num_integration_steps=L / step_size, integrator=integrator, - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, ) _, out = run_inference_algorithm( @@ -402,7 +406,10 @@ def test_mclmc(self): np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @parameterized.parameters([True, False]) - def test_adjusted_mclmc_dynamic(self, diagonal_preconditioning): + def test_adjusted_mclmc_dynamic( + self, + diagonal_preconditioning, + ): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -495,7 +502,7 @@ def __init__(self, d, condition_number): integrator = isokinetic_mclachlan - def get_sqrt_diag_cov(): + def get_inverse_mass_matrix(): init_key, tune_key = jax.random.split(key) initial_position = model.sample_init(init_key) @@ -506,10 +513,10 @@ def get_sqrt_diag_cov(): rng_key=init_key, ) - kernel = lambda sqrt_diag_cov: blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix: blackjax.mcmc.mclmc.build_kernel( logdensity_fn=model.logdensity_fn, integrator=integrator, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) ( @@ -523,13 +530,14 @@ def get_sqrt_diag_cov(): diagonal_preconditioning=True, ) - return blackjax_mclmc_sampler_params.sqrt_diag_cov + return blackjax_mclmc_sampler_params.inverse_mass_matrix - sqrt_diag_cov = get_sqrt_diag_cov() + inverse_mass_matrix = get_inverse_mass_matrix() assert ( jnp.abs( jnp.dot( - (sqrt_diag_cov**2) / jnp.linalg.norm(sqrt_diag_cov**2), + (inverse_mass_matrix**2) + / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 From a49bb35f37293f0033ea4c9c5b8daf7ff62c1461 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 13:54:36 -0500 Subject: [PATCH 04/63] add static adjusted mclmc --- tests/mcmc/test_sampling.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index a4ea66a9b..d788696f8 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -144,8 +144,6 @@ def run_mclmc( transform=lambda state, info: state.position, ) - print(samples["coefs"][0].item()) - return samples def run_adjusted_mclmc_dynamic( @@ -218,8 +216,6 @@ def run_adjusted_mclmc_dynamic( progress_bar=False, ) - print(blackjax_mclmc_sampler_params.inverse_mass_matrix[1].item()) - return out def run_adjusted_mclmc( From 35d71ffb9ece4940b113ed2caedf4f302b20ddfd Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:02:35 -0500 Subject: [PATCH 05/63] draft --- blackjax/adaptation/ensemble_mclmc.py | 224 +++++++++++++++++++++ blackjax/adaptation/ensemble_umclmc.py | 265 +++++++++++++++++++++++++ blackjax/adaptation/step_size.py | 43 ++++ blackjax/mcmc/mclmc.py | 2 +- blackjax/util.py | 123 +++++++++++- 5 files changed, 655 insertions(+), 2 deletions(-) create mode 100644 blackjax/adaptation/ensemble_mclmc.py create mode 100644 blackjax/adaptation/ensemble_umclmc.py diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py new file mode 100644 index 000000000..dabf5a3be --- /dev/null +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -0,0 +1,224 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#"""Public API for the MCLMC Kernel""" + +from typing import Callable, NamedTuple, Any + +import jax +import jax.numpy as jnp + +from blackjax.util import run_eca +from blackjax.mcmc.integrators import generate_isokinetic_integrator, velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +import blackjax.adaptation.ensemble_umclmc as umclmc +from blackjax.adaptation.ensemble_umclmc import equipartition_diagonal, equipartition_diagonal_loss, equipartition_fullrank, equipartition_fullrank_loss + +from blackjax.adaptation.step_size import dual_averaging_adaptation, bisection_monotonic_fn + + + +class AdaptationState(NamedTuple): + steps_per_sample: float + step_size: float + epsadap_state: Any + sample_count: int + + + +def build_kernel(logdensity_fn, integrator, sqrt_diag_cov): + """MCLMC kernel""" + + kernel = build_kernel_malt(logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov, L_proposal_factor = 1.25) + + def sequential_kernel(key, state, adap): + return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample) + + return sequential_kernel + + + +class Adaptation: + + def __init__(self, adap_state, num_adaptation_samples, + steps_per_sample, acc_prob_target= 0.8, + observables = lambda x: 0., + observables_for_bias = lambda x: 0., contract= lambda x: 0.): + + self.num_adaptation_samples= num_adaptation_samples + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + + ### Determine the initial hyperparameters ### + + ## stepsize ## + #if we switched to the more accurate integrator we can use longer step size + #integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 + # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 + #adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adap_state.step_size #* integrator_factor * adjustment_factor + + #steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) + + ### Initialize the dual averaging adaptation ### + #da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) + #epsadap_state = da_init_fn(step_size) + + ### Initialize the bisection for finding the step size + epsadap_state, self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + + self.initial_state = AdaptationState(steps_per_sample, step_size, epsadap_state, 0) + + + def summary_statistics_fn(self, state, info, rng_key): + + return {'acceptance_probability': info.acceptance_rate, + #'inv_acceptance_probability': 1./info.acceptance_rate, + 'equipartition_diagonal': equipartition_diagonal(state), + 'equipartition_fullrank': equipartition_fullrank(state, rng_key), + 'observables': self.observables(state.position), + 'observables_for_bias': self.observables_for_bias(state.position) + } + + + def update(self, adaptation_state, Etheta): + + # combine the expectation values to get useful scalars + acc_prob = Etheta['acceptance_probability'] + #acc_prob = 1./Etheta['inv_acceptance_probability'] + equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) + equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) + true_bias = self.contract(Etheta['observables_for_bias']) + + + info_to_be_stored = {'L': adaptation_state.step_size * adaptation_state.steps_per_sample, 'steps_per_sample': adaptation_state.steps_per_sample, 'step_size': adaptation_state.step_size, + 'acc_prob': acc_prob, + 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, + 'observables': Etheta['observables'] + } + + # hyperparameter adaptation + + # Dual Averaging + # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples + + # def update(_): + # da_state = self.epsadap_update(adaptation_state.epsadap_state, acc_prob) + # step_size = jnp.exp(da_state.log_step_size) + # return da_state, step_size + + # def dont_update(_): + # da_state = adaptation_state.epsadap_state + # return da_state, jnp.exp(da_state.log_step_size_avg) + + # epsadap_state, step_size = jax.lax.cond(adaptation_phase, update, dont_update, operand=None) + + # Bisection + epsadap_state, step_size = self.epsadap_update(adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob) + + return AdaptationState(adaptation_state.steps_per_sample, step_size, epsadap_state, adaptation_state.sample_count + 1), info_to_be_stored + + + +def bias(model): + """should be transfered to benchmarks/""" + + def observables(position): + return jnp.square(model.transform(position)) + + def contract(sampler_E_x2): + bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 + return jnp.array([jnp.max(bsq), jnp.average(bsq)]) + + return observables, contract + + + +def while_steps_num(cond): + if jnp.all(cond): + return len(cond) + else: + return jnp.argmin(cond) + 1 + + +def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, + alpha= 1.9, bias_type= 0, save_frac= 0.2, C= 0.1, power= 3./8., early_stop= True, r_end= 5e-3,# stage1 parameters + diagonal_preconditioning= True, integrator_coefficients= None, steps_per_sample= 10, acc_prob= None, + observables = lambda x: None, + ensemble_observables= None + ): + + observables_for_bias, contract = bias(model) + key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) + + # initialize the chains + initial_state = umclmc.initialize(key_init, model.logdensity_fn, model.sample_init, num_chains, mesh) + + ### burn-in with the unadjusted method ### + kernel = umclmc.build_kernel(model.logdensity_fn) + save_num= (int)(jnp.rint(save_frac * num_steps1)) + adap = umclmc.Adaptation(model.ndims, alpha= alpha, bias_type= bias_type, save_num= save_num, C=C, power= power, r_end = r_end, + observables= observables, observables_for_bias= observables_for_bias, contract= contract) + final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps1, num_chains, mesh, ensemble_observables) + + if early_stop: # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. + + num_steps_while = while_steps_num((info1[0] if ensemble_observables != None else info1)['while_cond']) + #print(num_steps_while, save_num) + final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps_while, num_chains, mesh, ensemble_observables) + + ### refine the results with the adjusted method ### + _acc_prob = acc_prob + if integrator_coefficients == None: + high_dims = model.ndims > 200 + _integrator_coefficients = omelyan_coefficients if high_dims else mclachlan_coefficients + if acc_prob == None: + _acc_prob = 0.9 if high_dims else 0.7 + + else: + _integrator_coefficients = integrator_coefficients + if acc_prob == None: + _acc_prob = 0.9 + + + integrator = generate_isokinetic_integrator(_integrator_coefficients) + gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + + if diagonal_preconditioning: + sqrt_diag_cov= final_adaptation_state.sqrt_diag_cov + + # scale the stepsize so that it reflects averag scale change of the preconditioning + average_scale_change = jnp.sqrt(jnp.average(jnp.square(sqrt_diag_cov))) + final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) + + else: + sqrt_diag_cov= 1. + + kernel = build_kernel(model.logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov) + initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) + num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) + num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. + + adap = Adaptation(final_adaptation_state, num_adaptation_samples, steps_per_sample, _acc_prob, + observables= observables, observables_for_bias= observables_for_bias, contract= contract) + + final_state, final_adaptation_state, info2 = run_eca(key_mclmc, initial_state, kernel, adap, num_samples, num_chains, mesh, ensemble_observables) + + return info1, info2, gradient_calls_per_step, _acc_prob + + \ No newline at end of file diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py new file mode 100644 index 000000000..73b9ae4f7 --- /dev/null +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -0,0 +1,265 @@ + +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#"""Public API for the MCLMC Kernel""" + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree +from typing import Callable, NamedTuple, Any + +from blackjax.mcmc.integrators import IntegratorState, isokinetic_velocity_verlet +from blackjax.types import Array, ArrayLike +from blackjax.util import pytree_size +from blackjax.mcmc import mclmc +from blackjax.mcmc.integrators import _normalized_flatten_array +from blackjax.util import ensemble_execute_fn + + + +def no_nans(a): + flat_a, unravel_fn = ravel_pytree(a) + return jnp.all(jnp.isfinite(flat_a)) + + +def nan_reject(nonans, old, new): + """Equivalent to + return new if nonans else old""" + + return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) + + + +def build_kernel(logdensity_fn): + """MCLMC kernel (with nan rejection)""" + + kernel = mclmc.build_kernel(logdensity_fn= logdensity_fn, integrator= isokinetic_velocity_verlet) + + + def sequential_kernel(key, state, adap): + + new_state, info = kernel(key, state, adap.L, adap.step_size) + + # reject the new state if there were nans + nonans = no_nans(new_state) + new_state = nan_reject(nonans, state, new_state) + + return new_state, {'nans': 1-nonans, 'energy_change': info.energy_change * nonans, 'logdensity': info.logdensity * nonans} + + + return sequential_kernel + + + +def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): + """initialize the chains based on the equipartition of the initial condition. + We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. + """ + + def sequential_init(key, x, args): + """initialize the position using sample_init and the velocity along the gradient""" + position = sample_init(key) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + flat_g, unravel_fn = ravel_pytree(logdensity_grad) + velocity = unravel_fn(_normalized_flatten_array(flat_g)[0]) # = grad logp/ |grad logp| + return IntegratorState(position, velocity, logdensity, logdensity_grad), None + + def summary_statistics_fn(state): + """compute the diagonal elements of the equipartition matrix""" + return -state.position * state.logdensity_grad + + def ensemble_init(key, state, signs): + """flip the velocity, depending on the equipartition condition""" + velocity = jax.tree_util.tree_map(lambda sign, u: sign * u, signs, state.momentum) + return IntegratorState(state.position, velocity, state.logdensity, state.logdensity_grad), None + + key1, key2= jax.random.split(rng_key) + initial_state, equipartition = ensemble_execute_fn(sequential_init, key1, num_chains, mesh, summary_statistics_fn= summary_statistics_fn) + signs = -2. * (equipartition < 1.) + 1. + initial_state, _ = ensemble_execute_fn(ensemble_init, key2, num_chains, mesh, x= initial_state, args= signs) + + return initial_state + + +def update_history(new_vals, history): + return jnp.concatenate((new_vals[None, :], history[:-1])) + +def update_history_scalar(new_val, history): + return jnp.concatenate((new_val * jnp.ones(1), history[:-1])) + +def contract_history(theta, weights): + + square_average = jnp.square(jnp.average(theta, weights= weights, axis= 0)) + average_square = jnp.average(jnp.square(theta), weights= weights, axis= 0) + + r = (average_square - square_average) / square_average + + return jnp.array([jnp.max(r), + jnp.average(r) + ]) + + +class History(NamedTuple): + observables: Array + stopping: Array + weights: Array + + +class AdaptationState(NamedTuple): + + L: float + sqrt_diag_cov: Any + step_size: float + + step_count: int + EEVPD: float + EEVPD_wanted: float + history: Any + + +def equipartition_diagonal(state): + """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. + equipartition_loss = average over parameters (Ei)""" + return jax.tree_util.tree_map(lambda x, g: -x * g, state.position, state.logdensity_grad) + + + +def equipartition_fullrank(state, rng_key): + """loss = Tr[(1 - E)^T (1 - E)] / d^2 + where Eij = is the equipartition patrix. + Loss is computed with the Hutchinson's trick.""" + + x, unravel_fn = ravel_pytree(state.position) + g, unravel_fn = ravel_pytree(state.logdensity_grad) + d = len(x) + + def func(z): + """z here has the same shape as position""" + return (z + jnp.dot(z, g) * x) + + z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij + return jax.vmap(func)(z) + + +def equipartition_diagonal_loss(Eii): + Eii_flat, unravel_fn = ravel_pytree(Eii) + return jnp.average(jnp.square(1.- Eii_flat)) + + +def equipartition_fullrank_loss(delta_z): + d = delta_z.shape[-1] + return jnp.average(jnp.square(delta_z)) / d + + +class Adaptation: + + def __init__(self, num_dims, + alpha= 1., C= 0.1, power = 3./8., r_end= 0.01, + bias_type= 0, save_num = 10, + observables= lambda x: 0., observables_for_bias= lambda x: 0., contract= lambda x: 0. + ): + + self.num_dims = num_dims + self.alpha = alpha + self.C = C + self.power = power + self.r_end = r_end + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + self.bias_type = bias_type + self.save_num = save_num + #sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) + + r_save_num = save_num + + history = History(observables= jnp.zeros((r_save_num, num_dims)), + stopping= jnp.full((save_num,), jnp.nan), + weights= jnp.zeros(r_save_num)) + + self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step + sqrt_diag_cov= jnp.ones(num_dims), + step_size= 0.01 * jnp.sqrt(num_dims), + step_count= 0, + EEVPD=1e-3, EEVPD_wanted=1e-3, + history=history) + + + def summary_statistics_fn(self, state, info, rng_key): + + position_flat, unravel_fn = ravel_pytree(state.position) + + return {'equipartition_diagonal': equipartition_diagonal(state), + 'equipartition_fullrank': equipartition_fullrank(state, rng_key), + 'x': position_flat, 'xsq': jnp.square(position_flat), + 'E': info['energy_change'], 'Esq': jnp.square(info['energy_change']), + 'rejection_rate_nans': info['nans'], + 'observables_for_bias': self.observables_for_bias(state.position), + 'observables': self.observables(state.position), + 'entropy': - info['logdensity'] + } + + + def update(self, adaptation_state, Etheta): + + # combine the expectation values to get useful scalars + equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) + equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) + + history_observables = update_history(Etheta['observables_for_bias'], adaptation_state.history.observables) + history_weights = update_history_scalar(1., adaptation_state.history.weights) + fluctuations = contract_history(history_observables, history_weights) + history_stopping = update_history_scalar(jax.lax.cond(adaptation_state.step_count > len(history_weights), lambda _: fluctuations[0], lambda _: jnp.nan, operand=None), + adaptation_state.history.stopping) + history = History(history_observables, history_stopping, history_weights) + + L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) + sqrt_diag_cov = jnp.sqrt(Etheta['xsq'] - jnp.square(Etheta['x'])) + EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims + true_bias = self.contract(Etheta['observables_for_bias']) + nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) + + # hyperparameter adaptation + # estimate bias + bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag + EEVPD_wanted = self.C * jnp.power(bias, self.power) + + + eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1./6.) + eps_factor = jnp.clip(eps_factor, 0.3, 3.) + + eps_factor = nan_reject(1-nans, 0.5, eps_factor) # reduce the stepsize if there were nans + + # determine if we want to finish this stage (i.e. if loss is no longer decreassing) + #increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] + #while_cond = ~increasing + while_cond = (fluctuations[0] > self.r_end) | (adaptation_state.step_count < self.save_num) + + info_to_be_stored = {'L': adaptation_state.L, 'step_size': adaptation_state.step_size, + 'EEVPD_wanted': EEVPD_wanted, 'EEVPD': EEVPD, + 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, + 'r_max': fluctuations[0], 'r_avg': fluctuations[1], + 'while_cond': while_cond, 'entropy': Etheta['entropy'], + 'observables': Etheta['observables']} + + adaptation_state_new = AdaptationState(L, + sqrt_diag_cov, + adaptation_state.step_size * eps_factor, + adaptation_state.step_count + 1, + EEVPD, + EEVPD_wanted, + history) + + return adaptation_state_new, info_to_be_stored + diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 2b06172c0..7a076b962 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -257,3 +257,46 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: rss_state = jax.lax.while_loop(do_continue, update, rss_state) return rss_state.step_size + +def bisection_monotonic_fn(acc_prob_wanted, reduce_shift = jnp.log(2.), tolerance= 0.03): + """Bisection of a monotonically decreassing function, that doesn't require an initially bracketing interval.""" + + def update(state, exp_x, acc_rate_new): + + bounds, terminated = state + + # update the bounds + acc_high = acc_rate_new > acc_prob_wanted + x = jnp.log(exp_x) + + def on_true(bounds): + bounds0 = jnp.max(jnp.array([bounds[0], x])) + return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift + + def on_false(bounds): + bounds1 = jnp.min(jnp.array([bounds[1], x])) + return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift + + + bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) + + + # if we have already found a bracketing interval, do bisection, otherwise further reduce or increase the bounds + bracketing = jnp.all(jnp.isfinite(bounds_new)) + + def reduce(bounds): + return x_new + + def bisect(bounds): + return jnp.average(bounds) + + x_new = jax.lax.cond(bracketing, bisect, reduce, bounds_new) + + stepsize = terminated * exp_x + (1-terminated) * jnp.exp(x_new) + + terminated_new = (jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance) | terminated + + return (bounds_new, terminated_new), stepsize + + + return (jnp.array([-jnp.inf, jnp.inf]), False), update \ No newline at end of file diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ff9638a1f..e5cc46213 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): +def build_kernel(logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan): """Build a HMC kernel. Parameters diff --git a/blackjax/util.py b/blackjax/util.py index 8cdcd45ee..8f6494167 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -4,11 +4,14 @@ from typing import Callable, Union import jax.numpy as jnp -from jax import jit, lax +from jax import jit, lax, device_put, vmap from jax.flatten_util import ravel_pytree from jax.random import normal, split from jax.tree_util import tree_leaves, tree_map +from jax.sharding import Mesh, PartitionSpec, NamedSharding +from jax.experimental.shard_map import shard_map + from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -314,3 +317,121 @@ def incremental_value_update( ) total += weight return total, average + +def eca_step(kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info= None): + + def _step(state_all, xs): + """This function operates on a single device.""" + state, adaptation_state = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + _, keys_sampling, key_adaptation = xs # keys_sampling.shape = (chains_per_device, ) + + # update the state of all chains on this device + state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + + # combine all the chains to compute expectation values + theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) + Etheta = tree_map(lambda theta: lax.psum(jnp.sum(theta, axis= 0), axis_name= 'chains') / num_chains, theta) + + # use these to adapt the hyperparameters of the dynamics + adaptation_state, info_to_be_stored = adaptation_update(adaptation_state, Etheta) + + return (state, adaptation_state), info_to_be_stored + + + if ensemble_info != None: + + def step(state_all, xs): + (state, adaptation_state), info_to_be_stored = _step(state_all, xs) + return (state, adaptation_state), (info_to_be_stored, vmap(ensemble_info)(state.position)) + + return step + + else: + return _step + + +def run_eca(rng_key, initial_state, kernel, adaptation, num_steps, num_chains, mesh, ensemble_info= None): + + step = eca_step(kernel, adaptation.summary_statistics_fn, adaptation.update, num_chains, ensemble_info) + + + def all_steps(initial_state, keys_sampling, keys_adaptation): + """This function operates on a single device. key is a random key for this device.""" + + initial_state_all = (initial_state, adaptation.initial_state) + + # run sampling + xs = (jnp.arange(num_steps), keys_sampling.T, keys_adaptation) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + final_state, final_adaptation_state = final_state_all + return final_state, final_adaptation_state, info_history # info history is composed of averages over all chains, so it is a couple of scalars + + + p, pscalar = PartitionSpec('chains'), PartitionSpec() + parallel_execute = shard_map(all_steps, + mesh= mesh, + in_specs= (p, p, pscalar), + out_specs= (p, pscalar, pscalar), + check_rep=False + ) + + # produce all random keys that will be needed + key_sampling, key_adaptation = split(rng_key) + keys_adaptation = split(key_adaptation, num_steps) + distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices + keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + + + # run sampling in parallel + final_state, final_adaptation_state, info_history = parallel_execute(initial_state, keys_sampling, keys_adaptation) + + return final_state, final_adaptation_state, info_history + + + + +def ensemble_execute_fn(func, rng_key, num_chains, mesh, + x= None, + args= None, + summary_statistics_fn= lambda y: 0., + ): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics + """ + p, pscalar = PartitionSpec('chains'), PartitionSpec() + + if x == None: + X= device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X= x + + adaptation_update= lambda _, Etheta: (Etheta, None) + + _F = eca_step(func, lambda y, info, key: summary_statistics_fn(y), adaptation_update, num_chains) + + def F(x, keys): + """This function operates on a single device. key is a random key for this device.""" + y, summary_statistics = _F((x, args), (None, keys, None))[0] + return y, summary_statistics + + parallel_execute = shard_map(F, + mesh= mesh, + in_specs= (p, p), + out_specs= (p, pscalar), + check_rep=False + ) + + keys = device_put(split(rng_key, num_chains), NamedSharding(mesh, p)) # random keys, distributed across devices + # apply F in parallel + return parallel_execute(X, keys) \ No newline at end of file From 0f3df53011f00ecec9d879f23c01f08588608241 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:13:11 -0500 Subject: [PATCH 06/63] draft --- blackjax/adaptation/ensemble_mclmc.py | 14 +++++++------- blackjax/adaptation/ensemble_umclmc.py | 8 ++++---- blackjax/mcmc/adjusted_mclmc.py | 4 ++-- blackjax/util.py | 3 +++ 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index dabf5a3be..b303f8ae7 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -37,13 +37,13 @@ class AdaptationState(NamedTuple): -def build_kernel(logdensity_fn, integrator, sqrt_diag_cov): +def build_kernel(logdensity_fn, integrator, inverse_mass_matrix): """MCLMC kernel""" - kernel = build_kernel_malt(logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov, L_proposal_factor = 1.25) + kernel = build_kernel_malt(logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix= inverse_mass_matrix,) def sequential_kernel(key, state, adap): - return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample) + return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample, L_proposal_factor = 1.25,) return sequential_kernel @@ -200,16 +200,16 @@ def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. if diagonal_preconditioning: - sqrt_diag_cov= final_adaptation_state.sqrt_diag_cov + inverse_mass_matrix= jnp.sqrt(final_adaptation_state.inverse_mass_matrix) # scale the stepsize so that it reflects averag scale change of the preconditioning - average_scale_change = jnp.sqrt(jnp.average(jnp.square(sqrt_diag_cov))) + average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) else: - sqrt_diag_cov= 1. + inverse_mass_matrix= 1. - kernel = build_kernel(model.logdensity_fn, integrator, sqrt_diag_cov= sqrt_diag_cov) + kernel = build_kernel(model.logdensity_fn, integrator, inverse_mass_matrix= inverse_mass_matrix) initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 73b9ae4f7..458361103 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -119,7 +119,7 @@ class History(NamedTuple): class AdaptationState(NamedTuple): L: float - sqrt_diag_cov: Any + inverse_mass_matrix: Any step_size: float step_count: int @@ -189,7 +189,7 @@ def __init__(self, num_dims, weights= jnp.zeros(r_save_num)) self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step - sqrt_diag_cov= jnp.ones(num_dims), + inverse_mass_matrix= jnp.ones(num_dims), step_size= 0.01 * jnp.sqrt(num_dims), step_count= 0, EEVPD=1e-3, EEVPD_wanted=1e-3, @@ -225,7 +225,7 @@ def update(self, adaptation_state, Etheta): history = History(history_observables, history_stopping, history_weights) L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) - sqrt_diag_cov = jnp.sqrt(Etheta['xsq'] - jnp.square(Etheta['x'])) + inverse_mass_matrix = Etheta['xsq'] - jnp.square(Etheta['x']) EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims true_bias = self.contract(Etheta['observables_for_bias']) nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) @@ -254,7 +254,7 @@ def update(self, adaptation_state, Etheta): 'observables': Etheta['observables']} adaptation_state_new = AdaptationState(L, - sqrt_diag_cov, + inverse_mass_matrix, adaptation_state.step_size * eps_factor, adaptation_state.step_count + 1, EEVPD, diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 9b868562c..8a5a37e55 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -37,7 +37,7 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable): def build_kernel( - num_integration_steps: int, + logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, inverse_mass_matrix=1.0, @@ -66,8 +66,8 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: HMCState, - logdensity_fn: Callable, step_size: float, + num_integration_steps: int, L_proposal_factor: float = jnp.inf, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" diff --git a/blackjax/util.py b/blackjax/util.py index 8f6494167..92eb77f40 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -377,7 +377,10 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): ) # produce all random keys that will be needed + # rng_key = rng_key if not isinstance(rng_key, jnp.ndarray) else rng_key[0] + key_sampling, key_adaptation = split(rng_key) + num_steps = jnp.array(num_steps).item() keys_adaptation = split(key_adaptation, num_steps) distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) From 04522f57c94cf6d5aa2bbb25f9cd14450be645da Mon Sep 17 00:00:00 2001 From: = Date: Wed, 15 Jan 2025 14:18:55 -0500 Subject: [PATCH 07/63] change order of parameters --- blackjax/mcmc/adjusted_mclmc.py | 16 ++++++++-------- tests/mcmc/test_sampling.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 9b868562c..f390402f2 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -37,7 +37,7 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable): def build_kernel( - num_integration_steps: int, + logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, inverse_mass_matrix=1.0, @@ -66,8 +66,8 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: HMCState, - logdensity_fn: Callable, step_size: float, + num_integration_steps: int, L_proposal_factor: float = jnp.inf, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" @@ -140,7 +140,7 @@ def as_top_level_api( """ kernel = build_kernel( - num_integration_steps, + logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, @@ -152,11 +152,11 @@ def init_fn(position: ArrayLikeTree, rng_key=None): def update_fn(rng_key: PRNGKey, state): return kernel( - rng_key, - state, - logdensity_fn, - step_size, - L_proposal_factor, + rng_key=rng_key, + state=state, + step_size=step_size, + num_integration_steps=num_integration_steps, + L_proposal_factor=L_proposal_factor, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index d788696f8..e9068326e 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -237,13 +237,13 @@ def run_adjusted_mclmc( kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc.build_kernel( integrator=integrator, - num_integration_steps=avg_num_integration_steps, inverse_mass_matrix=inverse_mass_matrix, + logdensity_fn=logdensity_fn, )( rng_key=rng_key, state=state, step_size=step_size, - logdensity_fn=logdensity_fn, + num_integration_steps=avg_num_integration_steps, ) target_acc_rate = 0.9 From a7c99b92ee59b82db855d4b919a6254aa979a97d Mon Sep 17 00:00:00 2001 From: = Date: Thu, 16 Jan 2025 21:00:11 +0000 Subject: [PATCH 08/63] draft --- blackjax/adaptation/ensemble_mclmc.py | 352 +++++++++++++++---------- blackjax/adaptation/ensemble_umclmc.py | 319 ++++++++++++---------- blackjax/adaptation/step_size.py | 43 ++- blackjax/mcmc/mclmc.py | 4 +- blackjax/util.py | 200 ++++++++------ tests/mcmc/test_sampling.py | 102 +++++++ 6 files changed, 655 insertions(+), 365 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index b303f8ae7..c01791daa 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -11,142 +11,158 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#"""Public API for the MCLMC Kernel""" +# """Public API for the MCLMC Kernel""" -from typing import Callable, NamedTuple, Any +from typing import Any, NamedTuple import jax import jax.numpy as jnp -from blackjax.util import run_eca -from blackjax.mcmc.integrators import generate_isokinetic_integrator, velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients -from blackjax.mcmc.hmc import HMCState -from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt import blackjax.adaptation.ensemble_umclmc as umclmc -from blackjax.adaptation.ensemble_umclmc import equipartition_diagonal, equipartition_diagonal_loss, equipartition_fullrank, equipartition_fullrank_loss - -from blackjax.adaptation.step_size import dual_averaging_adaptation, bisection_monotonic_fn +from blackjax.adaptation.ensemble_umclmc import ( + equipartition_diagonal, + equipartition_diagonal_loss, + equipartition_fullrank, + equipartition_fullrank_loss, +) +from blackjax.adaptation.step_size import bisection_monotonic_fn +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.integrators import ( + generate_isokinetic_integrator, + mclachlan_coefficients, + omelyan_coefficients, +) +from blackjax.util import run_eca - class AdaptationState(NamedTuple): steps_per_sample: float step_size: float epsadap_state: Any sample_count: int - - -def build_kernel(logdensity_fn, integrator, inverse_mass_matrix): - """MCLMC kernel""" - - kernel = build_kernel_malt(logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix= inverse_mass_matrix,) - - def sequential_kernel(key, state, adap): - return kernel(key, state, step_size= adap.step_size, num_integration_steps= adap.steps_per_sample, L_proposal_factor = 1.25,) - - return sequential_kernel +# put the arguments of build_kernel in a suitable order +build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( + logdensity_fn=logdensity_fn, + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, +)(rng_key=key, state=state, step_size=adap.step_size, num_integration_steps=adap.steps_per_sample, L_proposal_factor=1.25) class Adaptation: - - def __init__(self, adap_state, num_adaptation_samples, - steps_per_sample, acc_prob_target= 0.8, - observables = lambda x: 0., - observables_for_bias = lambda x: 0., contract= lambda x: 0.): - - self.num_adaptation_samples= num_adaptation_samples + def __init__( + self, + adap_state, + num_adaptation_samples, + steps_per_sample, + acc_prob_target=0.8, + observables=lambda x: 0.0, + observables_for_bias=lambda x: 0.0, + contract=lambda x: 0.0, + ): + self.num_adaptation_samples = num_adaptation_samples self.observables = observables self.observables_for_bias = observables_for_bias self.contract = contract - - ### Determine the initial hyperparameters ### - - ## stepsize ## - #if we switched to the more accurate integrator we can use longer step size - #integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + + # Determine the initial hyperparameters # + + # stepsize # + # if we switched to the more accurate integrator we can use longer step size + # integrator_factor = jnp.sqrt(10.) if mclachlan else 1. # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 - # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - #adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adap_state.step_size #* integrator_factor * adjustment_factor - - #steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) - - ### Initialize the dual averaging adaptation ### - #da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) - #epsadap_state = da_init_fn(step_size) - - ### Initialize the bisection for finding the step size + # adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adap_state.step_size # * integrator_factor * adjustment_factor + + # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) + + # Initialize the dual averaging adaptation # + # da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) + # epsadap_state = da_init_fn(step_size) + + # Initialize the bisection for finding the step size epsadap_state, self.epsadap_update = bisection_monotonic_fn(acc_prob_target) - - self.initial_state = AdaptationState(steps_per_sample, step_size, epsadap_state, 0) - - + + self.initial_state = AdaptationState( + steps_per_sample, step_size, epsadap_state, 0 + ) + def summary_statistics_fn(self, state, info, rng_key): - - return {'acceptance_probability': info.acceptance_rate, - #'inv_acceptance_probability': 1./info.acceptance_rate, - 'equipartition_diagonal': equipartition_diagonal(state), - 'equipartition_fullrank': equipartition_fullrank(state, rng_key), - 'observables': self.observables(state.position), - 'observables_for_bias': self.observables_for_bias(state.position) - } - + return { + "acceptance_probability": info.acceptance_rate, + "equipartition_diagonal": equipartition_diagonal(state), + "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "observables": self.observables(state.position), + "observables_for_bias": self.observables_for_bias(state.position), + } def update(self, adaptation_state, Etheta): - # combine the expectation values to get useful scalars - acc_prob = Etheta['acceptance_probability'] - #acc_prob = 1./Etheta['inv_acceptance_probability'] - equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) - equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) - true_bias = self.contract(Etheta['observables_for_bias']) - - - info_to_be_stored = {'L': adaptation_state.step_size * adaptation_state.steps_per_sample, 'steps_per_sample': adaptation_state.steps_per_sample, 'step_size': adaptation_state.step_size, - 'acc_prob': acc_prob, - 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, - 'observables': Etheta['observables'] - } + acc_prob = Etheta["acceptance_probability"] + # acc_prob = 1./Etheta['inv_acceptance_probability'] + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) + true_bias = self.contract(Etheta["observables_for_bias"]) + + info_to_be_stored = { + "L": adaptation_state.step_size * adaptation_state.steps_per_sample, + "steps_per_sample": adaptation_state.steps_per_sample, + "step_size": adaptation_state.step_size, + "acc_prob": acc_prob, + "equi_diag": equi_diag, + "equi_full": equi_full, + "bias": true_bias, + "observables": Etheta["observables"], + } # hyperparameter adaptation - + # Dual Averaging - # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples - + # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples + # def update(_): # da_state = self.epsadap_update(adaptation_state.epsadap_state, acc_prob) # step_size = jnp.exp(da_state.log_step_size) # return da_state, step_size - + # def dont_update(_): # da_state = adaptation_state.epsadap_state # return da_state, jnp.exp(da_state.log_step_size_avg) - + # epsadap_state, step_size = jax.lax.cond(adaptation_phase, update, dont_update, operand=None) - - # Bisection - epsadap_state, step_size = self.epsadap_update(adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob) - - return AdaptationState(adaptation_state.steps_per_sample, step_size, epsadap_state, adaptation_state.sample_count + 1), info_to_be_stored + # Bisection + epsadap_state, step_size = self.epsadap_update( + adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob + ) + + return ( + AdaptationState( + adaptation_state.steps_per_sample, + step_size, + epsadap_state, + adaptation_state.sample_count + 1, + ), + info_to_be_stored, + ) def bias(model): """should be transfered to benchmarks/""" - + def observables(position): return jnp.square(model.transform(position)) - + def contract(sampler_E_x2): bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 return jnp.array([jnp.max(bsq), jnp.average(bsq)]) - - return observables, contract + return observables, contract def while_steps_num(cond): @@ -156,69 +172,141 @@ def while_steps_num(cond): return jnp.argmin(cond) + 1 -def emaus(model, num_steps1, num_steps2, num_chains, mesh, rng_key, - alpha= 1.9, bias_type= 0, save_frac= 0.2, C= 0.1, power= 3./8., early_stop= True, r_end= 5e-3,# stage1 parameters - diagonal_preconditioning= True, integrator_coefficients= None, steps_per_sample= 10, acc_prob= None, - observables = lambda x: None, - ensemble_observables= None - ): - +def emaus( + model, + num_steps1, + num_steps2, + num_chains, + mesh, + rng_key, + alpha=1.9, + bias_type=0, + save_frac=0.2, + C=0.1, + power=3.0 / 8.0, + early_stop=True, + r_end=5e-3, # stage1 parameters + diagonal_preconditioning=True, + integrator_coefficients=None, + steps_per_sample=10, + acc_prob=None, + observables=lambda x: None, + ensemble_observables=None, +): observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) - + # initialize the chains - initial_state = umclmc.initialize(key_init, model.logdensity_fn, model.sample_init, num_chains, mesh) - - ### burn-in with the unadjusted method ### + initial_state = umclmc.initialize( + key_init, model.logdensity_fn, model.sample_init, num_chains, mesh + ) + + # burn-in with the unadjusted method # kernel = umclmc.build_kernel(model.logdensity_fn) - save_num= (int)(jnp.rint(save_frac * num_steps1)) - adap = umclmc.Adaptation(model.ndims, alpha= alpha, bias_type= bias_type, save_num= save_num, C=C, power= power, r_end = r_end, - observables= observables, observables_for_bias= observables_for_bias, contract= contract) - final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps1, num_chains, mesh, ensemble_observables) - - if early_stop: # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. - - num_steps_while = while_steps_num((info1[0] if ensemble_observables != None else info1)['while_cond']) - #print(num_steps_while, save_num) - final_state, final_adaptation_state, info1 = run_eca(key_umclmc, initial_state, kernel, adap, num_steps_while, num_chains, mesh, ensemble_observables) - - ### refine the results with the adjusted method ### + save_num = (int)(jnp.rint(save_frac * num_steps1)) + adap = umclmc.Adaptation( + model.ndims, + alpha=alpha, + bias_type=bias_type, + save_num=save_num, + C=C, + power=power, + r_end=r_end, + observables=observables, + observables_for_bias=observables_for_bias, + contract=contract, + ) + final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps1, + num_chains, + mesh, + ensemble_observables, + ) + + if ( + early_stop + ): # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. + num_steps_while = while_steps_num( + (info1[0] if ensemble_observables is not None else info1)["while_cond"] + ) + # print(num_steps_while, save_num) + final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps_while, + num_chains, + mesh, + ensemble_observables, + ) + + # refine the results with the adjusted method # _acc_prob = acc_prob - if integrator_coefficients == None: + if integrator_coefficients is None: high_dims = model.ndims > 200 - _integrator_coefficients = omelyan_coefficients if high_dims else mclachlan_coefficients - if acc_prob == None: + _integrator_coefficients = ( + omelyan_coefficients if high_dims else mclachlan_coefficients + ) + if acc_prob is None: _acc_prob = 0.9 if high_dims else 0.7 - + else: _integrator_coefficients = integrator_coefficients - if acc_prob == None: + if acc_prob is None: _acc_prob = 0.9 - - + integrator = generate_isokinetic_integrator(_integrator_coefficients) - gradient_calls_per_step= len(_integrator_coefficients) // 2 #scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + gradient_calls_per_step = ( + len(_integrator_coefficients) // 2 + ) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. if diagonal_preconditioning: - inverse_mass_matrix= jnp.sqrt(final_adaptation_state.inverse_mass_matrix) - + inverse_mass_matrix = jnp.sqrt(final_adaptation_state.inverse_mass_matrix) + # scale the stepsize so that it reflects averag scale change of the preconditioning average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) - final_adaptation_state = final_adaptation_state._replace(step_size= final_adaptation_state.step_size / average_scale_change) + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size / average_scale_change + ) else: - inverse_mass_matrix= 1. - - kernel = build_kernel(model.logdensity_fn, integrator, inverse_mass_matrix= inverse_mass_matrix) - initial_state= HMCState(final_state.position, final_state.logdensity, final_state.logdensity_grad) + inverse_mass_matrix = 1.0 + + kernel = build_kernel( + model.logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix + ) + initial_state = HMCState( + final_state.position, final_state.logdensity, final_state.logdensity_grad + ) num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) - num_adaptation_samples = num_samples//2 # number of samples after which the stepsize is fixed. - - adap = Adaptation(final_adaptation_state, num_adaptation_samples, steps_per_sample, _acc_prob, - observables= observables, observables_for_bias= observables_for_bias, contract= contract) - - final_state, final_adaptation_state, info2 = run_eca(key_mclmc, initial_state, kernel, adap, num_samples, num_chains, mesh, ensemble_observables) - + num_adaptation_samples = ( + num_samples // 2 + ) # number of samples after which the stepsize is fixed. + + adap = Adaptation( + final_adaptation_state, + num_adaptation_samples, + steps_per_sample, + _acc_prob, + observables=observables, + observables_for_bias=observables_for_bias, + contract=contract, + ) + + final_state, final_adaptation_state, info2 = run_eca( + key_mclmc, + initial_state, + kernel, + adap, + num_samples, + num_chains, + mesh, + ensemble_observables, + ) + return info1, info2, gradient_calls_per_step, _acc_prob - - \ No newline at end of file diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 458361103..7e2390def 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -1,4 +1,3 @@ - # Copyright 2020- The Blackjax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,22 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#"""Public API for the MCLMC Kernel""" +# """Public API for the MCLMC Kernel""" + +from typing import Any, NamedTuple import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree -from typing import Callable, NamedTuple, Any -from blackjax.mcmc.integrators import IntegratorState, isokinetic_velocity_verlet -from blackjax.types import Array, ArrayLike -from blackjax.util import pytree_size from blackjax.mcmc import mclmc -from blackjax.mcmc.integrators import _normalized_flatten_array +from blackjax.mcmc.integrators import ( + IntegratorState, + _normalized_flatten_array, + isokinetic_velocity_verlet, +) +from blackjax.types import Array from blackjax.util import ensemble_execute_fn - def no_nans(a): flat_a, unravel_fn = ravel_pytree(a) return jnp.all(jnp.isfinite(flat_a)) @@ -35,79 +36,96 @@ def no_nans(a): def nan_reject(nonans, old, new): """Equivalent to - return new if nonans else old""" - - return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) + return new if nonans else old""" + return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) def build_kernel(logdensity_fn): """MCLMC kernel (with nan rejection)""" - - kernel = mclmc.build_kernel(logdensity_fn= logdensity_fn, integrator= isokinetic_velocity_verlet) - - + + kernel = mclmc.build_kernel( + logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet + ) + def sequential_kernel(key, state, adap): - new_state, info = kernel(key, state, adap.L, adap.step_size) - + # reject the new state if there were nans nonans = no_nans(new_state) new_state = nan_reject(nonans, state, new_state) - - return new_state, {'nans': 1-nonans, 'energy_change': info.energy_change * nonans, 'logdensity': info.logdensity * nonans} - + return new_state, { + "nans": 1 - nonans, + "energy_change": info.energy_change * nonans, + "logdensity": info.logdensity * nonans, + } + return sequential_kernel - def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): """initialize the chains based on the equipartition of the initial condition. - We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. + We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. """ - + def sequential_init(key, x, args): """initialize the position using sample_init and the velocity along the gradient""" position = sample_init(key) logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) flat_g, unravel_fn = ravel_pytree(logdensity_grad) - velocity = unravel_fn(_normalized_flatten_array(flat_g)[0]) # = grad logp/ |grad logp| + velocity = unravel_fn( + _normalized_flatten_array(flat_g)[0] + ) # = grad logp/ |grad logp| return IntegratorState(position, velocity, logdensity, logdensity_grad), None - + def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" return -state.position * state.logdensity_grad - + def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" - velocity = jax.tree_util.tree_map(lambda sign, u: sign * u, signs, state.momentum) - return IntegratorState(state.position, velocity, state.logdensity, state.logdensity_grad), None - - key1, key2= jax.random.split(rng_key) - initial_state, equipartition = ensemble_execute_fn(sequential_init, key1, num_chains, mesh, summary_statistics_fn= summary_statistics_fn) - signs = -2. * (equipartition < 1.) + 1. - initial_state, _ = ensemble_execute_fn(ensemble_init, key2, num_chains, mesh, x= initial_state, args= signs) - + velocity = jax.tree_util.tree_map( + lambda sign, u: sign * u, signs, state.momentum + ) + return ( + IntegratorState( + state.position, velocity, state.logdensity, state.logdensity_grad + ), + None, + ) + + key1, key2 = jax.random.split(rng_key) + initial_state, equipartition = ensemble_execute_fn( + sequential_init, + key1, + num_chains, + mesh, + summary_statistics_fn=summary_statistics_fn, + ) + signs = -2.0 * (equipartition < 1.0) + 1.0 + initial_state, _ = ensemble_execute_fn( + ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs + ) + return initial_state - - + + def update_history(new_vals, history): return jnp.concatenate((new_vals[None, :], history[:-1])) + def update_history_scalar(new_val, history): return jnp.concatenate((new_val * jnp.ones(1), history[:-1])) + def contract_history(theta, weights): - - square_average = jnp.square(jnp.average(theta, weights= weights, axis= 0)) - average_square = jnp.average(jnp.square(theta), weights= weights, axis= 0) - + square_average = jnp.square(jnp.average(theta, weights=weights, axis=0)) + average_square = jnp.average(jnp.square(theta), weights=weights, axis=0) + r = (average_square - square_average) / square_average - - return jnp.array([jnp.max(r), - jnp.average(r) - ]) + + return jnp.array([jnp.max(r), jnp.average(r)]) class History(NamedTuple): @@ -117,44 +135,44 @@ class History(NamedTuple): class AdaptationState(NamedTuple): - L: float inverse_mass_matrix: Any step_size: float - + step_count: int EEVPD: float EEVPD_wanted: float - history: Any - + history: Any + def equipartition_diagonal(state): - """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. + """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. equipartition_loss = average over parameters (Ei)""" - return jax.tree_util.tree_map(lambda x, g: -x * g, state.position, state.logdensity_grad) - + return jax.tree_util.tree_map( + lambda x, g: -x * g, state.position, state.logdensity_grad + ) def equipartition_fullrank(state, rng_key): """loss = Tr[(1 - E)^T (1 - E)] / d^2 - where Eij = is the equipartition patrix. - Loss is computed with the Hutchinson's trick.""" + where Eij = is the equipartition patrix. + Loss is computed with the Hutchinson's trick.""" x, unravel_fn = ravel_pytree(state.position) g, unravel_fn = ravel_pytree(state.logdensity_grad) d = len(x) - + def func(z): """z here has the same shape as position""" - return (z + jnp.dot(z, g) * x) + return z + jnp.dot(z, g) * x - z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij + z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij return jax.vmap(func)(z) def equipartition_diagonal_loss(Eii): Eii_flat, unravel_fn = ravel_pytree(Eii) - return jnp.average(jnp.square(1.- Eii_flat)) + return jnp.average(jnp.square(1.0 - Eii_flat)) def equipartition_fullrank_loss(delta_z): @@ -163,103 +181,138 @@ def equipartition_fullrank_loss(delta_z): class Adaptation: - - def __init__(self, num_dims, - alpha= 1., C= 0.1, power = 3./8., r_end= 0.01, - bias_type= 0, save_num = 10, - observables= lambda x: 0., observables_for_bias= lambda x: 0., contract= lambda x: 0. - ): - + def __init__( + self, + num_dims, + alpha=1.0, + C=0.1, + power=3.0 / 8.0, + r_end=0.01, + bias_type=0, + save_num=10, + observables=lambda x: 0.0, + observables_for_bias=lambda x: 0.0, + contract=lambda x: 0.0, + ): self.num_dims = num_dims self.alpha = alpha self.C = C self.power = power self.r_end = r_end self.observables = observables - self.observables_for_bias = observables_for_bias + self.observables_for_bias = observables_for_bias self.contract = contract self.bias_type = bias_type self.save_num = save_num - #sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) - + # sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) + r_save_num = save_num - - history = History(observables= jnp.zeros((r_save_num, num_dims)), - stopping= jnp.full((save_num,), jnp.nan), - weights= jnp.zeros(r_save_num)) - - self.initial_state = AdaptationState(L= jnp.inf, # do not add noise for the first step - inverse_mass_matrix= jnp.ones(num_dims), - step_size= 0.01 * jnp.sqrt(num_dims), - step_count= 0, - EEVPD=1e-3, EEVPD_wanted=1e-3, - history=history) - - + + history = History( + observables=jnp.zeros((r_save_num, num_dims)), + stopping=jnp.full((save_num,), jnp.nan), + weights=jnp.zeros(r_save_num), + ) + + self.initial_state = AdaptationState( + L=jnp.inf, # do not add noise for the first step + inverse_mass_matrix=jnp.ones(num_dims), + step_size=0.01 * jnp.sqrt(num_dims), + step_count=0, + EEVPD=1e-3, + EEVPD_wanted=1e-3, + history=history, + ) + def summary_statistics_fn(self, state, info, rng_key): - position_flat, unravel_fn = ravel_pytree(state.position) - - return {'equipartition_diagonal': equipartition_diagonal(state), - 'equipartition_fullrank': equipartition_fullrank(state, rng_key), - 'x': position_flat, 'xsq': jnp.square(position_flat), - 'E': info['energy_change'], 'Esq': jnp.square(info['energy_change']), - 'rejection_rate_nans': info['nans'], - 'observables_for_bias': self.observables_for_bias(state.position), - 'observables': self.observables(state.position), - 'entropy': - info['logdensity'] - } - - + + return { + "equipartition_diagonal": equipartition_diagonal(state), + "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "x": position_flat, + "xsq": jnp.square(position_flat), + "E": info["energy_change"], + "Esq": jnp.square(info["energy_change"]), + "rejection_rate_nans": info["nans"], + "observables_for_bias": self.observables_for_bias(state.position), + "observables": self.observables(state.position), + "entropy": -info["logdensity"], + } + def update(self, adaptation_state, Etheta): - # combine the expectation values to get useful scalars - equi_diag = equipartition_diagonal_loss(Etheta['equipartition_diagonal']) - equi_full = equipartition_fullrank_loss(Etheta['equipartition_fullrank']) - - history_observables = update_history(Etheta['observables_for_bias'], adaptation_state.history.observables) - history_weights = update_history_scalar(1., adaptation_state.history.weights) + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) + + history_observables = update_history( + Etheta["observables_for_bias"], adaptation_state.history.observables + ) + history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) - history_stopping = update_history_scalar(jax.lax.cond(adaptation_state.step_count > len(history_weights), lambda _: fluctuations[0], lambda _: jnp.nan, operand=None), - adaptation_state.history.stopping) + history_stopping = update_history_scalar( + jax.lax.cond( + adaptation_state.step_count > len(history_weights), + lambda _: fluctuations[0], + lambda _: jnp.nan, + operand=None, + ), + adaptation_state.history.stopping, + ) history = History(history_observables, history_stopping, history_weights) - - L = self.alpha * jnp.sqrt(jnp.sum(Etheta['xsq'] - jnp.square(Etheta['x']))) # average over the ensemble, sum over parameters (to get sqrt(d)) - inverse_mass_matrix = Etheta['xsq'] - jnp.square(Etheta['x']) - EEVPD = (Etheta['Esq'] - jnp.square(Etheta['E'])) / self.num_dims - true_bias = self.contract(Etheta['observables_for_bias']) - nans = (Etheta['rejection_rate_nans'] > 0.) #| (~jnp.isfinite(eps_factor)) + + L = self.alpha * jnp.sqrt( + jnp.sum(Etheta["xsq"] - jnp.square(Etheta["x"])) + ) # average over the ensemble, sum over parameters (to get sqrt(d)) + inverse_mass_matrix = Etheta["xsq"] - jnp.square(Etheta["x"]) + EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.num_dims + true_bias = self.contract(Etheta["observables_for_bias"]) + nans = Etheta["rejection_rate_nans"] > 0.0 # | (~jnp.isfinite(eps_factor)) # hyperparameter adaptation # estimate bias - bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag + bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[ + self.bias_type + ] # r_max, r_avg, equi_full, equi_diag EEVPD_wanted = self.C * jnp.power(bias, self.power) - - eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1./6.) - eps_factor = jnp.clip(eps_factor, 0.3, 3.) - - eps_factor = nan_reject(1-nans, 0.5, eps_factor) # reduce the stepsize if there were nans + eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1.0 / 6.0) + eps_factor = jnp.clip(eps_factor, 0.3, 3.0) + + eps_factor = nan_reject( + 1 - nans, 0.5, eps_factor + ) # reduce the stepsize if there were nans # determine if we want to finish this stage (i.e. if loss is no longer decreassing) - #increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] - #while_cond = ~increasing - while_cond = (fluctuations[0] > self.r_end) | (adaptation_state.step_count < self.save_num) - - info_to_be_stored = {'L': adaptation_state.L, 'step_size': adaptation_state.step_size, - 'EEVPD_wanted': EEVPD_wanted, 'EEVPD': EEVPD, - 'equi_diag': equi_diag, 'equi_full': equi_full, 'bias': true_bias, - 'r_max': fluctuations[0], 'r_avg': fluctuations[1], - 'while_cond': while_cond, 'entropy': Etheta['entropy'], - 'observables': Etheta['observables']} - - adaptation_state_new = AdaptationState(L, - inverse_mass_matrix, - adaptation_state.step_size * eps_factor, - adaptation_state.step_count + 1, - EEVPD, - EEVPD_wanted, - history) - + # increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] + # while_cond = ~increasing + while_cond = (fluctuations[0] > self.r_end) | ( + adaptation_state.step_count < self.save_num + ) + + info_to_be_stored = { + "L": adaptation_state.L, + "step_size": adaptation_state.step_size, + "EEVPD_wanted": EEVPD_wanted, + "EEVPD": EEVPD, + "equi_diag": equi_diag, + "equi_full": equi_full, + "bias": true_bias, + "r_max": fluctuations[0], + "r_avg": fluctuations[1], + "while_cond": while_cond, + "entropy": Etheta["entropy"], + "observables": Etheta["observables"], + } + + adaptation_state_new = AdaptationState( + L, + inverse_mass_matrix, + adaptation_state.step_size * eps_factor, + adaptation_state.step_count + 1, + EEVPD, + EEVPD_wanted, + history, + ) + return adaptation_state_new, info_to_be_stored - diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 7a076b962..94c634ce3 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -258,45 +258,44 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: return rss_state.step_size -def bisection_monotonic_fn(acc_prob_wanted, reduce_shift = jnp.log(2.), tolerance= 0.03): + +def bisection_monotonic_fn(acc_prob_wanted, reduce_shift=jnp.log(2.0), tolerance=0.03): """Bisection of a monotonically decreassing function, that doesn't require an initially bracketing interval.""" - + def update(state, exp_x, acc_rate_new): - bounds, terminated = state - + # update the bounds acc_high = acc_rate_new > acc_prob_wanted x = jnp.log(exp_x) - + def on_true(bounds): bounds0 = jnp.max(jnp.array([bounds[0], x])) return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift - + def on_false(bounds): bounds1 = jnp.min(jnp.array([bounds[1], x])) return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift - - - bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) - - + + bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) + # if we have already found a bracketing interval, do bisection, otherwise further reduce or increase the bounds bracketing = jnp.all(jnp.isfinite(bounds_new)) - - def reduce(bounds): + + def reduce(bounds): return x_new def bisect(bounds): return jnp.average(bounds) - + x_new = jax.lax.cond(bracketing, bisect, reduce, bounds_new) - - stepsize = terminated * exp_x + (1-terminated) * jnp.exp(x_new) - - terminated_new = (jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance) | terminated - + + stepsize = terminated * exp_x + (1 - terminated) * jnp.exp(x_new) + + terminated_new = ( + jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance + ) | terminated + return (bounds_new, terminated_new), stepsize - - - return (jnp.array([-jnp.inf, jnp.inf]), False), update \ No newline at end of file + + return (jnp.array([-jnp.inf, jnp.inf]), False), update diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e5cc46213..2299dc68e 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -60,7 +60,9 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan): +def build_kernel( + logdensity_fn, inverse_mass_matrix=1.0, integrator=isokinetic_mclachlan +): """Build a HMC kernel. Parameters diff --git a/blackjax/util.py b/blackjax/util.py index 92eb77f40..e8c42f11f 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -4,14 +4,13 @@ from typing import Callable, Union import jax.numpy as jnp -from jax import jit, lax, device_put, vmap +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map from jax.flatten_util import ravel_pytree from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map -from jax.sharding import Mesh, PartitionSpec, NamedSharding -from jax.experimental.shard_map import shard_map - from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -318,123 +317,170 @@ def incremental_value_update( total += weight return total, average -def eca_step(kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info= None): +def eca_step( + kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +): def _step(state_all, xs): """This function operates on a single device.""" - state, adaptation_state = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. - _, keys_sampling, key_adaptation = xs # keys_sampling.shape = (chains_per_device, ) - + ( + state, + adaptation_state, + ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + ( + _, + keys_sampling, + key_adaptation, + ) = xs # keys_sampling.shape = (chains_per_device, ) + # update the state of all chains on this device state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - + # combine all the chains to compute expectation values theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) - Etheta = tree_map(lambda theta: lax.psum(jnp.sum(theta, axis= 0), axis_name= 'chains') / num_chains, theta) + Etheta = tree_map( + lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") + / num_chains, + theta, + ) # use these to adapt the hyperparameters of the dynamics - adaptation_state, info_to_be_stored = adaptation_update(adaptation_state, Etheta) - + adaptation_state, info_to_be_stored = adaptation_update( + adaptation_state, Etheta + ) + return (state, adaptation_state), info_to_be_stored - - - if ensemble_info != None: - + + if ensemble_info is not None: + def step(state_all, xs): (state, adaptation_state), info_to_be_stored = _step(state_all, xs) - return (state, adaptation_state), (info_to_be_stored, vmap(ensemble_info)(state.position)) - + return (state, adaptation_state), ( + info_to_be_stored, + vmap(ensemble_info)(state.position), + ) + return step else: return _step -def run_eca(rng_key, initial_state, kernel, adaptation, num_steps, num_chains, mesh, ensemble_info= None): - - step = eca_step(kernel, adaptation.summary_statistics_fn, adaptation.update, num_chains, ensemble_info) - +def run_eca( + rng_key, + initial_state, + kernel, + adaptation, + num_steps, + num_chains, + mesh, + ensemble_info=None, +): + step = eca_step( + kernel, + adaptation.summary_statistics_fn, + adaptation.update, + num_chains, + ensemble_info, + ) def all_steps(initial_state, keys_sampling, keys_adaptation): """This function operates on a single device. key is a random key for this device.""" - + initial_state_all = (initial_state, adaptation.initial_state) - + # run sampling - xs = (jnp.arange(num_steps), keys_sampling.T, keys_adaptation) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - + xs = ( + jnp.arange(num_steps), + keys_sampling.T, + keys_adaptation, + ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + final_state_all, info_history = lax.scan(step, initial_state_all, xs) final_state, final_adaptation_state = final_state_all - return final_state, final_adaptation_state, info_history # info history is composed of averages over all chains, so it is a couple of scalars - + return ( + final_state, + final_adaptation_state, + info_history, + ) # info history is composed of averages over all chains, so it is a couple of scalars + + p, pscalar = PartitionSpec("chains"), PartitionSpec() + parallel_execute = shard_map( + all_steps, + mesh=mesh, + in_specs=(p, p, pscalar), + out_specs=(p, pscalar, pscalar), + check_rep=False, + ) - p, pscalar = PartitionSpec('chains'), PartitionSpec() - parallel_execute = shard_map(all_steps, - mesh= mesh, - in_specs= (p, p, pscalar), - out_specs= (p, pscalar, pscalar), - check_rep=False - ) - # produce all random keys that will be needed # rng_key = rng_key if not isinstance(rng_key, jnp.ndarray) else rng_key[0] key_sampling, key_adaptation = split(rng_key) num_steps = jnp.array(num_steps).item() keys_adaptation = split(key_adaptation, num_steps) - distribute_keys = lambda key, shape: device_put(split(key, shape), NamedSharding(mesh, p)) # random keys, distributed across devices + distribute_keys = lambda key, shape: device_put( + split(key, shape), NamedSharding(mesh, p) + ) # random keys, distributed across devices keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute(initial_state, keys_sampling, keys_adaptation) - - return final_state, final_adaptation_state, info_history + final_state, final_adaptation_state, info_history = parallel_execute( + initial_state, keys_sampling, keys_adaptation + ) + return final_state, final_adaptation_state, info_history +def ensemble_execute_fn( + func, + rng_key, + num_chains, + mesh, + x=None, + args=None, + summary_statistics_fn=lambda y: 0.0, +): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. -def ensemble_execute_fn(func, rng_key, num_chains, mesh, - x= None, - args= None, - summary_statistics_fn= lambda y: 0., - ): - """Given a sequential function - func(rng_key, x, args) = y, - evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. - Args: - x: array distributed over all decvices - args: additional arguments for func, not distributed. - summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. - rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - - Returns: - y: array distributed over all decvices. Need not be of the same shape as x. - Etheta: expected values of the summary statistics + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics """ - p, pscalar = PartitionSpec('chains'), PartitionSpec() - - if x == None: - X= device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) - else: - X= x - - adaptation_update= lambda _, Etheta: (Etheta, None) - - _F = eca_step(func, lambda y, info, key: summary_statistics_fn(y), adaptation_update, num_chains) + p, pscalar = PartitionSpec("chains"), PartitionSpec() + + if x is None: + X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X = x + + adaptation_update = lambda _, Etheta: (Etheta, None) + + _F = eca_step( + func, + lambda y, info, key: summary_statistics_fn(y), + adaptation_update, + num_chains, + ) def F(x, keys): """This function operates on a single device. key is a random key for this device.""" y, summary_statistics = _F((x, args), (None, keys, None))[0] return y, summary_statistics - parallel_execute = shard_map(F, - mesh= mesh, - in_specs= (p, p), - out_specs= (p, pscalar), - check_rep=False - ) - - keys = device_put(split(rng_key, num_chains), NamedSharding(mesh, p)) # random keys, distributed across devices + parallel_execute = shard_map( + F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False + ) + + keys = device_put( + split(rng_key, num_chains), NamedSharding(mesh, p) + ) # random keys, distributed across devices # apply F in parallel - return parallel_execute(X, keys) \ No newline at end of file + return parallel_execute(X, keys) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index e9068326e..5d3dece82 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -2,6 +2,7 @@ import functools import itertools +from blackjax.adaptation.ensemble_mclmc import emaus import chex import jax import jax.numpy as jnp @@ -284,6 +285,31 @@ def run_adjusted_mclmc( ) return out + + def run_emaus( + self, + initial_position, + logdensity_fn, + key, + num_steps, + diagonal_preconditioning, + ): + + mesh = jax.sharding.Mesh(jax.devices(), 'chains') + + from blackjax.mcmc.integrators import velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients + + + integrator_coefficients = mclachlan_coefficients + + info1, info2, grads_per_step, _acc_prob = emaus(logdensity_fn, num_steps1=1000, num_steps2=3000, num_chains=4000, mesh=mesh, rng_key=key, + alpha = 1.9, bias_type= 3, C= 0.1, power= 3./8., + early_stop=1, r_end= 1e-2, diagonal_preconditioning= diagonal_preconditioning, integrator_coefficients= integrator_coefficients, + steps_per_sample= 15, acc_prob= None, ensemble_observables= lambda x: x + #ensemble_observables = lambda x: vec @ x + ) # run the algorithm + + return info2[1].reshape(info2[1].shape[0]*info2[1].shape[1], info2[1].shape[2]) @parameterized.parameters( itertools.product( @@ -457,6 +483,41 @@ def test_adjusted_mclmc(self, diagonal_preconditioning): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + + # TODO: add preconditioning + def test_emaus(self,): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_(**x) + + model = Banana() + + states = self.run_emaus( + initial_position={"coefs": 1.0, "log_scale": 1.0}, + logdensity_fn=model, + key=inference_key, + num_steps=10000, + diagonal_preconditioning=True, + ) + + # coefs_samples = states["coefs"][3000:] + # scale_samples = np.exp(states["log_scale"][3000:]) + + # samples = states[3000:] + + print((states**2).mean(axis=0), Banana().E_x2) + + np.testing.assert_allclose((states**2).mean(axis=0), Banana().E_x2, atol=1e-2) + + # np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + # np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) def test_mclmc_preconditioning(self): class IllConditionedGaussian: @@ -1223,5 +1284,46 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): ) + +#TODO: remove +class Banana(): + """Banana target fromm the Inference Gym""" + + def __init__(self, initialization= 'wide'): + self.name = 'Banana' + self.ndims = 2 + self.curvature = 0.03 + + self.transform = lambda x: x + self.E_x2 = jnp.array([100.0, 19.0]) #the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. + self.Var_x2 = jnp.array([20000.0, 4600.898]) + + if initialization == 'map': + self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) + elif initialization == 'posterior': + self.sample_init = lambda key: self.posterior_draw(key) + elif initialization == 'wide': + self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.array([10.0, 5.0]) * 2 + else: + raise ValueError('initialization = '+initialization +' is not a valid option.') + + def logdensity_fn(self, x): + mu2 = self.curvature * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + + def posterior_draw(self, key): + z = jax.random.normal(key, shape = (2, )) + x0 = 10.0 * z[0] + x1 = self.curvature * (x0 ** 2 - 100) + z[1] + return jnp.array([x0, x1]) + + def ground_truth(self): + x = jax.vmap(self.posterior_draw)(jax.random.split(jax.random.PRNGKey(0), 100000000)) + print(jnp.average(x, axis=0)) + print(jnp.average(jnp.square(x), axis=0)) + print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) + + if __name__ == "__main__": absltest.main() + From 6972f235b5dbb553ff38276ac596fdc645a762d5 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 3 Feb 2025 13:44:45 -0500 Subject: [PATCH 09/63] mid cleanup --- blackjax/adaptation/ensemble_mclmc.py | 107 ++++++++++++------------ blackjax/adaptation/ensemble_umclmc.py | 2 + blackjax/mcmc/metrics.py | 9 +- blackjax/optimizers/lbfgs.py | 4 +- blackjax/sgmcmc/csgld.py | 1 + blackjax/sgmcmc/sgnht.py | 1 + blackjax/smc/tuning/from_kernel_info.py | 1 + blackjax/smc/tuning/from_particles.py | 1 + tests/adaptation/test_mass_matrix.py | 1 + tests/mcmc/test_sampling.py | 107 +++++++++++++++--------- tests/mcmc/test_trajectory.py | 1 + tests/mcmc/test_uturn.py | 1 + tests/optimizers/test_optimizers.py | 1 + tests/optimizers/test_pathfinder.py | 1 + tests/smc/test_resampling.py | 1 + tests/smc/test_smc.py | 1 + tests/smc/test_smc_ess.py | 1 + tests/smc/test_solver.py | 1 + tests/smc/test_tempered_smc.py | 9 +- tests/test_benchmarks.py | 1 + tests/test_compilation.py | 1 + tests/test_diagnostics.py | 1 + 22 files changed, 149 insertions(+), 105 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index c01791daa..786d55195 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -22,8 +22,6 @@ from blackjax.adaptation.ensemble_umclmc import ( equipartition_diagonal, equipartition_diagonal_loss, - equipartition_fullrank, - equipartition_fullrank_loss, ) from blackjax.adaptation.step_size import bisection_monotonic_fn from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt @@ -37,30 +35,38 @@ class AdaptationState(NamedTuple): + steps_per_sample: float step_size: float - epsadap_state: Any - sample_count: int + stepsize_adaptation_state: ( + Any # the state of the bisection algorithm to find a stepsize + ) + iteration: int + -# put the arguments of build_kernel in a suitable order build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( logdensity_fn=logdensity_fn, integrator=integrator, inverse_mass_matrix=inverse_mass_matrix, -)(rng_key=key, state=state, step_size=adap.step_size, num_integration_steps=adap.steps_per_sample, L_proposal_factor=1.25) - +)( + rng_key=key, + state=state, + step_size=adap.step_size, + num_integration_steps=adap.steps_per_sample, + L_proposal_factor=1.25, +) class Adaptation: def __init__( self, - adap_state, - num_adaptation_samples, - steps_per_sample, + adaptation_state, + num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params + steps_per_sample, # L/eps (same for each chain: currently fixed to 15) acc_prob_target=0.8, - observables=lambda x: 0.0, - observables_for_bias=lambda x: 0.0, - contract=lambda x: 0.0, + observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep + observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains + contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions ): self.num_adaptation_samples = num_adaptation_samples self.observables = observables @@ -76,27 +82,30 @@ def __init__( # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - # adjustment_factor = jnp.power(0.82 / (num_dims * adap_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adap_state.step_size # * integrator_factor * adjustment_factor + # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adaptation_state.step_size # * integrator_factor * adjustment_factor # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) # Initialize the dual averaging adaptation # # da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) - # epsadap_state = da_init_fn(step_size) + # stepsize_adaptation_state = da_init_fn(step_size) # Initialize the bisection for finding the step size - epsadap_state, self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( + acc_prob_target + ) self.initial_state = AdaptationState( - steps_per_sample, step_size, epsadap_state, 0 + steps_per_sample, step_size, stepsize_adaptation_state, 0 ) def summary_statistics_fn(self, state, info, rng_key): return { "acceptance_probability": info.acceptance_rate, - "equipartition_diagonal": equipartition_diagonal(state), - "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "equipartition_diagonal": equipartition_diagonal( + state + ), # metric for bias: equipartition theorem gives todo... "observables": self.observables(state.position), "observables_for_bias": self.observables_for_bias(state.position), } @@ -106,8 +115,7 @@ def update(self, adaptation_state, Etheta): acc_prob = Etheta["acceptance_probability"] # acc_prob = 1./Etheta['inv_acceptance_probability'] equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) - equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) - true_bias = self.contract(Etheta["observables_for_bias"]) + true_bias = self.contract(Etheta["observables_for_bias"]) # remove info_to_be_stored = { "L": adaptation_state.step_size * adaptation_state.steps_per_sample, @@ -115,38 +123,23 @@ def update(self, adaptation_state, Etheta): "step_size": adaptation_state.step_size, "acc_prob": acc_prob, "equi_diag": equi_diag, - "equi_full": equi_full, "bias": true_bias, "observables": Etheta["observables"], } - # hyperparameter adaptation - - # Dual Averaging - # adaptation_phase = adaptation_state.sample_count < self.num_adaptation_samples - - # def update(_): - # da_state = self.epsadap_update(adaptation_state.epsadap_state, acc_prob) - # step_size = jnp.exp(da_state.log_step_size) - # return da_state, step_size - - # def dont_update(_): - # da_state = adaptation_state.epsadap_state - # return da_state, jnp.exp(da_state.log_step_size_avg) - - # epsadap_state, step_size = jax.lax.cond(adaptation_phase, update, dont_update, operand=None) - - # Bisection - epsadap_state, step_size = self.epsadap_update( - adaptation_state.epsadap_state, adaptation_state.step_size, acc_prob + # Bisection to find step size + stepsize_adaptation_state, step_size = self.epsadap_update( + adaptation_state.stepsize_adaptation_state, + adaptation_state.step_size, + acc_prob, ) return ( AdaptationState( adaptation_state.steps_per_sample, step_size, - epsadap_state, - adaptation_state.sample_count + 1, + stepsize_adaptation_state, + adaptation_state.iteration + 1, ), info_to_be_stored, ) @@ -174,24 +167,25 @@ def while_steps_num(cond): def emaus( model, - num_steps1, - num_steps2, + num_steps1, # max number in phase 1 + num_steps2, # fixed number in phase 2 num_chains, mesh, rng_key, - alpha=1.9, - bias_type=0, - save_frac=0.2, - C=0.1, - power=3.0 / 8.0, - early_stop=True, + alpha=1.9, # L = \sqrt{d}*\alpha*vars + bias_type=0, # eliminate (fix to diagonal rank) + save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 + C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) + power=3.0 / 8.0, # eliminate + early_stop=True, # for stage 1 r_end=5e-3, # stage1 parameters diagonal_preconditioning=True, - integrator_coefficients=None, + integrator_coefficients=None, # (for stage 2) steps_per_sample=10, acc_prob=None, observables=lambda x: None, ensemble_observables=None, + diagnostics=True ): observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) @@ -309,4 +303,9 @@ def emaus( ensemble_observables, ) - return info1, info2, gradient_calls_per_step, _acc_prob + if diagnostics: + info = {"phase_1" : info1, "phase_2" : info2} + else: + info = None + + return info, gradient_calls_per_step, _acc_prob, final_state diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 7e2390def..d430b767e 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -128,6 +128,7 @@ def contract_history(theta, weights): return jnp.array([jnp.max(r), jnp.average(r)]) +# used for the early stopping class History(NamedTuple): observables: Array stopping: Array @@ -224,6 +225,7 @@ def __init__( history=history, ) + # info 1 def summary_statistics_fn(self, state, info, rng_key): position_flat, unravel_fn = ravel_pytree(state.position) diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index f0720acf4..70e33d3a4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -43,8 +43,7 @@ class KineticEnergy(Protocol): def __call__( self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> Numeric: - ... + ) -> Numeric: ... class CheckTurning(Protocol): @@ -55,8 +54,7 @@ def __call__( momentum_sum: ArrayLikeTree, position_left: Optional[ArrayLikeTree] = None, position_right: Optional[ArrayLikeTree] = None, - ) -> bool: - ... + ) -> bool: ... class Scale(Protocol): @@ -67,8 +65,7 @@ def __call__( *, inv: bool, trans: bool, - ) -> ArrayLikeTree: - ... + ) -> ArrayLikeTree: ... class Metric(NamedTuple): diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index 0dd59f003..aef55200f 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -269,9 +269,7 @@ def compute_next_alpha(s_l, z_l, alpha_lm1): b = z_l.T @ s_l c = s_l.T @ jnp.diag(1.0 / alpha_lm1) @ s_l inv_alpha_l = ( - a / (b * alpha_lm1) - + z_l**2 / b - - (a * s_l**2) / (b * c * alpha_lm1**2) + a / (b * alpha_lm1) + z_l**2 / b - (a * s_l**2) / (b * c * alpha_lm1**2) ) return 1.0 / inv_alpha_l diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index 506740c50..02766ca3d 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -41,6 +41,7 @@ class ContourSGLDState(NamedTuple): Index `i` such that the current position belongs to :math:`S_i`. """ + position: ArrayTree energy_pdf: Array energy_idx: int diff --git a/blackjax/sgmcmc/sgnht.py b/blackjax/sgmcmc/sgnht.py index ad9547406..7bcb2ccef 100644 --- a/blackjax/sgmcmc/sgnht.py +++ b/blackjax/sgmcmc/sgnht.py @@ -35,6 +35,7 @@ class SGNHTState(NamedTuple): Scalar thermostat controlling kinetic energy. """ + position: ArrayTree momentum: ArrayTree xi: float diff --git a/blackjax/smc/tuning/from_kernel_info.py b/blackjax/smc/tuning/from_kernel_info.py index a039e66c1..5725cc363 100644 --- a/blackjax/smc/tuning/from_kernel_info.py +++ b/blackjax/smc/tuning/from_kernel_info.py @@ -2,6 +2,7 @@ strategies to tune the parameters of mcmc kernels used within smc, based on MCMC states """ + import jax import jax.numpy as jnp diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 4c8ca98da..2d0b737fa 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -2,6 +2,7 @@ strategies to tune the parameters of mcmc kernels used within SMC, based on particles. """ + import jax import jax.numpy as jnp from jax._src.flatten_util import ravel_pytree diff --git a/tests/adaptation/test_mass_matrix.py b/tests/adaptation/test_mass_matrix.py index 622b2111c..97d6ea882 100644 --- a/tests/adaptation/test_mass_matrix.py +++ b/tests/adaptation/test_mass_matrix.py @@ -1,4 +1,5 @@ """Test the welford adaptation algorithm.""" + import itertools import chex diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 5d3dece82..7d600e14c 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -1,4 +1,5 @@ """Test the accuracy of the MCMC kernels.""" + import functools import itertools @@ -285,31 +286,50 @@ def run_adjusted_mclmc( ) return out - + def run_emaus( - self, - initial_position, - logdensity_fn, - key, - num_steps, - diagonal_preconditioning, - ): + self, + initial_position, + logdensity_fn, + key, + num_steps, + diagonal_preconditioning, + ): - mesh = jax.sharding.Mesh(jax.devices(), 'chains') + mesh = jax.sharding.Mesh(jax.devices(), "chains") - from blackjax.mcmc.integrators import velocity_verlet_coefficients, mclachlan_coefficients, omelyan_coefficients + from blackjax.mcmc.integrators import ( + velocity_verlet_coefficients, + mclachlan_coefficients, + omelyan_coefficients, + ) + integrator_coefficients = mclachlan_coefficients - integrator_coefficients = mclachlan_coefficients + info1, info2, grads_per_step, _acc_prob = emaus( + logdensity_fn, + num_steps1=1000, + num_steps2=3000, + num_chains=4000, + mesh=mesh, + rng_key=key, + alpha=1.9, + bias_type=3, + C=0.1, + power=3.0 / 8.0, + early_stop=1, + r_end=1e-2, + diagonal_preconditioning=diagonal_preconditioning, + integrator_coefficients=integrator_coefficients, + steps_per_sample=15, + acc_prob=None, + ensemble_observables=lambda x: x, + # ensemble_observables = lambda x: vec @ x + ) # run the algorithm - info1, info2, grads_per_step, _acc_prob = emaus(logdensity_fn, num_steps1=1000, num_steps2=3000, num_chains=4000, mesh=mesh, rng_key=key, - alpha = 1.9, bias_type= 3, C= 0.1, power= 3./8., - early_stop=1, r_end= 1e-2, diagonal_preconditioning= diagonal_preconditioning, integrator_coefficients= integrator_coefficients, - steps_per_sample= 15, acc_prob= None, ensemble_observables= lambda x: x - #ensemble_observables = lambda x: vec @ x - ) # run the algorithm - - return info2[1].reshape(info2[1].shape[0]*info2[1].shape[1], info2[1].shape[2]) + return info2[1].reshape( + info2[1].shape[0] * info2[1].shape[1], info2[1].shape[2] + ) @parameterized.parameters( itertools.product( @@ -483,9 +503,11 @@ def test_adjusted_mclmc(self, diagonal_preconditioning): np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) - + # TODO: add preconditioning - def test_emaus(self,): + def test_emaus( + self, + ): """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) @@ -593,8 +615,7 @@ def get_inverse_mass_matrix(): assert ( jnp.abs( jnp.dot( - (inverse_mass_matrix**2) - / jnp.linalg.norm(inverse_mass_matrix**2), + (inverse_mass_matrix**2) / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 @@ -1284,41 +1305,50 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): ) - -#TODO: remove -class Banana(): +# TODO: remove +class Banana: """Banana target fromm the Inference Gym""" - def __init__(self, initialization= 'wide'): - self.name = 'Banana' + def __init__(self, initialization="wide"): + self.name = "Banana" self.ndims = 2 self.curvature = 0.03 - + self.transform = lambda x: x - self.E_x2 = jnp.array([100.0, 19.0]) #the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. + self.E_x2 = jnp.array( + [100.0, 19.0] + ) # the first is analytic the second is by drawing 10^8 samples from the generative model. Relative accuracy is around 10^-5. self.Var_x2 = jnp.array([20000.0, 4600.898]) - if initialization == 'map': + if initialization == "map": self.sample_init = lambda key: jnp.array([0, -100.0 * self.curvature]) - elif initialization == 'posterior': + elif initialization == "posterior": self.sample_init = lambda key: self.posterior_draw(key) - elif initialization == 'wide': - self.sample_init = lambda key: jax.random.normal(key, shape=(self.ndims,)) * jnp.array([10.0, 5.0]) * 2 + elif initialization == "wide": + self.sample_init = ( + lambda key: jax.random.normal(key, shape=(self.ndims,)) + * jnp.array([10.0, 5.0]) + * 2 + ) else: - raise ValueError('initialization = '+initialization +' is not a valid option.') + raise ValueError( + "initialization = " + initialization + " is not a valid option." + ) def logdensity_fn(self, x): mu2 = self.curvature * (x[0] ** 2 - 100) return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) def posterior_draw(self, key): - z = jax.random.normal(key, shape = (2, )) + z = jax.random.normal(key, shape=(2,)) x0 = 10.0 * z[0] - x1 = self.curvature * (x0 ** 2 - 100) + z[1] + x1 = self.curvature * (x0**2 - 100) + z[1] return jnp.array([x0, x1]) def ground_truth(self): - x = jax.vmap(self.posterior_draw)(jax.random.split(jax.random.PRNGKey(0), 100000000)) + x = jax.vmap(self.posterior_draw)( + jax.random.split(jax.random.PRNGKey(0), 100000000) + ) print(jnp.average(x, axis=0)) print(jnp.average(jnp.square(x), axis=0)) print(jnp.std(jnp.square(x[:, 0])) ** 2, jnp.std(jnp.square(x[:, 1])) ** 2) @@ -1326,4 +1356,3 @@ def ground_truth(self): if __name__ == "__main__": absltest.main() - diff --git a/tests/mcmc/test_trajectory.py b/tests/mcmc/test_trajectory.py index e93280400..bc8490b19 100644 --- a/tests/mcmc/test_trajectory.py +++ b/tests/mcmc/test_trajectory.py @@ -1,4 +1,5 @@ """Test the trajectory integration""" + import chex import jax import jax.numpy as jnp diff --git a/tests/mcmc/test_uturn.py b/tests/mcmc/test_uturn.py index 7f9f597d6..ff1f261f6 100644 --- a/tests/mcmc/test_uturn.py +++ b/tests/mcmc/test_uturn.py @@ -1,4 +1,5 @@ """Test the iterative u-turn criterion.""" + import chex import jax.numpy as jnp from absl.testing import absltest, parameterized diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index a7549842f..47f437af2 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -1,4 +1,5 @@ """Test optimizers.""" + import functools import chex diff --git a/tests/optimizers/test_pathfinder.py b/tests/optimizers/test_pathfinder.py index b9b9c69be..f40e79410 100644 --- a/tests/optimizers/test_pathfinder.py +++ b/tests/optimizers/test_pathfinder.py @@ -1,4 +1,5 @@ """Test the pathfinder algorithm.""" + import functools import chex diff --git a/tests/smc/test_resampling.py b/tests/smc/test_resampling.py index 20cb0d813..e6570f8f6 100644 --- a/tests/smc/test_resampling.py +++ b/tests/smc/test_resampling.py @@ -1,4 +1,5 @@ """Test the resampling functions for SMC.""" + import itertools import chex diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index b0e86e0b0..5c5e73259 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -1,4 +1,5 @@ """Test the generic SMC sampler""" + import functools import chex diff --git a/tests/smc/test_smc_ess.py b/tests/smc/test_smc_ess.py index 570d392d9..1f02b8d61 100644 --- a/tests/smc/test_smc_ess.py +++ b/tests/smc/test_smc_ess.py @@ -1,4 +1,5 @@ """Test the ess function""" + import functools import chex diff --git a/tests/smc/test_solver.py b/tests/smc/test_solver.py index 49db84129..8bcdd6a07 100644 --- a/tests/smc/test_solver.py +++ b/tests/smc/test_solver.py @@ -1,4 +1,5 @@ """Test the solving functions""" + import itertools import chex diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index 527457d62..ef8b8cb08 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -1,4 +1,5 @@ """Test the tempered SMC steps and routine""" + import functools import chex @@ -79,9 +80,11 @@ def logprior_fn(x): base_params, jax.tree.map(lambda x: jnp.repeat(x, num_particles, axis=0), base_params), jax.tree_util.tree_map_with_path( - lambda path, x: jnp.repeat(x, num_particles, axis=0) - if path[0].key == "step_size" - else x, + lambda path, x: ( + jnp.repeat(x, num_particles, axis=0) + if path[0].key == "step_size" + else x + ), base_params, ), ] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 2d108a48d..ea9c6aa66 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -5,6 +5,7 @@ obviously more models. It should also be run in CI. """ + import functools import jax diff --git a/tests/test_compilation.py b/tests/test_compilation.py index 7179b71ba..2d3b03a44 100644 --- a/tests/test_compilation.py +++ b/tests/test_compilation.py @@ -5,6 +5,7 @@ internal changes do not trigger more compilations than is necessary. """ + import chex import jax import jax.numpy as jnp diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py index b583c8645..1d7c74846 100644 --- a/tests/test_diagnostics.py +++ b/tests/test_diagnostics.py @@ -1,4 +1,5 @@ """Test MCMC diagnostics.""" + import functools import itertools From c96a8e82c1a2d170f47c01ab30d669ab65fe35d2 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 4 Feb 2025 13:16:03 -0500 Subject: [PATCH 10/63] fix while loop --- .../adaptation/adjusted_mclmc_adaptation.py | 2 +- blackjax/adaptation/ensemble_mclmc.py | 50 ++++++++++--------- blackjax/util.py | 29 ++++++++++- 3 files changed, 55 insertions(+), 26 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index eabb642a3..8c9fafc60 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -100,7 +100,7 @@ def adjusted_mclmc_find_L_and_step_size( state, params = adjusted_mclmc_make_adaptation_L( mclmc_kernel, frac=frac_tune3, - Lfactor=0.5, + Lfactor=0.3, max=max, eigenvector=eigenvector, )(state, params, num_steps, part2_key1) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 786d55195..95f2a9456 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -111,9 +111,7 @@ def summary_statistics_fn(self, state, info, rng_key): } def update(self, adaptation_state, Etheta): - # combine the expectation values to get useful scalars acc_prob = Etheta["acceptance_probability"] - # acc_prob = 1./Etheta['inv_acceptance_probability'] equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) true_bias = self.contract(Etheta["observables_for_bias"]) # remove @@ -173,10 +171,8 @@ def emaus( mesh, rng_key, alpha=1.9, # L = \sqrt{d}*\alpha*vars - bias_type=0, # eliminate (fix to diagonal rank) save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) - power=3.0 / 8.0, # eliminate early_stop=True, # for stage 1 r_end=5e-3, # stage1 parameters diagonal_preconditioning=True, @@ -187,6 +183,28 @@ def emaus( ensemble_observables=None, diagnostics=True ): + + """ + model: the target density object + num_steps1: number of steps in the first phase + num_steps2: number of steps in the second phase + num_chains: number of chains + mesh: the mesh object, used for distributing the computation across cpus and nodes + rng_key: the random key + alpha: L = \sqrt{d}*\alpha*variances + save_frac: the fraction of samples used to estimate the fluctuation in the first phase + C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) + early_stop: whether to stop the first phase early + r_end + diagonal_preconditioning: whether to use diagonal preconditioning + integrator_coefficients: the coefficients of the integrator + steps_per_sample: the number of steps per sample + acc_prob: the acceptance probability + observables: the observables (for diagnostic use) + ensemble_observables: observable calculated over the ensemble (for diagnostic use) + diagnostics: whether to return diagnostics + """ + observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) @@ -201,15 +219,16 @@ def emaus( adap = umclmc.Adaptation( model.ndims, alpha=alpha, - bias_type=bias_type, + bias_type=3, save_num=save_num, C=C, - power=power, + power=3.0 / 8.0, r_end=r_end, observables=observables, observables_for_bias=observables_for_bias, contract=contract, ) + final_state, final_adaptation_state, info1 = run_eca( key_umclmc, initial_state, @@ -219,26 +238,9 @@ def emaus( num_chains, mesh, ensemble_observables, + early_stop=early_stop, ) - if ( - early_stop - ): # here I am cheating a bit, because I am not sure if it is possible to do a while loop in jax and save something at every step. Therefore I rerun burn-in with exactly the same parameters and stop at the point where the orignal while loop would have stopped. The release implementation should not have that. - num_steps_while = while_steps_num( - (info1[0] if ensemble_observables is not None else info1)["while_cond"] - ) - # print(num_steps_while, save_num) - final_state, final_adaptation_state, info1 = run_eca( - key_umclmc, - initial_state, - kernel, - adap, - num_steps_while, - num_chains, - mesh, - ensemble_observables, - ) - # refine the results with the adjusted method # _acc_prob = acc_prob if integrator_coefficients is None: diff --git a/blackjax/util.py b/blackjax/util.py index e8c42f11f..53543f662 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -11,6 +11,8 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map + +import jax from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -375,6 +377,7 @@ def run_eca( num_chains, mesh, ensemble_info=None, + early_stop=False, ): step = eca_step( kernel, @@ -396,7 +399,31 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): keys_adaptation, ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - final_state_all, info_history = lax.scan(step, initial_state_all, xs) + # ((a, Int) -> (a, Int)) + def step_while(a): + x, i, _ = a + + auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + + # output, info = step(x, (jnp.arange(num_steps)[0],keys_sampling.T[0],keys_adaptation[0])) + output, info = step(x,auxilliary_input) + + + # jax.debug.print("info {x}", x=info[0].get("while_cond")) + # jax.debug.print("info {x}", x=i) + + return (output, i + 1, info[0].get("while_cond")) + + # jax.debug.print("initial {x}", x=0) + if early_stop: + final_state_all, i, _ = lax.while_loop( + lambda a: ((a[1] < num_steps) & a[2] ), step_while, (initial_state_all, 0, True) + ) + info_history = None + + else: + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + final_state, final_adaptation_state = final_state_all return ( final_state, From 805113a2716a41e6e318a03580be5a04a33431fa Mon Sep 17 00:00:00 2001 From: = Date: Thu, 6 Feb 2025 17:55:41 -0500 Subject: [PATCH 11/63] test passes --- blackjax/adaptation/ensemble_mclmc.py | 46 +++++++++++------ blackjax/adaptation/ensemble_umclmc.py | 36 +++++++++++-- blackjax/mcmc/integrators.py | 12 +++++ blackjax/mcmc/mclmc.py | 5 ++ blackjax/util.py | 14 +++++ tests/mcmc/test_sampling.py | 71 +++++++++++++++++--------- 6 files changed, 139 insertions(+), 45 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 95f2a9456..5b95aacbd 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -164,13 +164,16 @@ def while_steps_num(cond): def emaus( - model, + logdensity_fn, + sample_init, + transform, + ndims, num_steps1, # max number in phase 1 num_steps2, # fixed number in phase 2 num_chains, mesh, rng_key, - alpha=1.9, # L = \sqrt{d}*\alpha*vars + alpha=1.9, # L = sqrt{d}*alpha*vars save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) early_stop=True, # for stage 1 @@ -183,7 +186,6 @@ def emaus( ensemble_observables=None, diagnostics=True ): - """ model: the target density object num_steps1: number of steps in the first phase @@ -191,7 +193,7 @@ def emaus( num_chains: number of chains mesh: the mesh object, used for distributing the computation across cpus and nodes rng_key: the random key - alpha: L = \sqrt{d}*\alpha*variances + alpha: L = sqrt{d}*alpha*variances save_frac: the fraction of samples used to estimate the fluctuation in the first phase C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) early_stop: whether to stop the first phase early @@ -205,29 +207,35 @@ def emaus( diagnostics: whether to return diagnostics """ - observables_for_bias, contract = bias(model) + # observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) # initialize the chains initial_state = umclmc.initialize( - key_init, model.logdensity_fn, model.sample_init, num_chains, mesh + key_init, logdensity_fn, sample_init, num_chains, mesh ) + + # jax.debug.print("{x} foo", x=jax.flatten_util.ravel_pytree(initial_state.position)[0].shape[-1]) + ndims = 2 + # burn-in with the unadjusted method # - kernel = umclmc.build_kernel(model.logdensity_fn) + kernel = umclmc.build_kernel(logdensity_fn) save_num = (int)(jnp.rint(save_frac * num_steps1)) adap = umclmc.Adaptation( - model.ndims, + ndims, alpha=alpha, bias_type=3, save_num=save_num, C=C, power=3.0 / 8.0, r_end=r_end, - observables=observables, - observables_for_bias=observables_for_bias, - contract=contract, + # observables=observables, + observables_for_bias=lambda position: jnp.square(transform(jax.flatten_util.ravel_pytree(position)[0])), + # contract=contract, ) + + # jax.debug.print("initial_state.momentum: {x}", x=initial_state.momentum) final_state, final_adaptation_state, info1 = run_eca( key_umclmc, @@ -241,10 +249,13 @@ def emaus( early_stop=early_stop, ) + # print(final_state.position['coefs'].shape, "\n\nfoo\n\n") + # jax.debug.print("final_state.position: {x}", x=jnp.mean(final_state.position['coefs'])) + # refine the results with the adjusted method # _acc_prob = acc_prob if integrator_coefficients is None: - high_dims = model.ndims > 200 + high_dims = ndims > 200 _integrator_coefficients = ( omelyan_coefficients if high_dims else mclachlan_coefficients ) @@ -274,8 +285,11 @@ def emaus( inverse_mass_matrix = 1.0 kernel = build_kernel( - model.logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix + logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix ) + + + initial_state = HMCState( final_state.position, final_state.logdensity, final_state.logdensity_grad ) @@ -289,9 +303,9 @@ def emaus( num_adaptation_samples, steps_per_sample, _acc_prob, - observables=observables, - observables_for_bias=observables_for_bias, - contract=contract, + # observables=observables, + # observables_for_bias=observables_for_bias, + # contract=contract, ) final_state, final_adaptation_state, info2 = run_eca( diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index d430b767e..4e124421b 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -77,17 +77,33 @@ def sequential_init(key, x, args): velocity = unravel_fn( _normalized_flatten_array(flat_g)[0] ) # = grad logp/ |grad logp| + + jax.debug.print("logdensity {x}", x=logdensity_fn(position)) + # jax.debug.print("velocity {x}", x=velocity) + jax.debug.print("position {x}", x=position) + # jax.debug.print("logdensity_grad {x}", x=logdensity_grad) + # jax.debug.print("logdensity {x}", x=logdensity) + # jax.debug.print("flat_g {x}", x=flat_g) return IntegratorState(position, velocity, logdensity, logdensity_grad), None def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" - return -state.position * state.logdensity_grad + return 0 # -state.position * state.logdensity_grad + # TODO: restore! def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" - velocity = jax.tree_util.tree_map( - lambda sign, u: sign * u, signs, state.momentum + # velocity = jax.tree_util.tree_map( + # lambda sign, u: sign * u, signs, state.momentum + # ) + momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum) + + velocity_flat = jax.tree_util.tree_map( + lambda sign, u: sign*u, signs, momentum ) + + velocity = unflatten(velocity_flat) + return ( IntegratorState( state.position, velocity, state.logdensity, state.logdensity_grad @@ -103,6 +119,9 @@ def ensemble_init(key, state, signs): mesh, summary_statistics_fn=summary_statistics_fn, ) + + # jax.debug.print("initial_state {x}", x=initial_state.momentum) + signs = -2.0 * (equipartition < 1.0) + 1.0 initial_state, _ = ensemble_execute_fn( ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs @@ -112,7 +131,14 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): + # new_vals = jax.flatten_util.ravel_pytree(new_vals)[0] + # history = jax.flatten_util.ravel_pytree(history)[0] + # print(new_vals, "FOOO\n\n") + + new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) + # print(history, "FOOO\n\n") return jnp.concatenate((new_vals[None, :], history[:-1])) + # return history # TODO CHANGE BACK!!!! def update_history_scalar(new_val, history): @@ -192,7 +218,7 @@ def __init__( bias_type=0, save_num=10, observables=lambda x: 0.0, - observables_for_bias=lambda x: 0.0, + observables_for_bias=lambda x: x, contract=lambda x: 0.0, ): self.num_dims = num_dims @@ -250,6 +276,8 @@ def update(self, adaptation_state, Etheta): history_observables = update_history( Etheta["observables_for_bias"], adaptation_state.history.observables ) + # history_observables = adaptation_state.history.observables + history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) history_stopping = update_history_scalar( diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 733e7e960..1a5711add 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -169,6 +169,7 @@ def update( position, kinetic_grad, ) + # jax.debug.print("new_position {x}", x=new_position) logdensity, logdensity_grad = logdensity_and_grad_fn(new_position) return new_position, logdensity, logdensity_grad, None @@ -330,6 +331,7 @@ def update( """ del is_last_call + # jax.debug.print("old momentum {x}", x=momentum) logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_grads = flatten_grads * sqrt_inverse_mass_matrix @@ -338,6 +340,7 @@ def update( normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) delta = step_size * coef * gradient_norm / (dims - 1) + # jax.debug.print("delta {x}", x=delta) zeta = jnp.exp(-delta) new_momentum_raw = ( normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) @@ -353,6 +356,8 @@ def update( ) * (dims - 1) if previous_kinetic_energy_change is not None: kinetic_energy_change += previous_kinetic_energy_change + + # jax.debug.print("new_momentum {x}", x=next_momentum) return next_momentum, gr, kinetic_energy_change return update @@ -417,11 +422,15 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): momentum with random change in angle """ + # jax.debug.print("momentum unflat {x}", x=momentum) m, unravel_fn = ravel_pytree(momentum) + # jax.debug.print("momentum {x}", x=m) dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) + # jax.debug.print("z {x}", x=z) new_momentum = unravel_fn((m + z) / jnp.linalg.norm(m + z)) + # jax.debug.print("new_momentum {x}", x=new_momentum) # return new_momentum return jax.lax.cond( jnp.isinf(L), @@ -435,6 +444,7 @@ def with_isokinetic_maruyama(integrator): def stochastic_integrator(init_state, step_size, L_proposal, rng_key): key1, key2 = jax.random.split(rng_key) # partial refreshment + # jax.debug.print("state before noise {x}", x=init_state.momentum) state = init_state._replace( momentum=partially_refresh_momentum( momentum=init_state.momentum, @@ -443,8 +453,10 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): step_size=step_size * 0.5, ) ) + # jax.debug.print("state after noise {x}", x=state.momentum) # one step of the deterministic dynamics state, info = integrator(state, step_size) + # jax.debug.print("state after integ {x}", x=state.position) # partial refreshment state = state._replace( diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 2299dc68e..824fd9215 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -89,10 +89,15 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: + # jax.debug.print("state momentum 1 {x}", x=state.momentum) (position, momentum, logdensity, logdensitygrad), kinetic_change = step( state, step_size, L, rng_key ) + # jax.debug.print("state position 2 {x}", x=position) + # jax.debug.print("state position {x}", x=state.position.mean(axis=0)) + # jax.debug.print("state position 2 {x}", x=position.mean(axis=0)) + return IntegratorState( position, momentum, logdensity, logdensitygrad ), MCLMCInfo( diff --git a/blackjax/util.py b/blackjax/util.py index 53543f662..9b0a96046 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -337,6 +337,7 @@ def _step(state_all, xs): # update the state of all chains on this device state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + # combine all the chains to compute expectation values theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) @@ -404,16 +405,29 @@ def step_while(a): x, i, _ = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + # jax.debug.print("momentum init {x}", x=x[0].momentum) # output, info = step(x, (jnp.arange(num_steps)[0],keys_sampling.T[0],keys_adaptation[0])) + # print(x, "\n\n") output, info = step(x,auxilliary_input) + # print(output, "\n\n\nFOOO\n\n\n") + + # jax.debug.print("\nbar\n {x}", x=output[0].position['coefs'].mean(axis=0)) + # jax.debug.print("\nbar\n {x}", x=output[0].position.mean(axis=0)) + + check_state, _ = vmap(kernel, (0, 0, None))(xs[1][i], output[0], output[1]) + # jax.debug.print("\nbaz\n {x}", x=check_state.position['coefs'].mean(axis=0)) + # jax.debug.print("\nbaz\n {x}", x=check_state.position.mean(axis=0)) # jax.debug.print("info {x}", x=info[0].get("while_cond")) # jax.debug.print("info {x}", x=i) return (output, i + 1, info[0].get("while_cond")) + # flatten with ravel: use ravel, not tree_map + # initial_state_all = ravel_pytree(initial_state_all)[0] + # jax.debug.print("initial {x}", x=0) if early_stop: final_state_all, i, _ = lax.while_loop( diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7d600e14c..5e2d47e89 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -289,10 +289,11 @@ def run_adjusted_mclmc( def run_emaus( self, - initial_position, + sample_init, logdensity_fn, + ndims, + transform, key, - num_steps, diagonal_preconditioning, ): @@ -306,17 +307,18 @@ def run_emaus( integrator_coefficients = mclachlan_coefficients - info1, info2, grads_per_step, _acc_prob = emaus( - logdensity_fn, - num_steps1=1000, - num_steps2=3000, - num_chains=4000, + info, grads_per_step, _acc_prob, final_state = emaus( + logdensity_fn=logdensity_fn, + sample_init=sample_init, + transform=transform, + ndims=ndims, + num_steps1=100, + num_steps2=300, + num_chains=100, mesh=mesh, rng_key=key, alpha=1.9, - bias_type=3, C=0.1, - power=3.0 / 8.0, early_stop=1, r_end=1e-2, diagonal_preconditioning=diagonal_preconditioning, @@ -327,9 +329,7 @@ def run_emaus( # ensemble_observables = lambda x: vec @ x ) # run the algorithm - return info2[1].reshape( - info2[1].shape[0] * info2[1].shape[1], info2[1].shape[2] - ) + return final_state.position @parameterized.parameters( itertools.product( @@ -511,35 +511,56 @@ def test_emaus( """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + + # model = Banana() + # logdensity_fn = model.logdensity_fn + # sample_init = model.sample_init + + x_data = jax.random.normal(init_key0, shape=(1000, 1)) y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) logposterior_fn_ = functools.partial( self.regression_logprob, x=x_data, preds=y_data ) - logdensity_fn = lambda x: logposterior_fn_(**x) + # logdensity_fn = lambda x: logposterior_fn_(coefs=x[0], log_scale=x[1]) + logdensity_fn = lambda x: logposterior_fn_(coefs=x['coefs'][0], log_scale=x['log_scale'][0]) + # logdensity_fn = lambda x: logposterior_fn_(**x) - model = Banana() + # jax.debug.print("logposterior_fn_ {x}", x=logdensity_fn(jnp.array([[1.5606847], [1.719502]]))) + # jax.debug.print("logposterior_fn_ {x}", x=logdensity_fn({"coefs": jnp.array(1.5606847), "log_scale": jnp.array(1.719502)})) - states = self.run_emaus( - initial_position={"coefs": 1.0, "log_scale": 1.0}, - logdensity_fn=model, + + def sample_init(key): + key1, key2 = jax.random.split(key) + coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) + log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) + return {"coefs": coefs, "log_scale": log_scale} + # return jnp.concatenate([coefs, log_scale]) + + + samples = self.run_emaus( + sample_init=sample_init, + logdensity_fn=logdensity_fn, + transform=lambda x: x, + ndims=2, key=inference_key, - num_steps=10000, diagonal_preconditioning=True, ) - # coefs_samples = states["coefs"][3000:] - # scale_samples = np.exp(states["log_scale"][3000:]) - # samples = states[3000:] - print((states**2).mean(axis=0), Banana().E_x2) + # # jax.debug.print("pos mean, {x}", x=jnp.mean(samples["coefs"][-1])) + + + coefs_samples = samples["coefs"] + scale_samples = np.exp(samples["log_scale"]) - np.testing.assert_allclose((states**2).mean(axis=0), Banana().E_x2, atol=1e-2) + jax.debug.print("coefs_samples mean {x}", x=jnp.mean(coefs_samples)) + jax.debug.print("scale_samples mean {x}", x=jnp.mean(scale_samples)) - # np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - # np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) def test_mclmc_preconditioning(self): class IllConditionedGaussian: From 16841c6bb96d6b3380b59f5693298ea4db80360a Mon Sep 17 00:00:00 2001 From: = Date: Thu, 6 Feb 2025 18:00:26 -0500 Subject: [PATCH 12/63] precommit --- blackjax/adaptation/ensemble_mclmc.py | 22 ++++++--------- blackjax/adaptation/ensemble_umclmc.py | 15 +++------- blackjax/mcmc/integrators.py | 13 +-------- blackjax/mcmc/mclmc.py | 5 ---- blackjax/mcmc/metrics.py | 9 ++++-- blackjax/optimizers/lbfgs.py | 4 ++- blackjax/util.py | 26 +++-------------- tests/mcmc/test_sampling.py | 39 ++++++-------------------- 8 files changed, 34 insertions(+), 99 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 5b95aacbd..7294f4936 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -35,7 +35,6 @@ class AdaptationState(NamedTuple): - steps_per_sample: float step_size: float stepsize_adaptation_state: ( @@ -83,7 +82,9 @@ def __init__( # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adaptation_state.step_size # * integrator_factor * adjustment_factor + step_size = ( + adaptation_state.step_size + ) # * integrator_factor * adjustment_factor # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) @@ -184,7 +185,7 @@ def emaus( acc_prob=None, observables=lambda x: None, ensemble_observables=None, - diagnostics=True + diagnostics=True, ): """ model: the target density object @@ -215,8 +216,6 @@ def emaus( key_init, logdensity_fn, sample_init, num_chains, mesh ) - - # jax.debug.print("{x} foo", x=jax.flatten_util.ravel_pytree(initial_state.position)[0].shape[-1]) ndims = 2 # burn-in with the unadjusted method # @@ -231,12 +230,12 @@ def emaus( power=3.0 / 8.0, r_end=r_end, # observables=observables, - observables_for_bias=lambda position: jnp.square(transform(jax.flatten_util.ravel_pytree(position)[0])), + observables_for_bias=lambda position: jnp.square( + transform(jax.flatten_util.ravel_pytree(position)[0]) + ), # contract=contract, ) - # jax.debug.print("initial_state.momentum: {x}", x=initial_state.momentum) - final_state, final_adaptation_state, info1 = run_eca( key_umclmc, initial_state, @@ -249,9 +248,6 @@ def emaus( early_stop=early_stop, ) - # print(final_state.position['coefs'].shape, "\n\nfoo\n\n") - # jax.debug.print("final_state.position: {x}", x=jnp.mean(final_state.position['coefs'])) - # refine the results with the adjusted method # _acc_prob = acc_prob if integrator_coefficients is None: @@ -288,8 +284,6 @@ def emaus( logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix ) - - initial_state = HMCState( final_state.position, final_state.logdensity, final_state.logdensity_grad ) @@ -320,7 +314,7 @@ def emaus( ) if diagnostics: - info = {"phase_1" : info1, "phase_2" : info2} + info = {"phase_1": info1, "phase_2": info2} else: info = None diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 4e124421b..f919a82e8 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -78,17 +78,12 @@ def sequential_init(key, x, args): _normalized_flatten_array(flat_g)[0] ) # = grad logp/ |grad logp| - jax.debug.print("logdensity {x}", x=logdensity_fn(position)) - # jax.debug.print("velocity {x}", x=velocity) - jax.debug.print("position {x}", x=position) - # jax.debug.print("logdensity_grad {x}", x=logdensity_grad) - # jax.debug.print("logdensity {x}", x=logdensity) - # jax.debug.print("flat_g {x}", x=flat_g) return IntegratorState(position, velocity, logdensity, logdensity_grad), None def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" - return 0 # -state.position * state.logdensity_grad + return 0 # -state.position * state.logdensity_grad + # TODO: restore! def ensemble_init(key, state, signs): @@ -99,7 +94,7 @@ def ensemble_init(key, state, signs): momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum) velocity_flat = jax.tree_util.tree_map( - lambda sign, u: sign*u, signs, momentum + lambda sign, u: sign * u, signs, momentum ) velocity = unflatten(velocity_flat) @@ -120,8 +115,6 @@ def ensemble_init(key, state, signs): summary_statistics_fn=summary_statistics_fn, ) - # jax.debug.print("initial_state {x}", x=initial_state.momentum) - signs = -2.0 * (equipartition < 1.0) + 1.0 initial_state, _ = ensemble_execute_fn( ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs @@ -277,7 +270,7 @@ def update(self, adaptation_state, Etheta): Etheta["observables_for_bias"], adaptation_state.history.observables ) # history_observables = adaptation_state.history.observables - + history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) history_stopping = update_history_scalar( diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 1a5711add..49700697d 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -169,7 +169,6 @@ def update( position, kinetic_grad, ) - # jax.debug.print("new_position {x}", x=new_position) logdensity, logdensity_grad = logdensity_and_grad_fn(new_position) return new_position, logdensity, logdensity_grad, None @@ -331,7 +330,6 @@ def update( """ del is_last_call - # jax.debug.print("old momentum {x}", x=momentum) logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_grads = flatten_grads * sqrt_inverse_mass_matrix @@ -340,7 +338,6 @@ def update( normalized_gradient, gradient_norm = _normalized_flatten_array(flatten_grads) momentum_proj = jnp.dot(flatten_momentum, normalized_gradient) delta = step_size * coef * gradient_norm / (dims - 1) - # jax.debug.print("delta {x}", x=delta) zeta = jnp.exp(-delta) new_momentum_raw = ( normalized_gradient * (1 - zeta) * (1 + zeta + momentum_proj * (1 - zeta)) @@ -357,7 +354,6 @@ def update( if previous_kinetic_energy_change is not None: kinetic_energy_change += previous_kinetic_energy_change - # jax.debug.print("new_momentum {x}", x=next_momentum) return next_momentum, gr, kinetic_energy_change return update @@ -422,16 +418,12 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): momentum with random change in angle """ - # jax.debug.print("momentum unflat {x}", x=momentum) m, unravel_fn = ravel_pytree(momentum) - # jax.debug.print("momentum {x}", x=m) dim = m.shape[0] nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) - # jax.debug.print("z {x}", x=z) new_momentum = unravel_fn((m + z) / jnp.linalg.norm(m + z)) - # jax.debug.print("new_momentum {x}", x=new_momentum) - # return new_momentum + return jax.lax.cond( jnp.isinf(L), lambda _: momentum, @@ -444,7 +436,6 @@ def with_isokinetic_maruyama(integrator): def stochastic_integrator(init_state, step_size, L_proposal, rng_key): key1, key2 = jax.random.split(rng_key) # partial refreshment - # jax.debug.print("state before noise {x}", x=init_state.momentum) state = init_state._replace( momentum=partially_refresh_momentum( momentum=init_state.momentum, @@ -453,10 +444,8 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): step_size=step_size * 0.5, ) ) - # jax.debug.print("state after noise {x}", x=state.momentum) # one step of the deterministic dynamics state, info = integrator(state, step_size) - # jax.debug.print("state after integ {x}", x=state.position) # partial refreshment state = state._replace( diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 824fd9215..2299dc68e 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -89,15 +89,10 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: - # jax.debug.print("state momentum 1 {x}", x=state.momentum) (position, momentum, logdensity, logdensitygrad), kinetic_change = step( state, step_size, L, rng_key ) - # jax.debug.print("state position 2 {x}", x=position) - # jax.debug.print("state position {x}", x=state.position.mean(axis=0)) - # jax.debug.print("state position 2 {x}", x=position.mean(axis=0)) - return IntegratorState( position, momentum, logdensity, logdensitygrad ), MCLMCInfo( diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 70e33d3a4..f0720acf4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -43,7 +43,8 @@ class KineticEnergy(Protocol): def __call__( self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None - ) -> Numeric: ... + ) -> Numeric: + ... class CheckTurning(Protocol): @@ -54,7 +55,8 @@ def __call__( momentum_sum: ArrayLikeTree, position_left: Optional[ArrayLikeTree] = None, position_right: Optional[ArrayLikeTree] = None, - ) -> bool: ... + ) -> bool: + ... class Scale(Protocol): @@ -65,7 +67,8 @@ def __call__( *, inv: bool, trans: bool, - ) -> ArrayLikeTree: ... + ) -> ArrayLikeTree: + ... class Metric(NamedTuple): diff --git a/blackjax/optimizers/lbfgs.py b/blackjax/optimizers/lbfgs.py index aef55200f..0dd59f003 100644 --- a/blackjax/optimizers/lbfgs.py +++ b/blackjax/optimizers/lbfgs.py @@ -269,7 +269,9 @@ def compute_next_alpha(s_l, z_l, alpha_lm1): b = z_l.T @ s_l c = s_l.T @ jnp.diag(1.0 / alpha_lm1) @ s_l inv_alpha_l = ( - a / (b * alpha_lm1) + z_l**2 / b - (a * s_l**2) / (b * c * alpha_lm1**2) + a / (b * alpha_lm1) + + z_l**2 / b + - (a * s_l**2) / (b * c * alpha_lm1**2) ) return 1.0 / inv_alpha_l diff --git a/blackjax/util.py b/blackjax/util.py index 9b0a96046..5965ac93e 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -11,8 +11,6 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map - -import jax from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -337,7 +335,6 @@ def _step(state_all, xs): # update the state of all chains on this device state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - # combine all the chains to compute expectation values theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) @@ -405,33 +402,18 @@ def step_while(a): x, i, _ = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - # jax.debug.print("momentum init {x}", x=x[0].momentum) - - # output, info = step(x, (jnp.arange(num_steps)[0],keys_sampling.T[0],keys_adaptation[0])) - # print(x, "\n\n") - output, info = step(x,auxilliary_input) - # print(output, "\n\n\nFOOO\n\n\n") - - # jax.debug.print("\nbar\n {x}", x=output[0].position['coefs'].mean(axis=0)) - # jax.debug.print("\nbar\n {x}", x=output[0].position.mean(axis=0)) + output, info = step(x, auxilliary_input) check_state, _ = vmap(kernel, (0, 0, None))(xs[1][i], output[0], output[1]) - - # jax.debug.print("\nbaz\n {x}", x=check_state.position['coefs'].mean(axis=0)) - # jax.debug.print("\nbaz\n {x}", x=check_state.position.mean(axis=0)) - # jax.debug.print("info {x}", x=info[0].get("while_cond")) - # jax.debug.print("info {x}", x=i) return (output, i + 1, info[0].get("while_cond")) - # flatten with ravel: use ravel, not tree_map - # initial_state_all = ravel_pytree(initial_state_all)[0] - - # jax.debug.print("initial {x}", x=0) if early_stop: final_state_all, i, _ = lax.while_loop( - lambda a: ((a[1] < num_steps) & a[2] ), step_while, (initial_state_all, 0, True) + lambda a: ((a[1] < num_steps) & a[2]), + step_while, + (initial_state_all, 0, True), ) info_history = None diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 5e2d47e89..493b441d0 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -3,7 +3,6 @@ import functools import itertools -from blackjax.adaptation.ensemble_mclmc import emaus import chex import jax import jax.numpy as jnp @@ -16,6 +15,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.adaptation.ensemble_mclmc import emaus from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -296,14 +296,9 @@ def run_emaus( key, diagonal_preconditioning, ): - mesh = jax.sharding.Mesh(jax.devices(), "chains") - from blackjax.mcmc.integrators import ( - velocity_verlet_coefficients, - mclachlan_coefficients, - omelyan_coefficients, - ) + from blackjax.mcmc.integrators import mclachlan_coefficients integrator_coefficients = mclachlan_coefficients @@ -511,33 +506,22 @@ def test_emaus( """Test the MCLMC kernel.""" init_key0, init_key1, inference_key = jax.random.split(self.key, 3) - - # model = Banana() - # logdensity_fn = model.logdensity_fn - # sample_init = model.sample_init - x_data = jax.random.normal(init_key0, shape=(1000, 1)) y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) logposterior_fn_ = functools.partial( self.regression_logprob, x=x_data, preds=y_data ) - # logdensity_fn = lambda x: logposterior_fn_(coefs=x[0], log_scale=x[1]) - logdensity_fn = lambda x: logposterior_fn_(coefs=x['coefs'][0], log_scale=x['log_scale'][0]) - # logdensity_fn = lambda x: logposterior_fn_(**x) - - # jax.debug.print("logposterior_fn_ {x}", x=logdensity_fn(jnp.array([[1.5606847], [1.719502]]))) - # jax.debug.print("logposterior_fn_ {x}", x=logdensity_fn({"coefs": jnp.array(1.5606847), "log_scale": jnp.array(1.719502)})) - + logdensity_fn = lambda x: logposterior_fn_( + coefs=x["coefs"][0], log_scale=x["log_scale"][0] + ) def sample_init(key): key1, key2 = jax.random.split(key) coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) - log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) + log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) return {"coefs": coefs, "log_scale": log_scale} - # return jnp.concatenate([coefs, log_scale]) - samples = self.run_emaus( sample_init=sample_init, @@ -548,17 +532,9 @@ def sample_init(key): diagonal_preconditioning=True, ) - - - # # jax.debug.print("pos mean, {x}", x=jnp.mean(samples["coefs"][-1])) - - coefs_samples = samples["coefs"] scale_samples = np.exp(samples["log_scale"]) - jax.debug.print("coefs_samples mean {x}", x=jnp.mean(coefs_samples)) - jax.debug.print("scale_samples mean {x}", x=jnp.mean(scale_samples)) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @@ -636,7 +612,8 @@ def get_inverse_mass_matrix(): assert ( jnp.abs( jnp.dot( - (inverse_mass_matrix**2) / jnp.linalg.norm(inverse_mass_matrix**2), + (inverse_mass_matrix**2) + / jnp.linalg.norm(inverse_mass_matrix**2), eigs / jnp.linalg.norm(eigs), ) - 1 From 52ce7ad935f63e6174ef4224c5abb86dc2fbaa40 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 6 Feb 2025 18:18:55 -0500 Subject: [PATCH 13/63] update --- tests/mcmc/test_sampling.py | 150 ++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7718ad2b6..ae83927d8 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -15,10 +15,10 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.adaptation.ensemble_mclmc import emaus from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm -from blackjax.adaptation.ensemble_mclmc import emaus def orbit_samples(orbits, weights, rng_key): @@ -291,43 +291,43 @@ def run_adjusted_mclmc_static( return out def run_emaus( - self, - sample_init, - logdensity_fn, - ndims, - transform, - key, - diagonal_preconditioning, - ): - mesh = jax.sharding.Mesh(jax.devices(), "chains") - - from blackjax.mcmc.integrators import mclachlan_coefficients - - integrator_coefficients = mclachlan_coefficients - - info, grads_per_step, _acc_prob, final_state = emaus( - logdensity_fn=logdensity_fn, - sample_init=sample_init, - transform=transform, - ndims=ndims, - num_steps1=100, - num_steps2=300, - num_chains=100, - mesh=mesh, - rng_key=key, - alpha=1.9, - C=0.1, - early_stop=1, - r_end=1e-2, - diagonal_preconditioning=diagonal_preconditioning, - integrator_coefficients=integrator_coefficients, - steps_per_sample=15, - acc_prob=None, - ensemble_observables=lambda x: x, - # ensemble_observables = lambda x: vec @ x - ) # run the algorithm - - return final_state.position + self, + sample_init, + logdensity_fn, + ndims, + transform, + key, + diagonal_preconditioning, + ): + mesh = jax.sharding.Mesh(jax.devices(), "chains") + + from blackjax.mcmc.integrators import mclachlan_coefficients + + integrator_coefficients = mclachlan_coefficients + + info, grads_per_step, _acc_prob, final_state = emaus( + logdensity_fn=logdensity_fn, + sample_init=sample_init, + transform=transform, + ndims=ndims, + num_steps1=100, + num_steps2=300, + num_chains=100, + mesh=mesh, + rng_key=key, + alpha=1.9, + C=0.1, + early_stop=1, + r_end=1e-2, + diagonal_preconditioning=diagonal_preconditioning, + integrator_coefficients=integrator_coefficients, + steps_per_sample=15, + acc_prob=None, + ensemble_observables=lambda x: x, + # ensemble_observables = lambda x: vec @ x + ) # run the algorithm + + return final_state.position @parameterized.parameters( itertools.product( @@ -576,42 +576,42 @@ def get_inverse_mass_matrix(): ) def test_emaus( - self, - ): - """Test the MCLMC kernel.""" - - init_key0, init_key1, inference_key = jax.random.split(self.key, 3) - - x_data = jax.random.normal(init_key0, shape=(1000, 1)) - y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) - - logposterior_fn_ = functools.partial( - self.regression_logprob, x=x_data, preds=y_data - ) - logdensity_fn = lambda x: logposterior_fn_( - coefs=x["coefs"][0], log_scale=x["log_scale"][0] - ) - - def sample_init(key): - key1, key2 = jax.random.split(key) - coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) - log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) - return {"coefs": coefs, "log_scale": log_scale} - - samples = self.run_emaus( - sample_init=sample_init, - logdensity_fn=logdensity_fn, - transform=lambda x: x, - ndims=2, - key=inference_key, - diagonal_preconditioning=True, - ) - - coefs_samples = samples["coefs"] - scale_samples = np.exp(samples["log_scale"]) - - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + self, + ): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_( + coefs=x["coefs"][0], log_scale=x["log_scale"][0] + ) + + def sample_init(key): + key1, key2 = jax.random.split(key) + coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) + log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) + return {"coefs": coefs, "log_scale": log_scale} + + samples = self.run_emaus( + sample_init=sample_init, + logdensity_fn=logdensity_fn, + transform=lambda x: x, + ndims=2, + key=inference_key, + diagonal_preconditioning=True, + ) + + coefs_samples = samples["coefs"] + scale_samples = np.exp(samples["log_scale"]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( @@ -1296,4 +1296,4 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From cc7bfbd00bac75a3f0f50802dc9a5f37a0909e50 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Feb 2025 11:59:31 -0500 Subject: [PATCH 14/63] docstrings --- blackjax/sgmcmc/csgld.py | 1 - blackjax/sgmcmc/sgnht.py | 1 - blackjax/util.py | 22 +++++++++++++++++++--- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/blackjax/sgmcmc/csgld.py b/blackjax/sgmcmc/csgld.py index 02766ca3d..506740c50 100644 --- a/blackjax/sgmcmc/csgld.py +++ b/blackjax/sgmcmc/csgld.py @@ -41,7 +41,6 @@ class ContourSGLDState(NamedTuple): Index `i` such that the current position belongs to :math:`S_i`. """ - position: ArrayTree energy_pdf: Array energy_idx: int diff --git a/blackjax/sgmcmc/sgnht.py b/blackjax/sgmcmc/sgnht.py index 7bcb2ccef..ad9547406 100644 --- a/blackjax/sgmcmc/sgnht.py +++ b/blackjax/sgmcmc/sgnht.py @@ -35,7 +35,6 @@ class SGNHTState(NamedTuple): Scalar thermostat controlling kinetic energy. """ - position: ArrayTree momentum: ArrayTree xi: float diff --git a/blackjax/util.py b/blackjax/util.py index 5965ac93e..d0f97aa90 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -377,6 +377,25 @@ def run_eca( ensemble_info=None, early_stop=False, ): + """ + Run ensemble of chains in parallel on multiple devices. + ----------------------------------------------------- + Args: + rng_key: random key + initial_state: initial state of the system + kernel: kernel for the dynamics + adaptation: adaptation object + num_steps: number of steps to run + num_chains: number of chains + mesh: mesh for parallelization + ensemble_info: function that takes the state of the system and returns some information about the ensemble + early_stop: whether to stop early + Returns: + final_state: final state of the system + final_adaptation_state: final adaptation state + info_history: history of the information that was stored at each step (if early_stop is False, then this is None) + """ + step = eca_step( kernel, adaptation.summary_statistics_fn, @@ -405,8 +424,6 @@ def step_while(a): output, info = step(x, auxilliary_input) - check_state, _ = vmap(kernel, (0, 0, None))(xs[1][i], output[0], output[1]) - return (output, i + 1, info[0].get("while_cond")) if early_stop: @@ -437,7 +454,6 @@ def step_while(a): ) # produce all random keys that will be needed - # rng_key = rng_key if not isinstance(rng_key, jnp.ndarray) else rng_key[0] key_sampling, key_adaptation = split(rng_key) num_steps = jnp.array(num_steps).item() From 58d9920a2fa6281334e7832cb4a6c37dbc9181f9 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Feb 2025 12:11:55 -0500 Subject: [PATCH 15/63] remove debug statements --- blackjax/adaptation/ensemble_mclmc.py | 30 +++++++++----------------- blackjax/adaptation/ensemble_umclmc.py | 14 +----------- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 7294f4936..7a117665f 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -84,13 +84,8 @@ def __init__( # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) step_size = ( adaptation_state.step_size - ) # * integrator_factor * adjustment_factor + ) - # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) - - # Initialize the dual averaging adaptation # - # da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) - # stepsize_adaptation_state = da_init_fn(step_size) # Initialize the bisection for finding the step size stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( @@ -169,18 +164,18 @@ def emaus( sample_init, transform, ndims, - num_steps1, # max number in phase 1 - num_steps2, # fixed number in phase 2 + num_steps1, + num_steps2, num_chains, mesh, rng_key, - alpha=1.9, # L = sqrt{d}*alpha*vars - save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 - C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) - early_stop=True, # for stage 1 - r_end=5e-3, # stage1 parameters + alpha=1.9, + save_frac=0.2, + C=0.1, + early_stop=True, + r_end=5e-3, diagonal_preconditioning=True, - integrator_coefficients=None, # (for stage 2) + integrator_coefficients=None, steps_per_sample=10, acc_prob=None, observables=lambda x: None, @@ -229,11 +224,9 @@ def emaus( C=C, power=3.0 / 8.0, r_end=r_end, - # observables=observables, observables_for_bias=lambda position: jnp.square( transform(jax.flatten_util.ravel_pytree(position)[0]) ), - # contract=contract, ) final_state, final_adaptation_state, info1 = run_eca( @@ -248,7 +241,7 @@ def emaus( early_stop=early_stop, ) - # refine the results with the adjusted method # + # refine the results with the adjusted method _acc_prob = acc_prob if integrator_coefficients is None: high_dims = ndims > 200 @@ -297,9 +290,6 @@ def emaus( num_adaptation_samples, steps_per_sample, _acc_prob, - # observables=observables, - # observables_for_bias=observables_for_bias, - # contract=contract, ) final_state, final_adaptation_state, info2 = run_eca( diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index f919a82e8..830099ce6 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -88,9 +88,7 @@ def summary_statistics_fn(state): def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" - # velocity = jax.tree_util.tree_map( - # lambda sign, u: sign * u, signs, state.momentum - # ) + momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum) velocity_flat = jax.tree_util.tree_map( @@ -124,14 +122,8 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): - # new_vals = jax.flatten_util.ravel_pytree(new_vals)[0] - # history = jax.flatten_util.ravel_pytree(history)[0] - # print(new_vals, "FOOO\n\n") - new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) - # print(history, "FOOO\n\n") return jnp.concatenate((new_vals[None, :], history[:-1])) - # return history # TODO CHANGE BACK!!!! def update_history_scalar(new_val, history): @@ -146,8 +138,6 @@ def contract_history(theta, weights): return jnp.array([jnp.max(r), jnp.average(r)]) - -# used for the early stopping class History(NamedTuple): observables: Array stopping: Array @@ -224,8 +214,6 @@ def __init__( self.contract = contract self.bias_type = bias_type self.save_num = save_num - # sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) - r_save_num = save_num history = History( From 4274b07fdbace00a49b453364310a10ab2e805da Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Feb 2025 12:14:57 -0500 Subject: [PATCH 16/63] precommit --- blackjax/adaptation/ensemble_mclmc.py | 5 +---- blackjax/adaptation/ensemble_umclmc.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 7a117665f..363b2df7f 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -82,10 +82,7 @@ def __init__( # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = ( - adaptation_state.step_size - ) - + step_size = adaptation_state.step_size # Initialize the bisection for finding the step size stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 830099ce6..e65f05c50 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -138,6 +138,7 @@ def contract_history(theta, weights): return jnp.array([jnp.max(r), jnp.average(r)]) + class History(NamedTuple): observables: Array stopping: Array From d951a4456e0caa3286123654e3a198d664d88768 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Feb 2025 12:12:18 -0500 Subject: [PATCH 17/63] modify test --- blackjax/util.py | 6 +++++- tests/mcmc/test_sampling.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index d0f97aa90..ee71af2b9 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -321,6 +321,10 @@ def incremental_value_update( def eca_step( kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None ): + """ + Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. + """ + def _step(state_all, xs): """This function operates on a single device.""" ( @@ -378,7 +382,7 @@ def run_eca( early_stop=False, ): """ - Run ensemble of chains in parallel on multiple devices. + Run ensemble chain adaptation (eca) in parallel on multiple devices. ----------------------------------------------------- Args: rng_key: random key diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index ae83927d8..5d43801d7 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -312,7 +312,7 @@ def run_emaus( ndims=ndims, num_steps1=100, num_steps2=300, - num_chains=100, + num_chains=300, mesh=mesh, rng_key=key, alpha=1.9, From ba8f6ebebf4106e844bf305fa53f25c4c9186bb2 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Feb 2025 12:30:44 -0500 Subject: [PATCH 18/63] modify test --- tests/mcmc/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 5d43801d7..e73d3c557 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -312,7 +312,7 @@ def run_emaus( ndims=ndims, num_steps1=100, num_steps2=300, - num_chains=300, + num_chains=400, mesh=mesh, rng_key=key, alpha=1.9, From 29dcd5440ca341fd656de94b690ea9dc7bb2843f Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Feb 2025 12:39:20 -0500 Subject: [PATCH 19/63] modify test --- tests/mcmc/test_sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index e73d3c557..0379bdd74 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -312,7 +312,7 @@ def run_emaus( ndims=ndims, num_steps1=100, num_steps2=300, - num_chains=400, + num_chains=700, mesh=mesh, rng_key=key, alpha=1.9, From 4b7d8b004cabe3b93dd2496d4a36feabe39bff77 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 11 Feb 2025 12:52:30 -0500 Subject: [PATCH 20/63] modify test --- tests/mcmc/test_sampling.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 0379bdd74..a8272cfab 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -312,7 +312,7 @@ def run_emaus( ndims=ndims, num_steps1=100, num_steps2=300, - num_chains=700, + num_chains=800, mesh=mesh, rng_key=key, alpha=1.9, @@ -610,8 +610,10 @@ def sample_init(key): coefs_samples = samples["coefs"] scale_samples = np.exp(samples["log_scale"]) - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) - np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) + print(np.mean(scale_samples), np.mean(coefs_samples), "foo") + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( From 67cfd715b687782103d0cc7e0269d025e0195994 Mon Sep 17 00:00:00 2001 From: = Date: Sat, 22 Feb 2025 15:32:44 -0500 Subject: [PATCH 21/63] clean up and bug fix --- blackjax/adaptation/ensemble_mclmc.py | 4 +--- blackjax/adaptation/ensemble_umclmc.py | 14 +++++++------- blackjax/mcmc/termination.py | 6 +++--- tests/mcmc/test_sampling.py | 2 -- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 363b2df7f..bbc1bb5f7 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -81,7 +81,7 @@ def __init__( # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 # With the current eps, we had sigma^2 = EEVPD * d for N = 1. # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + # adjustment_factor = jnp.power(0.82 / (ndims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) step_size = adaptation_state.step_size # Initialize the bisection for finding the step size @@ -208,8 +208,6 @@ def emaus( key_init, logdensity_fn, sample_init, num_chains, mesh ) - ndims = 2 - # burn-in with the unadjusted method # kernel = umclmc.build_kernel(logdensity_fn) save_num = (int)(jnp.rint(save_frac * num_steps1)) diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index e65f05c50..4068bee67 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -123,7 +123,7 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) - return jnp.concatenate((new_vals[None, :], history[:-1])) + return jnp.concatenate((new_vals[None, :], history[:-1, :])) def update_history_scalar(new_val, history): @@ -194,7 +194,7 @@ def equipartition_fullrank_loss(delta_z): class Adaptation: def __init__( self, - num_dims, + ndims, alpha=1.0, C=0.1, power=3.0 / 8.0, @@ -205,7 +205,7 @@ def __init__( observables_for_bias=lambda x: x, contract=lambda x: 0.0, ): - self.num_dims = num_dims + self.ndims = ndims self.alpha = alpha self.C = C self.power = power @@ -218,15 +218,15 @@ def __init__( r_save_num = save_num history = History( - observables=jnp.zeros((r_save_num, num_dims)), + observables=jnp.zeros((r_save_num, ndims)), stopping=jnp.full((save_num,), jnp.nan), weights=jnp.zeros(r_save_num), ) self.initial_state = AdaptationState( L=jnp.inf, # do not add noise for the first step - inverse_mass_matrix=jnp.ones(num_dims), - step_size=0.01 * jnp.sqrt(num_dims), + inverse_mass_matrix=jnp.ones(ndims), + step_size=0.01 * jnp.sqrt(ndims), step_count=0, EEVPD=1e-3, EEVPD_wanted=1e-3, @@ -277,7 +277,7 @@ def update(self, adaptation_state, Etheta): jnp.sum(Etheta["xsq"] - jnp.square(Etheta["x"])) ) # average over the ensemble, sum over parameters (to get sqrt(d)) inverse_mass_matrix = Etheta["xsq"] - jnp.square(Etheta["x"]) - EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.num_dims + EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.ndims true_bias = self.contract(Etheta["observables_for_bias"]) nans = Etheta["rejection_rate_nans"] > 0.0 # | (~jnp.isfinite(eps_factor)) diff --git a/blackjax/mcmc/termination.py b/blackjax/mcmc/termination.py index eb1276da3..9fa7cee6c 100644 --- a/blackjax/mcmc/termination.py +++ b/blackjax/mcmc/termination.py @@ -33,10 +33,10 @@ def iterative_uturn_numpyro(is_turning: CheckTurning): def new_state(chain_state, max_num_doublings) -> IterativeUTurnState: flat, _ = jax.flatten_util.ravel_pytree(chain_state.position) - num_dims = jnp.shape(flat)[0] + ndims = jnp.shape(flat)[0] return IterativeUTurnState( - jnp.zeros((max_num_doublings, num_dims)), - jnp.zeros((max_num_doublings, num_dims)), + jnp.zeros((max_num_doublings, ndims)), + jnp.zeros((max_num_doublings, ndims)), 0, 0, ) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index a8272cfab..886cdce0d 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -610,8 +610,6 @@ def sample_init(key): coefs_samples = samples["coefs"] scale_samples = np.exp(samples["log_scale"]) - print(np.mean(scale_samples), np.mean(coefs_samples), "foo") - np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) From 5ee0b8a188e131643316249fdaa4c967d35bb31e Mon Sep 17 00:00:00 2001 From: Reuben Date: Thu, 27 Feb 2025 07:03:09 -0500 Subject: [PATCH 22/63] Update blackjax/adaptation/step_size.py Co-authored-by: Junpeng Lao --- blackjax/adaptation/step_size.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 94c634ce3..61ceed592 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -270,12 +270,14 @@ def update(state, exp_x, acc_rate_new): x = jnp.log(exp_x) def on_true(bounds): - bounds0 = jnp.max(jnp.array([bounds[0], x])) - return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift + lower, upper = bounds + lower = jnp.max(jnp.array([lower, x])) + return jnp.array([lower, upper]), lower + reduce_shift def on_false(bounds): - bounds1 = jnp.min(jnp.array([bounds[1], x])) - return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift + lower, upper = bounds + upper = jnp.min(jnp.array([upper, x])) + return jnp.array([lower, upper]), upper - reduce_shift bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) From 2b3362502f799b3e6d06b0cc7e92425e3aa8fc32 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 27 Feb 2025 07:03:32 -0500 Subject: [PATCH 23/63] clean up and bug fix --- blackjax/adaptation/ensemble_mclmc.py | 9 ++++----- blackjax/adaptation/step_size.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index bbc1bb5f7..15d386346 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -61,7 +61,7 @@ def __init__( self, adaptation_state, num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params - steps_per_sample, # L/eps (same for each chain: currently fixed to 15) + steps_per_sample=15, # L/eps acc_prob_target=0.8, observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains @@ -85,9 +85,8 @@ def __init__( step_size = adaptation_state.step_size # Initialize the bisection for finding the step size - stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( - acc_prob_target - ) + self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + stepsize_adaptation_state = ((jnp.array([-jnp.inf, jnp.inf]), False),) self.initial_state = AdaptationState( steps_per_sample, step_size, stepsize_adaptation_state, 0 @@ -173,7 +172,7 @@ def emaus( r_end=5e-3, diagonal_preconditioning=True, integrator_coefficients=None, - steps_per_sample=10, + steps_per_sample=15, acc_prob=None, observables=lambda x: None, ensemble_observables=None, diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 94c634ce3..ee2298c8f 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -298,4 +298,4 @@ def bisect(bounds): return (bounds_new, terminated_new), stepsize - return (jnp.array([-jnp.inf, jnp.inf]), False), update + return update From cc5e09a57a64700069fd63b9bc549e4a164dbda9 Mon Sep 17 00:00:00 2001 From: = Date: Thu, 27 Feb 2025 07:17:23 -0500 Subject: [PATCH 24/63] clean up and bug fix --- blackjax/adaptation/ensemble_mclmc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 15d386346..a43d0eb59 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -86,7 +86,7 @@ def __init__( # Initialize the bisection for finding the step size self.epsadap_update = bisection_monotonic_fn(acc_prob_target) - stepsize_adaptation_state = ((jnp.array([-jnp.inf, jnp.inf]), False),) + stepsize_adaptation_state = (jnp.array([-jnp.inf, jnp.inf]), False) self.initial_state = AdaptationState( steps_per_sample, step_size, stepsize_adaptation_state, 0 From bd40cf9c3a88ecd6599f96f049dc19650244530b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 5 Mar 2025 14:03:45 -0500 Subject: [PATCH 25/63] bug present in minimal_repro_3.py --- blackjax/adaptation/ensemble_mclmc.py | 8 + blackjax/mcmc/alternate_emaus.py | 85 +++++ tests/mcmc/minimal_repro.py | 300 +++++++++++++++ tests/mcmc/minimal_repro_2.py | 419 +++++++++++++++++++++ tests/mcmc/minimal_repro_3.py | 514 ++++++++++++++++++++++++++ tests/mcmc/test_sampling.py | 3 +- 6 files changed, 1328 insertions(+), 1 deletion(-) create mode 100644 blackjax/mcmc/alternate_emaus.py create mode 100644 tests/mcmc/minimal_repro.py create mode 100644 tests/mcmc/minimal_repro_2.py create mode 100644 tests/mcmc/minimal_repro_3.py diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index a43d0eb59..3edae76bd 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -13,6 +13,12 @@ # limitations under the License. # """Public API for the MCLMC Kernel""" +# import jax +# import jax.numpy as jnp +# from blackjax.util import run_eca +# import blackjax.adaptation.ensemble_umclmc as umclmc + + from typing import Any, NamedTuple import jax @@ -286,6 +292,8 @@ def emaus( _acc_prob, ) + + final_state, final_adaptation_state, info2 = run_eca( key_mclmc, initial_state, diff --git a/blackjax/mcmc/alternate_emaus.py b/blackjax/mcmc/alternate_emaus.py new file mode 100644 index 000000000..6010bab73 --- /dev/null +++ b/blackjax/mcmc/alternate_emaus.py @@ -0,0 +1,85 @@ +import jax +import jax.numpy as jnp +from blackjax.util import run_eca +import blackjax.adaptation.ensemble_umclmc as umclmc + + +def emaus( + logdensity_fn, + sample_init, + transform, + ndims, + num_steps1, + num_steps2, + num_chains, + mesh, + rng_key, + alpha=1.9, + save_frac=0.2, + C=0.1, + early_stop=True, + r_end=5e-3, + diagonal_preconditioning=True, + integrator_coefficients=None, + steps_per_sample=15, + acc_prob=None, + observables=lambda x: None, + ensemble_observables=None, + diagnostics=True, +): + """ + model: the target density object + num_steps1: number of steps in the first phase + num_steps2: number of steps in the second phase + num_chains: number of chains + mesh: the mesh object, used for distributing the computation across cpus and nodes + rng_key: the random key + alpha: L = sqrt{d}*alpha*variances + save_frac: the fraction of samples used to estimate the fluctuation in the first phase + C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) + early_stop: whether to stop the first phase early + r_end + diagonal_preconditioning: whether to use diagonal preconditioning + integrator_coefficients: the coefficients of the integrator + steps_per_sample: the number of steps per sample + acc_prob: the acceptance probability + observables: the observables (for diagnostic use) + ensemble_observables: observable calculated over the ensemble (for diagnostic use) + diagnostics: whether to return diagnostics + """ + + # observables_for_bias, contract = bias(model) + key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) + + # initialize the chains + initial_state = umclmc.initialize( + key_init, logdensity_fn, sample_init, num_chains, mesh + ) + + # burn-in with the unadjusted method # + kernel = umclmc.build_kernel(logdensity_fn) + save_num = (int)(jnp.rint(save_frac * num_steps1)) + adap = umclmc.Adaptation( + ndims, + alpha=alpha, + bias_type=3, + save_num=save_num, + C=C, + power=3.0 / 8.0, + r_end=r_end, + observables_for_bias=lambda position: jnp.square( + transform(jax.flatten_util.ravel_pytree(position)[0]) + ), + ) + + final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps1, + num_chains, + mesh, + ensemble_observables, + early_stop=early_stop, + ) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro.py b/tests/mcmc/minimal_repro.py new file mode 100644 index 000000000..639b624fc --- /dev/null +++ b/tests/mcmc/minimal_repro.py @@ -0,0 +1,300 @@ +import jax +import jax.numpy as jnp +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map +from jax.flatten_util import ravel_pytree +from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec +from jax.tree_util import tree_leaves, tree_map + +import blackjax.adaptation.ensemble_umclmc as umclmc + + +def eca_step( + kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +): + """ + Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. + """ + + def _step(state_all, xs): + """This function operates on a single device.""" + ( + state, + adaptation_state, + ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + ( + _, + keys_sampling, + key_adaptation, + ) = xs # keys_sampling.shape = (chains_per_device, ) + + # update the state of all chains on this device + state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + + # combine all the chains to compute expectation values + theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) + Etheta = tree_map( + lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") + / num_chains, + theta, + ) + + # use these to adapt the hyperparameters of the dynamics + adaptation_state, info_to_be_stored = adaptation_update( + adaptation_state, Etheta + ) + + return (state, adaptation_state), info_to_be_stored + + if ensemble_info is not None: + + def step(state_all, xs): + (state, adaptation_state), info_to_be_stored = _step(state_all, xs) + return (state, adaptation_state), ( + info_to_be_stored, + vmap(ensemble_info)(state.position), + ) + + return step + + else: + return _step + + + + +def ensemble_execute_fn( + func, + rng_key, + num_chains, + mesh, + x=None, + args=None, + summary_statistics_fn=lambda y: 0.0, +): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics + """ + p, pscalar = PartitionSpec("chains"), PartitionSpec() + + if x is None: + X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X = x + + adaptation_update = lambda _, Etheta: (Etheta, None) + + _F = eca_step( + func, + lambda y, info, key: summary_statistics_fn(y), + adaptation_update, + num_chains, + ) + + def F(x, keys): + """This function operates on a single device. key is a random key for this device.""" + y, summary_statistics = _F((x, args), (None, keys, None))[0] + return y, summary_statistics + + parallel_execute = shard_map( + F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False + ) + + keys = device_put( + split(rng_key, num_chains), NamedSharding(mesh, p) + ) # random keys, distributed across devices + # apply F in parallel + return parallel_execute(X, keys) + +def run_eca( + rng_key, + initial_state, + kernel, + adaptation, + num_steps, + num_chains, + mesh, + ensemble_info=None, + early_stop=False, +): + """ + Run ensemble chain adaptation (eca) in parallel on multiple devices. + ----------------------------------------------------- + Args: + rng_key: random key + initial_state: initial state of the system + kernel: kernel for the dynamics + adaptation: adaptation object + num_steps: number of steps to run + num_chains: number of chains + mesh: mesh for parallelization + ensemble_info: function that takes the state of the system and returns some information about the ensemble + early_stop: whether to stop early + Returns: + final_state: final state of the system + final_adaptation_state: final adaptation state + info_history: history of the information that was stored at each step (if early_stop is False, then this is None) + """ + + step = eca_step( + kernel, + adaptation.summary_statistics_fn, + adaptation.update, + num_chains, + ensemble_info, + ) + + def all_steps(initial_state, keys_sampling, keys_adaptation): + """This function operates on a single device. key is a random key for this device.""" + + initial_state_all = (initial_state, adaptation.initial_state) + + # run sampling + xs = ( + jnp.arange(num_steps), + keys_sampling.T, + keys_adaptation, + ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + + # ((a, Int) -> (a, Int)) + def step_while(a): + x, i, _ = a + + auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + + output, info = step(x, auxilliary_input) + + return (output, i + 1, info[0].get("while_cond")) + + if early_stop: + final_state_all, i, _ = lax.while_loop( + lambda a: ((a[1] < num_steps) & a[2]), + step_while, + (initial_state_all, 0, True), + ) + info_history = None + + else: + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + + final_state, final_adaptation_state = final_state_all + return ( + final_state, + final_adaptation_state, + info_history, + ) # info history is composed of averages over all chains, so it is a couple of scalars + + p, pscalar = PartitionSpec("chains"), PartitionSpec() + parallel_execute = shard_map( + all_steps, + mesh=mesh, + in_specs=(p, p, pscalar), + out_specs=(p, pscalar, pscalar), + check_rep=False, + ) + + # produce all random keys that will be needed + + key_sampling, key_adaptation = split(rng_key) + num_steps = jnp.array(num_steps).item() + keys_adaptation = split(key_adaptation, num_steps) + distribute_keys = lambda key, shape: device_put( + split(key, shape), NamedSharding(mesh, p) + ) # random keys, distributed across devices + keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + + # run sampling in parallel + final_state, final_adaptation_state, info_history = parallel_execute( + initial_state, keys_sampling, keys_adaptation + ) + + return final_state, final_adaptation_state, info_history + + +mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") + +key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) + +num_chains = 128 +ndims = 2 + +def logdensity_fn(x): + mu2 = 0.03 * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + +def transform(x): + return x + +def sample_init(key): + z = jax.random.normal(key, shape=(2,)) + x0 = 10.0 * z[0] + x1 = 0.03 * (x0**2 - 100) + z[1] + return jnp.array([x0, x1]) + +# initialize the chains +initial_state = umclmc.initialize( + key_init, logdensity_fn, sample_init, num_chains, mesh +) + +alpha = 1.9 +C = 0.1 +r_end=5e-3 +ensemble_observables=lambda x: x + +# burn-in with the unadjusted method # +kernel = umclmc.build_kernel(logdensity_fn) +save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) +adap = umclmc.Adaptation( + ndims, + alpha=alpha, + bias_type=3, + save_num=save_num, + C=C, + power=3.0 / 8.0, + r_end=r_end, + observables_for_bias=lambda position: jnp.square( + transform(jax.flatten_util.ravel_pytree(position)[0]) + ), +) + + +final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + 100, + num_chains, + mesh, + ensemble_observables, + early_stop=True, + ) + + +# a = jnp.array([8.0, 4.0]) + +# def f(rng_key, x, args): +# return x + normal(rng_key, x.shape) + a, a + +# out = ensemble_execute_fn( +# func = f, +# rng_key = jax.random.PRNGKey(0), +# num_chains = 4, +# mesh = mesh, +# x = None, +# args = None, +# summary_statistics_fn = lambda y: a, +# ) + +# print(out) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro_2.py b/tests/mcmc/minimal_repro_2.py new file mode 100644 index 000000000..c73f52681 --- /dev/null +++ b/tests/mcmc/minimal_repro_2.py @@ -0,0 +1,419 @@ +import jax +import jax.numpy as jnp +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map +from jax.flatten_util import ravel_pytree +from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec +from jax.tree_util import tree_leaves, tree_map + +from blackjax.util import run_eca + +import blackjax.adaptation.ensemble_umclmc as umclmc + + +# def eca_step( +# kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +# ): +# """ +# Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. +# """ + +# def _step(state_all, xs): +# """This function operates on a single device.""" +# ( +# state, +# adaptation_state, +# ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. +# ( +# _, +# keys_sampling, +# key_adaptation, +# ) = xs # keys_sampling.shape = (chains_per_device, ) + +# # update the state of all chains on this device +# state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + +# # combine all the chains to compute expectation values +# theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) +# Etheta = tree_map( +# lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") +# / num_chains, +# theta, +# ) + +# # use these to adapt the hyperparameters of the dynamics +# adaptation_state, info_to_be_stored = adaptation_update( +# adaptation_state, Etheta +# ) + +# return (state, adaptation_state), info_to_be_stored + +# if ensemble_info is not None: + +# def step(state_all, xs): +# (state, adaptation_state), info_to_be_stored = _step(state_all, xs) +# return (state, adaptation_state), ( +# info_to_be_stored, +# vmap(ensemble_info)(state.position), +# ) + +# return step + +# else: +# return _step + + + + +# def ensemble_execute_fn( +# func, +# rng_key, +# num_chains, +# mesh, +# x=None, +# args=None, +# summary_statistics_fn=lambda y: 0.0, +# ): +# """Given a sequential function +# func(rng_key, x, args) = y, +# evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. +# Args: +# x: array distributed over all decvices +# args: additional arguments for func, not distributed. +# summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. +# rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + +# Returns: +# y: array distributed over all decvices. Need not be of the same shape as x. +# Etheta: expected values of the summary statistics +# """ +# p, pscalar = PartitionSpec("chains"), PartitionSpec() + +# if x is None: +# X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) +# else: +# X = x + +# adaptation_update = lambda _, Etheta: (Etheta, None) + +# _F = eca_step( +# func, +# lambda y, info, key: summary_statistics_fn(y), +# adaptation_update, +# num_chains, +# ) + +# def F(x, keys): +# """This function operates on a single device. key is a random key for this device.""" +# y, summary_statistics = _F((x, args), (None, keys, None))[0] +# return y, summary_statistics + +# parallel_execute = shard_map( +# F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False +# ) + +# keys = device_put( +# split(rng_key, num_chains), NamedSharding(mesh, p) +# ) # random keys, distributed across devices +# # apply F in parallel +# return parallel_execute(X, keys) + +# def run_eca( +# rng_key, +# initial_state, +# kernel, +# adaptation, +# num_steps, +# num_chains, +# mesh, +# ensemble_info=None, +# early_stop=False, +# ): +# """ +# Run ensemble chain adaptation (eca) in parallel on multiple devices. +# ----------------------------------------------------- +# Args: +# rng_key: random key +# initial_state: initial state of the system +# kernel: kernel for the dynamics +# adaptation: adaptation object +# num_steps: number of steps to run +# num_chains: number of chains +# mesh: mesh for parallelization +# ensemble_info: function that takes the state of the system and returns some information about the ensemble +# early_stop: whether to stop early +# Returns: +# final_state: final state of the system +# final_adaptation_state: final adaptation state +# info_history: history of the information that was stored at each step (if early_stop is False, then this is None) +# """ + +# step = eca_step( +# kernel, +# adaptation.summary_statistics_fn, +# adaptation.update, +# num_chains, +# ensemble_info, +# ) + +# def all_steps(initial_state, keys_sampling, keys_adaptation): +# """This function operates on a single device. key is a random key for this device.""" + +# initial_state_all = (initial_state, adaptation.initial_state) + +# # run sampling +# xs = ( +# jnp.arange(num_steps), +# keys_sampling.T, +# keys_adaptation, +# ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + +# # ((a, Int) -> (a, Int)) +# def step_while(a): +# x, i, _ = a + +# auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + +# output, info = step(x, auxilliary_input) + +# return (output, i + 1, info[0].get("while_cond")) + +# if early_stop: +# final_state_all, i, _ = lax.while_loop( +# lambda a: ((a[1] < num_steps) & a[2]), +# step_while, +# (initial_state_all, 0, True), +# ) +# info_history = None + +# else: +# final_state_all, info_history = lax.scan(step, initial_state_all, xs) + +# final_state, final_adaptation_state = final_state_all +# return ( +# final_state, +# final_adaptation_state, +# info_history, +# ) # info history is composed of averages over all chains, so it is a couple of scalars + +# p, pscalar = PartitionSpec("chains"), PartitionSpec() +# parallel_execute = shard_map( +# all_steps, +# mesh=mesh, +# in_specs=(p, p, pscalar), +# out_specs=(p, pscalar, pscalar), +# check_rep=False, +# ) + +# # produce all random keys that will be needed + +# key_sampling, key_adaptation = split(rng_key) +# num_steps = jnp.array(num_steps).item() +# keys_adaptation = split(key_adaptation, num_steps) +# distribute_keys = lambda key, shape: device_put( +# split(key, shape), NamedSharding(mesh, p) +# ) # random keys, distributed across devices +# keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + +# # run sampling in parallel +# final_state, final_adaptation_state, info_history = parallel_execute( +# initial_state, keys_sampling, keys_adaptation +# ) + +# return final_state, final_adaptation_state, info_history + + +mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") + +# key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) + +num_chains = 128 +ndims = 2 + +def logdensity_fn(x): + mu2 = 0.03 * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + +def transform(x): + return x + +def sample_init(key): + z = jax.random.normal(key, shape=(2,)) + x0 = 10.0 * z[0] + x1 = 0.03 * (x0**2 - 100) + z[1] + return jnp.array([x0, x1]) + +# # initialize the chains +# initial_state = umclmc.initialize( +# key_init, logdensity_fn, sample_init, num_chains, mesh +# ) + +# alpha = 1.9 +# C = 0.1 +# r_end=5e-3 +# ensemble_observables=lambda x: x + +# # burn-in with the unadjusted method # +# kernel = umclmc.build_kernel(logdensity_fn) +# save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) +# adap = umclmc.Adaptation( +# ndims, +# alpha=alpha, +# bias_type=3, +# save_num=save_num, +# C=C, +# power=3.0 / 8.0, +# r_end=r_end, +# observables_for_bias=lambda position: jnp.square( +# transform(jax.flatten_util.ravel_pytree(position)[0]) +# ), +# ) + + +# final_state, final_adaptation_state, info1 = run_eca( +# key_umclmc, +# initial_state, +# kernel, +# adap, +# 100, +# num_chains, +# mesh, +# ensemble_observables, +# early_stop=True, +# ) + +from blackjax.mcmc.integrators import mclachlan_coefficients + +import sys +# sys.path.append(".") +# sys.path.append("../") +from blackjax.adaptation.ensemble_mclmc import emaus +# from blackjax.mcmc.alternate_emaus import emaus + + +# def emaus( +# logdensity_fn, +# sample_init, +# transform, +# ndims, +# num_steps1, +# num_steps2, +# num_chains, +# mesh, +# rng_key, +# alpha=1.9, +# save_frac=0.2, +# C=0.1, +# early_stop=True, +# r_end=5e-3, +# diagonal_preconditioning=True, +# integrator_coefficients=None, +# steps_per_sample=15, +# acc_prob=None, +# observables=lambda x: None, +# ensemble_observables=None, +# diagnostics=True, +# ): +# """ +# model: the target density object +# num_steps1: number of steps in the first phase +# num_steps2: number of steps in the second phase +# num_chains: number of chains +# mesh: the mesh object, used for distributing the computation across cpus and nodes +# rng_key: the random key +# alpha: L = sqrt{d}*alpha*variances +# save_frac: the fraction of samples used to estimate the fluctuation in the first phase +# C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) +# early_stop: whether to stop the first phase early +# r_end +# diagonal_preconditioning: whether to use diagonal preconditioning +# integrator_coefficients: the coefficients of the integrator +# steps_per_sample: the number of steps per sample +# acc_prob: the acceptance probability +# observables: the observables (for diagnostic use) +# ensemble_observables: observable calculated over the ensemble (for diagnostic use) +# diagnostics: whether to return diagnostics +# """ + +# # observables_for_bias, contract = bias(model) +# key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) + +# # initialize the chains +# initial_state = umclmc.initialize( +# key_init, logdensity_fn, sample_init, num_chains, mesh +# ) + +# # burn-in with the unadjusted method # +# kernel = umclmc.build_kernel(logdensity_fn) +# save_num = (int)(jnp.rint(save_frac * num_steps1)) +# adap = umclmc.Adaptation( +# ndims, +# alpha=alpha, +# bias_type=3, +# save_num=save_num, +# C=C, +# power=3.0 / 8.0, +# r_end=r_end, +# observables_for_bias=lambda position: jnp.square( +# transform(jax.flatten_util.ravel_pytree(position)[0]) +# ), +# ) + +# final_state, final_adaptation_state, info1 = run_eca( +# key_umclmc, +# initial_state, +# kernel, +# adap, +# num_steps1, +# num_chains, +# mesh, +# ensemble_observables, +# early_stop=early_stop, +# ) + +key = jax.random.key(0) + +emaus( + logdensity_fn=logdensity_fn, + sample_init=sample_init, + transform=transform, + ndims=ndims, + num_steps1=100, + num_steps2=300, + num_chains=num_chains, + mesh=mesh, + rng_key=key, + alpha=1.9, + C=0.1, + early_stop=1, + r_end=1e-2, + diagonal_preconditioning=True, + integrator_coefficients=mclachlan_coefficients, + steps_per_sample=15, + acc_prob=None, + ensemble_observables=lambda x: x, + # adap=adap, + # kernel=kernel, + # initial_state=initial_state, + # key_umclmc=key_umclmc, + # ensemble_observables = lambda x: vec @ x + ) # run the algorithm + + +# a = jnp.array([8.0, 4.0]) + +# def f(rng_key, x, args): +# return x + normal(rng_key, x.shape) + a, a + +# out = ensemble_execute_fn( +# func = f, +# rng_key = jax.random.PRNGKey(0), +# num_chains = 4, +# mesh = mesh, +# x = None, +# args = None, +# summary_statistics_fn = lambda y: a, +# ) + +# print(out) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro_3.py b/tests/mcmc/minimal_repro_3.py new file mode 100644 index 000000000..56e08caeb --- /dev/null +++ b/tests/mcmc/minimal_repro_3.py @@ -0,0 +1,514 @@ + + +from typing import Any, NamedTuple + +import jax +import jax.numpy as jnp + +import blackjax.adaptation.ensemble_umclmc as umclmc +from blackjax.adaptation.ensemble_umclmc import ( + equipartition_diagonal, + equipartition_diagonal_loss, +) +from blackjax.adaptation.step_size import bisection_monotonic_fn +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.integrators import ( + generate_isokinetic_integrator, + mclachlan_coefficients, + omelyan_coefficients, +) +import jax +import jax.numpy as jnp +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map +from jax.flatten_util import ravel_pytree +from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec +from jax.tree_util import tree_leaves, tree_map + +import blackjax.adaptation.ensemble_umclmc as umclmc + + +def eca_step( + kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None +): + """ + Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. + """ + + def _step(state_all, xs): + """This function operates on a single device.""" + ( + state, + adaptation_state, + ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + ( + _, + keys_sampling, + key_adaptation, + ) = xs # keys_sampling.shape = (chains_per_device, ) + + # update the state of all chains on this device + state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + + # combine all the chains to compute expectation values + theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) + Etheta = tree_map( + lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") + / num_chains, + theta, + ) + + # use these to adapt the hyperparameters of the dynamics + adaptation_state, info_to_be_stored = adaptation_update( + adaptation_state, Etheta + ) + + return (state, adaptation_state), info_to_be_stored + + if ensemble_info is not None: + + def step(state_all, xs): + (state, adaptation_state), info_to_be_stored = _step(state_all, xs) + return (state, adaptation_state), ( + info_to_be_stored, + vmap(ensemble_info)(state.position), + ) + + return step + + else: + return _step + + + + +def ensemble_execute_fn( + func, + rng_key, + num_chains, + mesh, + x=None, + args=None, + summary_statistics_fn=lambda y: 0.0, +): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics + """ + p, pscalar = PartitionSpec("chains"), PartitionSpec() + + if x is None: + X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X = x + + adaptation_update = lambda _, Etheta: (Etheta, None) + + _F = eca_step( + func, + lambda y, info, key: summary_statistics_fn(y), + adaptation_update, + num_chains, + ) + + def F(x, keys): + """This function operates on a single device. key is a random key for this device.""" + y, summary_statistics = _F((x, args), (None, keys, None))[0] + return y, summary_statistics + + parallel_execute = shard_map( + F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False + ) + + keys = device_put( + split(rng_key, num_chains), NamedSharding(mesh, p) + ) # random keys, distributed across devices + # apply F in parallel + return parallel_execute(X, keys) + +def run_eca( + rng_key, + initial_state, + kernel, + adaptation, + num_steps, + num_chains, + mesh, + ensemble_info=None, + early_stop=False, +): + """ + Run ensemble chain adaptation (eca) in parallel on multiple devices. + ----------------------------------------------------- + Args: + rng_key: random key + initial_state: initial state of the system + kernel: kernel for the dynamics + adaptation: adaptation object + num_steps: number of steps to run + num_chains: number of chains + mesh: mesh for parallelization + ensemble_info: function that takes the state of the system and returns some information about the ensemble + early_stop: whether to stop early + Returns: + final_state: final state of the system + final_adaptation_state: final adaptation state + info_history: history of the information that was stored at each step (if early_stop is False, then this is None) + """ + + step = eca_step( + kernel, + adaptation.summary_statistics_fn, + adaptation.update, + num_chains, + ensemble_info, + ) + + def all_steps(initial_state, keys_sampling, keys_adaptation): + """This function operates on a single device. key is a random key for this device.""" + + initial_state_all = (initial_state, adaptation.initial_state) + + # run sampling + xs = ( + jnp.arange(num_steps), + keys_sampling.T, + keys_adaptation, + ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + + # ((a, Int) -> (a, Int)) + def step_while(a): + x, i, _ = a + + auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + + output, info = step(x, auxilliary_input) + + return (output, i + 1, info[0].get("while_cond")) + + if early_stop: + final_state_all, i, _ = lax.while_loop( + lambda a: ((a[1] < num_steps) & a[2]), + step_while, + (initial_state_all, 0, True), + ) + info_history = None + + else: + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + + final_state, final_adaptation_state = final_state_all + return ( + final_state, + final_adaptation_state, + info_history, + ) # info history is composed of averages over all chains, so it is a couple of scalars + + p, pscalar = PartitionSpec("chains"), PartitionSpec() + parallel_execute = shard_map( + all_steps, + mesh=mesh, + in_specs=(p, p, pscalar), + out_specs=(p, pscalar, pscalar), + check_rep=False, + ) + + # produce all random keys that will be needed + + key_sampling, key_adaptation = split(rng_key) + num_steps = jnp.array(num_steps).item() + keys_adaptation = split(key_adaptation, num_steps) + distribute_keys = lambda key, shape: device_put( + split(key, shape), NamedSharding(mesh, p) + ) # random keys, distributed across devices + keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + + # run sampling in parallel + final_state, final_adaptation_state, info_history = parallel_execute( + initial_state, keys_sampling, keys_adaptation + ) + + return final_state, final_adaptation_state, info_history + +# from blackjax.util import run_eca + + + +class AdaptationState(NamedTuple): + steps_per_sample: float + step_size: float + stepsize_adaptation_state: ( + Any # the state of the bisection algorithm to find a stepsize + ) + iteration: int + + +build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( + logdensity_fn=logdensity_fn, + integrator=integrator, + inverse_mass_matrix=inverse_mass_matrix, +)( + rng_key=key, + state=state, + step_size=adap.step_size, + num_integration_steps=adap.steps_per_sample, + L_proposal_factor=1.25, +) + + +class Adaptation: + def __init__( + self, + adaptation_state, + num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params + steps_per_sample=15, # L/eps + acc_prob_target=0.8, + observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep + observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains + contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions + ): + self.num_adaptation_samples = num_adaptation_samples + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + + # Determine the initial hyperparameters # + + # stepsize # + # if we switched to the more accurate integrator we can use longer step size + # integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 + # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 + # adjustment_factor = jnp.power(0.82 / (ndims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adaptation_state.step_size + + # Initialize the bisection for finding the step size + self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + stepsize_adaptation_state = (jnp.array([-jnp.inf, jnp.inf]), False) + + self.initial_state = AdaptationState( + steps_per_sample, step_size, stepsize_adaptation_state, 0 + ) + + def summary_statistics_fn(self, state, info, rng_key): + return { + "acceptance_probability": info.acceptance_rate, + "equipartition_diagonal": equipartition_diagonal( + state + ), # metric for bias: equipartition theorem gives todo... + "observables": self.observables(state.position), + "observables_for_bias": self.observables_for_bias(state.position), + } + + def update(self, adaptation_state, Etheta): + acc_prob = Etheta["acceptance_probability"] + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + true_bias = self.contract(Etheta["observables_for_bias"]) # remove + + info_to_be_stored = { + "L": adaptation_state.step_size * adaptation_state.steps_per_sample, + "steps_per_sample": adaptation_state.steps_per_sample, + "step_size": adaptation_state.step_size, + "acc_prob": acc_prob, + "equi_diag": equi_diag, + "bias": true_bias, + "observables": Etheta["observables"], + } + + # Bisection to find step size + stepsize_adaptation_state, step_size = self.epsadap_update( + adaptation_state.stepsize_adaptation_state, + adaptation_state.step_size, + acc_prob, + ) + + return ( + AdaptationState( + adaptation_state.steps_per_sample, + step_size, + stepsize_adaptation_state, + adaptation_state.iteration + 1, + ), + info_to_be_stored, + ) + + +def bias(model): + """should be transfered to benchmarks/""" + + def observables(position): + return jnp.square(model.transform(position)) + + def contract(sampler_E_x2): + bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 + return jnp.array([jnp.max(bsq), jnp.average(bsq)]) + + return observables, contract + + +def while_steps_num(cond): + if jnp.all(cond): + return len(cond) + else: + return jnp.argmin(cond) + 1 + + + +def logdensity_fn(x): + mu2 = 0.03 * (x[0] ** 2 - 100) + return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) + +def transform(x): + return x + +def sample_init(key): + z = jax.random.normal(key, shape=(2,)) + x0 = 10.0 * z[0] + x1 = 0.03 * (x0**2 - 100) + z[1] + return jnp.array([x0, x1]) + +num_chains = 128 + +mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") + +key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) + +integrator_coefficients = mclachlan_coefficients + +acc_prob = None + +# initialize the chains +initial_state = umclmc.initialize( + key_init, logdensity_fn, sample_init, num_chains, mesh +) + +diagonal_preconditioning = False +ndims = 2 + +alpha = 1.9 +C = 0.1 +r_end=5e-3 +ensemble_observables=lambda x: x + +# burn-in with the unadjusted method # +kernel = umclmc.build_kernel(logdensity_fn) +save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) +adap = umclmc.Adaptation( + ndims, + alpha=alpha, + bias_type=3, + save_num=save_num, + C=C, + power=3.0 / 8.0, + r_end=r_end, + observables_for_bias=lambda position: jnp.square( + transform(jax.flatten_util.ravel_pytree(position)[0]) + ), +) + +final_state, final_adaptation_state, info1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + 100, + num_chains, + mesh, + ensemble_observables, + early_stop=True, + ) + +# refine the results with the adjusted method +_acc_prob = acc_prob +if integrator_coefficients is None: + high_dims = ndims > 200 + _integrator_coefficients = ( + omelyan_coefficients if high_dims else mclachlan_coefficients + ) + if acc_prob is None: + _acc_prob = 0.9 if high_dims else 0.7 + +else: + _integrator_coefficients = integrator_coefficients + if acc_prob is None: + _acc_prob = 0.9 + +integrator = generate_isokinetic_integrator(_integrator_coefficients) +gradient_calls_per_step = ( + len(_integrator_coefficients) // 2 +) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + +if diagonal_preconditioning: + inverse_mass_matrix = jnp.sqrt(final_adaptation_state.inverse_mass_matrix) + + # scale the stepsize so that it reflects averag scale change of the preconditioning + average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size / average_scale_change + ) + +else: + inverse_mass_matrix = 1.0 + +kernel = build_kernel( + logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix +) +steps_per_sample = 15 +num_steps2 = 100 + + +initial_state = HMCState( + final_state.position, final_state.logdensity, final_state.logdensity_grad + ) + +print(initial_state.position.shape, "bar\n\n") + +# pos = jax.random.normal(key_mclmc, shape=(num_chains, ndims)) + + + +# print("baz", logdensity_fn(pos)) + +# initial_state = HMCState( +# pos, logdensity_fn(pos[0]), jax.grad(logdensity_fn)(pos[0]) +# ) + + +num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) +num_adaptation_samples = ( + num_samples // 2 +) # number of samples after which the stepsize is fixed. + +adap = Adaptation( + final_adaptation_state, + num_adaptation_samples, + steps_per_sample, + _acc_prob, +) + + + +final_state, final_adaptation_state, info2 = run_eca( + key_mclmc, + initial_state, + kernel, + adap, + num_samples, + num_chains, + mesh, + ensemble_observables, +) + diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 886cdce0d..57c8aedf7 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -5,6 +5,7 @@ import chex import jax +# jax.config.update("jax_traceback_filtering", "off") import jax.numpy as jnp import jax.scipy.stats as stats import numpy as np @@ -299,7 +300,7 @@ def run_emaus( key, diagonal_preconditioning, ): - mesh = jax.sharding.Mesh(jax.devices(), "chains") + mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") from blackjax.mcmc.integrators import mclachlan_coefficients From f7f8d86ed633e431872154b6a5e47e16bb668bed Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Thu, 6 Mar 2025 08:12:14 -0800 Subject: [PATCH 26/63] wip --- blackjax/adaptation/ensemble_mclmc.py | 18 ++++++++++-------- blackjax/adaptation/ensemble_umclmc.py | 4 ++-- blackjax/adaptation/mclmc_adaptation.py | 4 +++- blackjax/util.py | 10 +++++++++- 4 files changed, 24 insertions(+), 12 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index bbc1bb5f7..cc5ff60c9 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -65,7 +65,7 @@ def __init__( acc_prob_target=0.8, observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains - contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions + contract=lambda x: 0.0, # just for diagnostics: observables for bias, contracted over dimensions ): self.num_adaptation_samples = num_adaptation_samples self.observables = observables @@ -106,7 +106,7 @@ def summary_statistics_fn(self, state, info, rng_key): def update(self, adaptation_state, Etheta): acc_prob = Etheta["acceptance_probability"] equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) - true_bias = self.contract(Etheta["observables_for_bias"]) # remove + true_bias = self.contract(Etheta["observables_for_bias"]) info_to_be_stored = { "L": adaptation_state.step_size * adaptation_state.steps_per_sample, @@ -159,7 +159,6 @@ def while_steps_num(cond): def emaus( logdensity_fn, sample_init, - transform, ndims, num_steps1, num_steps2, @@ -175,9 +174,10 @@ def emaus( integrator_coefficients=None, steps_per_sample=10, acc_prob=None, - observables=lambda x: None, + observables_for_bias=lambda x: 0.0, ensemble_observables=None, diagnostics=True, + contract=lambda x: 0.0, ): """ model: the target density object @@ -210,7 +210,7 @@ def emaus( # burn-in with the unadjusted method # kernel = umclmc.build_kernel(logdensity_fn) - save_num = (int)(jnp.rint(save_frac * num_steps1)) + save_num = (jnp.rint(save_frac * num_steps1)).astype(int) adap = umclmc.Adaptation( ndims, alpha=alpha, @@ -219,9 +219,8 @@ def emaus( C=C, power=3.0 / 8.0, r_end=r_end, - observables_for_bias=lambda position: jnp.square( - transform(jax.flatten_util.ravel_pytree(position)[0]) - ), + observables_for_bias=observables_for_bias, + contract=contract, ) final_state, final_adaptation_state, info1 = run_eca( @@ -285,8 +284,11 @@ def emaus( num_adaptation_samples, steps_per_sample, _acc_prob, + contract=contract, + observables_for_bias=observables_for_bias, ) + final_state, final_adaptation_state, info2 = run_eca( key_mclmc, initial_state, diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 4068bee67..1f1bf518b 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -72,6 +72,7 @@ def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): def sequential_init(key, x, args): """initialize the position using sample_init and the velocity along the gradient""" position = sample_init(key) + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) flat_g, unravel_fn = ravel_pytree(logdensity_grad) velocity = unravel_fn( @@ -82,9 +83,8 @@ def sequential_init(key, x, args): def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" - return 0 # -state.position * state.logdensity_grad + return -state.position * state.logdensity_grad - # TODO: restore! def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 60fd46359..16612c05c 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -50,6 +50,7 @@ def mclmc_find_L_and_step_size( desired_energy_var=5e-4, trust_in_estimate=1.5, num_effective_samples=150, + params=None, diagonal_preconditioning=True, ): """ @@ -105,7 +106,8 @@ def mclmc_find_L_and_step_size( ) """ dim = pytree_size(state.position) - params = MCLMCAdaptationState( + if params is None: + params = MCLMCAdaptationState( jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) ) part1_key, part2_key = jax.random.split(rng_key, 2) diff --git a/blackjax/util.py b/blackjax/util.py index ee71af2b9..4668befee 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -10,7 +10,7 @@ from jax.random import normal, split from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map - +import jax from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -352,11 +352,14 @@ def _step(state_all, xs): adaptation_state, info_to_be_stored = adaptation_update( adaptation_state, Etheta ) + return (state, adaptation_state), info_to_be_stored + if ensemble_info is not None: + def step(state_all, xs): (state, adaptation_state), info_to_be_stored = _step(state_all, xs) return (state, adaptation_state), ( @@ -381,6 +384,7 @@ def run_eca( ensemble_info=None, early_stop=False, ): + """ Run ensemble chain adaptation (eca) in parallel on multiple devices. ----------------------------------------------------- @@ -413,6 +417,7 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): initial_state_all = (initial_state, adaptation.initial_state) + # run sampling xs = ( jnp.arange(num_steps), @@ -441,6 +446,8 @@ def step_while(a): else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) + + final_state, final_adaptation_state = final_state_all return ( final_state, @@ -448,6 +455,7 @@ def step_while(a): info_history, ) # info history is composed of averages over all chains, so it is a couple of scalars + p, pscalar = PartitionSpec("chains"), PartitionSpec() parallel_execute = shard_map( all_steps, From a5eb4f4fc1a3bb3561c852822fbd5a5894191395 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Thu, 6 Mar 2025 08:19:06 -0800 Subject: [PATCH 27/63] wip --- blackjax/adaptation/mclmc_adaptation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 2a2d13831..5a673414f 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -52,7 +52,6 @@ def mclmc_find_L_and_step_size( num_effective_samples=150, params=None, diagonal_preconditioning=True, - params=None, ): """ Finds the optimal value of the parameters for the MCLMC algorithm. From f40489803bf17e5da42fbba4aead158c50a24f89 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 13:07:43 -0400 Subject: [PATCH 28/63] bug fix --- blackjax/adaptation/ensemble_mclmc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 3edae76bd..ea5f51c48 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -285,6 +285,10 @@ def emaus( num_samples // 2 ) # number of samples after which the stepsize is fixed. + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size.item() + ) + adap = Adaptation( final_adaptation_state, num_adaptation_samples, @@ -292,8 +296,6 @@ def emaus( _acc_prob, ) - - final_state, final_adaptation_state, info2 = run_eca( key_mclmc, initial_state, From 9b00e2811e866948bc00e41f3d504bfcd4550ce3 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 13:28:51 -0400 Subject: [PATCH 29/63] bug fix --- blackjax/adaptation/ensemble_mclmc.py | 10 +++------- blackjax/adaptation/ensemble_umclmc.py | 14 +++++++++----- blackjax/util.py | 10 +--------- tests/mcmc/test_sampling.py | 6 ++---- 4 files changed, 15 insertions(+), 25 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index e153f8a65..22de73887 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -101,9 +101,7 @@ def __init__( def summary_statistics_fn(self, state, info, rng_key): return { "acceptance_probability": info.acceptance_rate, - "equipartition_diagonal": equipartition_diagonal( - state - ), # metric for bias: equipartition theorem gives todo... + "equipartition_diagonal": equipartition_diagonal(state), "observables": self.observables(state.position), "observables_for_bias": self.observables_for_bias(state.position), } @@ -111,7 +109,7 @@ def summary_statistics_fn(self, state, info, rng_key): def update(self, adaptation_state, Etheta): acc_prob = Etheta["acceptance_probability"] equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) - true_bias = self.contract(Etheta["observables_for_bias"]) + true_bias = self.contract(Etheta["observables_for_bias"]) info_to_be_stored = { "L": adaptation_state.step_size * adaptation_state.steps_per_sample, @@ -179,7 +177,7 @@ def emaus( integrator_coefficients=None, steps_per_sample=15, acc_prob=None, - observables_for_bias=lambda x: 0.0, + observables_for_bias=lambda x: x, ensemble_observables=None, diagnostics=True, contract=lambda x: 0.0, @@ -205,7 +203,6 @@ def emaus( diagnostics: whether to return diagnostics """ - # observables_for_bias, contract = bias(model) key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) # initialize the chains @@ -297,7 +294,6 @@ def emaus( observables_for_bias=observables_for_bias, ) - final_state, final_adaptation_state, info2 = run_eca( key_mclmc, initial_state, diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 1f1bf518b..3cc62ae00 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -72,7 +72,7 @@ def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): def sequential_init(key, x, args): """initialize the position using sample_init and the velocity along the gradient""" position = sample_init(key) - + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) flat_g, unravel_fn = ravel_pytree(logdensity_grad) velocity = unravel_fn( @@ -83,8 +83,12 @@ def sequential_init(key, x, args): def summary_statistics_fn(state): """compute the diagonal elements of the equipartition matrix""" - return -state.position * state.logdensity_grad + flat_pos, unflatten = jax.flatten_util.ravel_pytree(state.position) + flat_g, unravel_fn = ravel_pytree(state.logdensity_grad) + return unravel_fn(-flat_pos * flat_g) + # return 0 + # -state.position # * state.logdensity_grad def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" @@ -113,7 +117,9 @@ def ensemble_init(key, state, signs): summary_statistics_fn=summary_statistics_fn, ) - signs = -2.0 * (equipartition < 1.0) + 1.0 + flat_equi, _ = ravel_pytree(equipartition) + + signs = -2.0 * (flat_equi < 1.0) + 1.0 initial_state, _ = ensemble_execute_fn( ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs ) @@ -122,7 +128,6 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): - new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) return jnp.concatenate((new_vals[None, :], history[:-1, :])) @@ -258,7 +263,6 @@ def update(self, adaptation_state, Etheta): history_observables = update_history( Etheta["observables_for_bias"], adaptation_state.history.observables ) - # history_observables = adaptation_state.history.observables history_weights = update_history_scalar(1.0, adaptation_state.history.weights) fluctuations = contract_history(history_observables, history_weights) diff --git a/blackjax/util.py b/blackjax/util.py index 4668befee..ee71af2b9 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -10,7 +10,7 @@ from jax.random import normal, split from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map -import jax + from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -352,14 +352,11 @@ def _step(state_all, xs): adaptation_state, info_to_be_stored = adaptation_update( adaptation_state, Etheta ) - return (state, adaptation_state), info_to_be_stored - if ensemble_info is not None: - def step(state_all, xs): (state, adaptation_state), info_to_be_stored = _step(state_all, xs) return (state, adaptation_state), ( @@ -384,7 +381,6 @@ def run_eca( ensemble_info=None, early_stop=False, ): - """ Run ensemble chain adaptation (eca) in parallel on multiple devices. ----------------------------------------------------- @@ -417,7 +413,6 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): initial_state_all = (initial_state, adaptation.initial_state) - # run sampling xs = ( jnp.arange(num_steps), @@ -446,8 +441,6 @@ def step_while(a): else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) - - final_state, final_adaptation_state = final_state_all return ( final_state, @@ -455,7 +448,6 @@ def step_while(a): info_history, ) # info history is composed of averages over all chains, so it is a couple of scalars - p, pscalar = PartitionSpec("chains"), PartitionSpec() parallel_execute = shard_map( all_steps, diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 57c8aedf7..d2cbd1501 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -5,6 +5,7 @@ import chex import jax + # jax.config.update("jax_traceback_filtering", "off") import jax.numpy as jnp import jax.scipy.stats as stats @@ -296,11 +297,10 @@ def run_emaus( sample_init, logdensity_fn, ndims, - transform, key, diagonal_preconditioning, ): - mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") + mesh = jax.sharding.Mesh(devices=jax.devices(), axis_names="chains") from blackjax.mcmc.integrators import mclachlan_coefficients @@ -309,7 +309,6 @@ def run_emaus( info, grads_per_step, _acc_prob, final_state = emaus( logdensity_fn=logdensity_fn, sample_init=sample_init, - transform=transform, ndims=ndims, num_steps1=100, num_steps2=300, @@ -602,7 +601,6 @@ def sample_init(key): samples = self.run_emaus( sample_init=sample_init, logdensity_fn=logdensity_fn, - transform=lambda x: x, ndims=2, key=inference_key, diagonal_preconditioning=True, From f35f98eaffa8dc8ffbc112a4dbf5cf010f7ab7e6 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 13:33:42 -0400 Subject: [PATCH 30/63] bug fix --- tests/mcmc/minimal_repro.py | 300 -------------------- tests/mcmc/minimal_repro_2.py | 419 --------------------------- tests/mcmc/minimal_repro_3.py | 514 ---------------------------------- 3 files changed, 1233 deletions(-) delete mode 100644 tests/mcmc/minimal_repro.py delete mode 100644 tests/mcmc/minimal_repro_2.py delete mode 100644 tests/mcmc/minimal_repro_3.py diff --git a/tests/mcmc/minimal_repro.py b/tests/mcmc/minimal_repro.py deleted file mode 100644 index 639b624fc..000000000 --- a/tests/mcmc/minimal_repro.py +++ /dev/null @@ -1,300 +0,0 @@ -import jax -import jax.numpy as jnp -from jax import device_put, jit, lax, vmap -from jax.experimental.shard_map import shard_map -from jax.flatten_util import ravel_pytree -from jax.random import normal, split -from jax.sharding import NamedSharding, PartitionSpec -from jax.tree_util import tree_leaves, tree_map - -import blackjax.adaptation.ensemble_umclmc as umclmc - - -def eca_step( - kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None -): - """ - Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. - """ - - def _step(state_all, xs): - """This function operates on a single device.""" - ( - state, - adaptation_state, - ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. - ( - _, - keys_sampling, - key_adaptation, - ) = xs # keys_sampling.shape = (chains_per_device, ) - - # update the state of all chains on this device - state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - - # combine all the chains to compute expectation values - theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) - Etheta = tree_map( - lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") - / num_chains, - theta, - ) - - # use these to adapt the hyperparameters of the dynamics - adaptation_state, info_to_be_stored = adaptation_update( - adaptation_state, Etheta - ) - - return (state, adaptation_state), info_to_be_stored - - if ensemble_info is not None: - - def step(state_all, xs): - (state, adaptation_state), info_to_be_stored = _step(state_all, xs) - return (state, adaptation_state), ( - info_to_be_stored, - vmap(ensemble_info)(state.position), - ) - - return step - - else: - return _step - - - - -def ensemble_execute_fn( - func, - rng_key, - num_chains, - mesh, - x=None, - args=None, - summary_statistics_fn=lambda y: 0.0, -): - """Given a sequential function - func(rng_key, x, args) = y, - evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. - Args: - x: array distributed over all decvices - args: additional arguments for func, not distributed. - summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. - rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - - Returns: - y: array distributed over all decvices. Need not be of the same shape as x. - Etheta: expected values of the summary statistics - """ - p, pscalar = PartitionSpec("chains"), PartitionSpec() - - if x is None: - X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) - else: - X = x - - adaptation_update = lambda _, Etheta: (Etheta, None) - - _F = eca_step( - func, - lambda y, info, key: summary_statistics_fn(y), - adaptation_update, - num_chains, - ) - - def F(x, keys): - """This function operates on a single device. key is a random key for this device.""" - y, summary_statistics = _F((x, args), (None, keys, None))[0] - return y, summary_statistics - - parallel_execute = shard_map( - F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False - ) - - keys = device_put( - split(rng_key, num_chains), NamedSharding(mesh, p) - ) # random keys, distributed across devices - # apply F in parallel - return parallel_execute(X, keys) - -def run_eca( - rng_key, - initial_state, - kernel, - adaptation, - num_steps, - num_chains, - mesh, - ensemble_info=None, - early_stop=False, -): - """ - Run ensemble chain adaptation (eca) in parallel on multiple devices. - ----------------------------------------------------- - Args: - rng_key: random key - initial_state: initial state of the system - kernel: kernel for the dynamics - adaptation: adaptation object - num_steps: number of steps to run - num_chains: number of chains - mesh: mesh for parallelization - ensemble_info: function that takes the state of the system and returns some information about the ensemble - early_stop: whether to stop early - Returns: - final_state: final state of the system - final_adaptation_state: final adaptation state - info_history: history of the information that was stored at each step (if early_stop is False, then this is None) - """ - - step = eca_step( - kernel, - adaptation.summary_statistics_fn, - adaptation.update, - num_chains, - ensemble_info, - ) - - def all_steps(initial_state, keys_sampling, keys_adaptation): - """This function operates on a single device. key is a random key for this device.""" - - initial_state_all = (initial_state, adaptation.initial_state) - - # run sampling - xs = ( - jnp.arange(num_steps), - keys_sampling.T, - keys_adaptation, - ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - - # ((a, Int) -> (a, Int)) - def step_while(a): - x, i, _ = a - - auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - - output, info = step(x, auxilliary_input) - - return (output, i + 1, info[0].get("while_cond")) - - if early_stop: - final_state_all, i, _ = lax.while_loop( - lambda a: ((a[1] < num_steps) & a[2]), - step_while, - (initial_state_all, 0, True), - ) - info_history = None - - else: - final_state_all, info_history = lax.scan(step, initial_state_all, xs) - - final_state, final_adaptation_state = final_state_all - return ( - final_state, - final_adaptation_state, - info_history, - ) # info history is composed of averages over all chains, so it is a couple of scalars - - p, pscalar = PartitionSpec("chains"), PartitionSpec() - parallel_execute = shard_map( - all_steps, - mesh=mesh, - in_specs=(p, p, pscalar), - out_specs=(p, pscalar, pscalar), - check_rep=False, - ) - - # produce all random keys that will be needed - - key_sampling, key_adaptation = split(rng_key) - num_steps = jnp.array(num_steps).item() - keys_adaptation = split(key_adaptation, num_steps) - distribute_keys = lambda key, shape: device_put( - split(key, shape), NamedSharding(mesh, p) - ) # random keys, distributed across devices - keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - - # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute( - initial_state, keys_sampling, keys_adaptation - ) - - return final_state, final_adaptation_state, info_history - - -mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") - -key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) - -num_chains = 128 -ndims = 2 - -def logdensity_fn(x): - mu2 = 0.03 * (x[0] ** 2 - 100) - return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) - -def transform(x): - return x - -def sample_init(key): - z = jax.random.normal(key, shape=(2,)) - x0 = 10.0 * z[0] - x1 = 0.03 * (x0**2 - 100) + z[1] - return jnp.array([x0, x1]) - -# initialize the chains -initial_state = umclmc.initialize( - key_init, logdensity_fn, sample_init, num_chains, mesh -) - -alpha = 1.9 -C = 0.1 -r_end=5e-3 -ensemble_observables=lambda x: x - -# burn-in with the unadjusted method # -kernel = umclmc.build_kernel(logdensity_fn) -save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) -adap = umclmc.Adaptation( - ndims, - alpha=alpha, - bias_type=3, - save_num=save_num, - C=C, - power=3.0 / 8.0, - r_end=r_end, - observables_for_bias=lambda position: jnp.square( - transform(jax.flatten_util.ravel_pytree(position)[0]) - ), -) - - -final_state, final_adaptation_state, info1 = run_eca( - key_umclmc, - initial_state, - kernel, - adap, - 100, - num_chains, - mesh, - ensemble_observables, - early_stop=True, - ) - - -# a = jnp.array([8.0, 4.0]) - -# def f(rng_key, x, args): -# return x + normal(rng_key, x.shape) + a, a - -# out = ensemble_execute_fn( -# func = f, -# rng_key = jax.random.PRNGKey(0), -# num_chains = 4, -# mesh = mesh, -# x = None, -# args = None, -# summary_statistics_fn = lambda y: a, -# ) - -# print(out) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro_2.py b/tests/mcmc/minimal_repro_2.py deleted file mode 100644 index c73f52681..000000000 --- a/tests/mcmc/minimal_repro_2.py +++ /dev/null @@ -1,419 +0,0 @@ -import jax -import jax.numpy as jnp -from jax import device_put, jit, lax, vmap -from jax.experimental.shard_map import shard_map -from jax.flatten_util import ravel_pytree -from jax.random import normal, split -from jax.sharding import NamedSharding, PartitionSpec -from jax.tree_util import tree_leaves, tree_map - -from blackjax.util import run_eca - -import blackjax.adaptation.ensemble_umclmc as umclmc - - -# def eca_step( -# kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None -# ): -# """ -# Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. -# """ - -# def _step(state_all, xs): -# """This function operates on a single device.""" -# ( -# state, -# adaptation_state, -# ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. -# ( -# _, -# keys_sampling, -# key_adaptation, -# ) = xs # keys_sampling.shape = (chains_per_device, ) - -# # update the state of all chains on this device -# state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - -# # combine all the chains to compute expectation values -# theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) -# Etheta = tree_map( -# lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") -# / num_chains, -# theta, -# ) - -# # use these to adapt the hyperparameters of the dynamics -# adaptation_state, info_to_be_stored = adaptation_update( -# adaptation_state, Etheta -# ) - -# return (state, adaptation_state), info_to_be_stored - -# if ensemble_info is not None: - -# def step(state_all, xs): -# (state, adaptation_state), info_to_be_stored = _step(state_all, xs) -# return (state, adaptation_state), ( -# info_to_be_stored, -# vmap(ensemble_info)(state.position), -# ) - -# return step - -# else: -# return _step - - - - -# def ensemble_execute_fn( -# func, -# rng_key, -# num_chains, -# mesh, -# x=None, -# args=None, -# summary_statistics_fn=lambda y: 0.0, -# ): -# """Given a sequential function -# func(rng_key, x, args) = y, -# evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. -# Args: -# x: array distributed over all decvices -# args: additional arguments for func, not distributed. -# summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. -# rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - -# Returns: -# y: array distributed over all decvices. Need not be of the same shape as x. -# Etheta: expected values of the summary statistics -# """ -# p, pscalar = PartitionSpec("chains"), PartitionSpec() - -# if x is None: -# X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) -# else: -# X = x - -# adaptation_update = lambda _, Etheta: (Etheta, None) - -# _F = eca_step( -# func, -# lambda y, info, key: summary_statistics_fn(y), -# adaptation_update, -# num_chains, -# ) - -# def F(x, keys): -# """This function operates on a single device. key is a random key for this device.""" -# y, summary_statistics = _F((x, args), (None, keys, None))[0] -# return y, summary_statistics - -# parallel_execute = shard_map( -# F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False -# ) - -# keys = device_put( -# split(rng_key, num_chains), NamedSharding(mesh, p) -# ) # random keys, distributed across devices -# # apply F in parallel -# return parallel_execute(X, keys) - -# def run_eca( -# rng_key, -# initial_state, -# kernel, -# adaptation, -# num_steps, -# num_chains, -# mesh, -# ensemble_info=None, -# early_stop=False, -# ): -# """ -# Run ensemble chain adaptation (eca) in parallel on multiple devices. -# ----------------------------------------------------- -# Args: -# rng_key: random key -# initial_state: initial state of the system -# kernel: kernel for the dynamics -# adaptation: adaptation object -# num_steps: number of steps to run -# num_chains: number of chains -# mesh: mesh for parallelization -# ensemble_info: function that takes the state of the system and returns some information about the ensemble -# early_stop: whether to stop early -# Returns: -# final_state: final state of the system -# final_adaptation_state: final adaptation state -# info_history: history of the information that was stored at each step (if early_stop is False, then this is None) -# """ - -# step = eca_step( -# kernel, -# adaptation.summary_statistics_fn, -# adaptation.update, -# num_chains, -# ensemble_info, -# ) - -# def all_steps(initial_state, keys_sampling, keys_adaptation): -# """This function operates on a single device. key is a random key for this device.""" - -# initial_state_all = (initial_state, adaptation.initial_state) - -# # run sampling -# xs = ( -# jnp.arange(num_steps), -# keys_sampling.T, -# keys_adaptation, -# ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - -# # ((a, Int) -> (a, Int)) -# def step_while(a): -# x, i, _ = a - -# auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - -# output, info = step(x, auxilliary_input) - -# return (output, i + 1, info[0].get("while_cond")) - -# if early_stop: -# final_state_all, i, _ = lax.while_loop( -# lambda a: ((a[1] < num_steps) & a[2]), -# step_while, -# (initial_state_all, 0, True), -# ) -# info_history = None - -# else: -# final_state_all, info_history = lax.scan(step, initial_state_all, xs) - -# final_state, final_adaptation_state = final_state_all -# return ( -# final_state, -# final_adaptation_state, -# info_history, -# ) # info history is composed of averages over all chains, so it is a couple of scalars - -# p, pscalar = PartitionSpec("chains"), PartitionSpec() -# parallel_execute = shard_map( -# all_steps, -# mesh=mesh, -# in_specs=(p, p, pscalar), -# out_specs=(p, pscalar, pscalar), -# check_rep=False, -# ) - -# # produce all random keys that will be needed - -# key_sampling, key_adaptation = split(rng_key) -# num_steps = jnp.array(num_steps).item() -# keys_adaptation = split(key_adaptation, num_steps) -# distribute_keys = lambda key, shape: device_put( -# split(key, shape), NamedSharding(mesh, p) -# ) # random keys, distributed across devices -# keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - -# # run sampling in parallel -# final_state, final_adaptation_state, info_history = parallel_execute( -# initial_state, keys_sampling, keys_adaptation -# ) - -# return final_state, final_adaptation_state, info_history - - -mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") - -# key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) - -num_chains = 128 -ndims = 2 - -def logdensity_fn(x): - mu2 = 0.03 * (x[0] ** 2 - 100) - return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) - -def transform(x): - return x - -def sample_init(key): - z = jax.random.normal(key, shape=(2,)) - x0 = 10.0 * z[0] - x1 = 0.03 * (x0**2 - 100) + z[1] - return jnp.array([x0, x1]) - -# # initialize the chains -# initial_state = umclmc.initialize( -# key_init, logdensity_fn, sample_init, num_chains, mesh -# ) - -# alpha = 1.9 -# C = 0.1 -# r_end=5e-3 -# ensemble_observables=lambda x: x - -# # burn-in with the unadjusted method # -# kernel = umclmc.build_kernel(logdensity_fn) -# save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) -# adap = umclmc.Adaptation( -# ndims, -# alpha=alpha, -# bias_type=3, -# save_num=save_num, -# C=C, -# power=3.0 / 8.0, -# r_end=r_end, -# observables_for_bias=lambda position: jnp.square( -# transform(jax.flatten_util.ravel_pytree(position)[0]) -# ), -# ) - - -# final_state, final_adaptation_state, info1 = run_eca( -# key_umclmc, -# initial_state, -# kernel, -# adap, -# 100, -# num_chains, -# mesh, -# ensemble_observables, -# early_stop=True, -# ) - -from blackjax.mcmc.integrators import mclachlan_coefficients - -import sys -# sys.path.append(".") -# sys.path.append("../") -from blackjax.adaptation.ensemble_mclmc import emaus -# from blackjax.mcmc.alternate_emaus import emaus - - -# def emaus( -# logdensity_fn, -# sample_init, -# transform, -# ndims, -# num_steps1, -# num_steps2, -# num_chains, -# mesh, -# rng_key, -# alpha=1.9, -# save_frac=0.2, -# C=0.1, -# early_stop=True, -# r_end=5e-3, -# diagonal_preconditioning=True, -# integrator_coefficients=None, -# steps_per_sample=15, -# acc_prob=None, -# observables=lambda x: None, -# ensemble_observables=None, -# diagnostics=True, -# ): -# """ -# model: the target density object -# num_steps1: number of steps in the first phase -# num_steps2: number of steps in the second phase -# num_chains: number of chains -# mesh: the mesh object, used for distributing the computation across cpus and nodes -# rng_key: the random key -# alpha: L = sqrt{d}*alpha*variances -# save_frac: the fraction of samples used to estimate the fluctuation in the first phase -# C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) -# early_stop: whether to stop the first phase early -# r_end -# diagonal_preconditioning: whether to use diagonal preconditioning -# integrator_coefficients: the coefficients of the integrator -# steps_per_sample: the number of steps per sample -# acc_prob: the acceptance probability -# observables: the observables (for diagnostic use) -# ensemble_observables: observable calculated over the ensemble (for diagnostic use) -# diagnostics: whether to return diagnostics -# """ - -# # observables_for_bias, contract = bias(model) -# key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) - -# # initialize the chains -# initial_state = umclmc.initialize( -# key_init, logdensity_fn, sample_init, num_chains, mesh -# ) - -# # burn-in with the unadjusted method # -# kernel = umclmc.build_kernel(logdensity_fn) -# save_num = (int)(jnp.rint(save_frac * num_steps1)) -# adap = umclmc.Adaptation( -# ndims, -# alpha=alpha, -# bias_type=3, -# save_num=save_num, -# C=C, -# power=3.0 / 8.0, -# r_end=r_end, -# observables_for_bias=lambda position: jnp.square( -# transform(jax.flatten_util.ravel_pytree(position)[0]) -# ), -# ) - -# final_state, final_adaptation_state, info1 = run_eca( -# key_umclmc, -# initial_state, -# kernel, -# adap, -# num_steps1, -# num_chains, -# mesh, -# ensemble_observables, -# early_stop=early_stop, -# ) - -key = jax.random.key(0) - -emaus( - logdensity_fn=logdensity_fn, - sample_init=sample_init, - transform=transform, - ndims=ndims, - num_steps1=100, - num_steps2=300, - num_chains=num_chains, - mesh=mesh, - rng_key=key, - alpha=1.9, - C=0.1, - early_stop=1, - r_end=1e-2, - diagonal_preconditioning=True, - integrator_coefficients=mclachlan_coefficients, - steps_per_sample=15, - acc_prob=None, - ensemble_observables=lambda x: x, - # adap=adap, - # kernel=kernel, - # initial_state=initial_state, - # key_umclmc=key_umclmc, - # ensemble_observables = lambda x: vec @ x - ) # run the algorithm - - -# a = jnp.array([8.0, 4.0]) - -# def f(rng_key, x, args): -# return x + normal(rng_key, x.shape) + a, a - -# out = ensemble_execute_fn( -# func = f, -# rng_key = jax.random.PRNGKey(0), -# num_chains = 4, -# mesh = mesh, -# x = None, -# args = None, -# summary_statistics_fn = lambda y: a, -# ) - -# print(out) \ No newline at end of file diff --git a/tests/mcmc/minimal_repro_3.py b/tests/mcmc/minimal_repro_3.py deleted file mode 100644 index 56e08caeb..000000000 --- a/tests/mcmc/minimal_repro_3.py +++ /dev/null @@ -1,514 +0,0 @@ - - -from typing import Any, NamedTuple - -import jax -import jax.numpy as jnp - -import blackjax.adaptation.ensemble_umclmc as umclmc -from blackjax.adaptation.ensemble_umclmc import ( - equipartition_diagonal, - equipartition_diagonal_loss, -) -from blackjax.adaptation.step_size import bisection_monotonic_fn -from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt -from blackjax.mcmc.hmc import HMCState -from blackjax.mcmc.integrators import ( - generate_isokinetic_integrator, - mclachlan_coefficients, - omelyan_coefficients, -) -import jax -import jax.numpy as jnp -from jax import device_put, jit, lax, vmap -from jax.experimental.shard_map import shard_map -from jax.flatten_util import ravel_pytree -from jax.random import normal, split -from jax.sharding import NamedSharding, PartitionSpec -from jax.tree_util import tree_leaves, tree_map - -import blackjax.adaptation.ensemble_umclmc as umclmc - - -def eca_step( - kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None -): - """ - Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. - """ - - def _step(state_all, xs): - """This function operates on a single device.""" - ( - state, - adaptation_state, - ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. - ( - _, - keys_sampling, - key_adaptation, - ) = xs # keys_sampling.shape = (chains_per_device, ) - - # update the state of all chains on this device - state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) - - # combine all the chains to compute expectation values - theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) - Etheta = tree_map( - lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") - / num_chains, - theta, - ) - - # use these to adapt the hyperparameters of the dynamics - adaptation_state, info_to_be_stored = adaptation_update( - adaptation_state, Etheta - ) - - return (state, adaptation_state), info_to_be_stored - - if ensemble_info is not None: - - def step(state_all, xs): - (state, adaptation_state), info_to_be_stored = _step(state_all, xs) - return (state, adaptation_state), ( - info_to_be_stored, - vmap(ensemble_info)(state.position), - ) - - return step - - else: - return _step - - - - -def ensemble_execute_fn( - func, - rng_key, - num_chains, - mesh, - x=None, - args=None, - summary_statistics_fn=lambda y: 0.0, -): - """Given a sequential function - func(rng_key, x, args) = y, - evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. - Args: - x: array distributed over all decvices - args: additional arguments for func, not distributed. - summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. - rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. - - Returns: - y: array distributed over all decvices. Need not be of the same shape as x. - Etheta: expected values of the summary statistics - """ - p, pscalar = PartitionSpec("chains"), PartitionSpec() - - if x is None: - X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) - else: - X = x - - adaptation_update = lambda _, Etheta: (Etheta, None) - - _F = eca_step( - func, - lambda y, info, key: summary_statistics_fn(y), - adaptation_update, - num_chains, - ) - - def F(x, keys): - """This function operates on a single device. key is a random key for this device.""" - y, summary_statistics = _F((x, args), (None, keys, None))[0] - return y, summary_statistics - - parallel_execute = shard_map( - F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False - ) - - keys = device_put( - split(rng_key, num_chains), NamedSharding(mesh, p) - ) # random keys, distributed across devices - # apply F in parallel - return parallel_execute(X, keys) - -def run_eca( - rng_key, - initial_state, - kernel, - adaptation, - num_steps, - num_chains, - mesh, - ensemble_info=None, - early_stop=False, -): - """ - Run ensemble chain adaptation (eca) in parallel on multiple devices. - ----------------------------------------------------- - Args: - rng_key: random key - initial_state: initial state of the system - kernel: kernel for the dynamics - adaptation: adaptation object - num_steps: number of steps to run - num_chains: number of chains - mesh: mesh for parallelization - ensemble_info: function that takes the state of the system and returns some information about the ensemble - early_stop: whether to stop early - Returns: - final_state: final state of the system - final_adaptation_state: final adaptation state - info_history: history of the information that was stored at each step (if early_stop is False, then this is None) - """ - - step = eca_step( - kernel, - adaptation.summary_statistics_fn, - adaptation.update, - num_chains, - ensemble_info, - ) - - def all_steps(initial_state, keys_sampling, keys_adaptation): - """This function operates on a single device. key is a random key for this device.""" - - initial_state_all = (initial_state, adaptation.initial_state) - - # run sampling - xs = ( - jnp.arange(num_steps), - keys_sampling.T, - keys_adaptation, - ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - - # ((a, Int) -> (a, Int)) - def step_while(a): - x, i, _ = a - - auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - - output, info = step(x, auxilliary_input) - - return (output, i + 1, info[0].get("while_cond")) - - if early_stop: - final_state_all, i, _ = lax.while_loop( - lambda a: ((a[1] < num_steps) & a[2]), - step_while, - (initial_state_all, 0, True), - ) - info_history = None - - else: - final_state_all, info_history = lax.scan(step, initial_state_all, xs) - - final_state, final_adaptation_state = final_state_all - return ( - final_state, - final_adaptation_state, - info_history, - ) # info history is composed of averages over all chains, so it is a couple of scalars - - p, pscalar = PartitionSpec("chains"), PartitionSpec() - parallel_execute = shard_map( - all_steps, - mesh=mesh, - in_specs=(p, p, pscalar), - out_specs=(p, pscalar, pscalar), - check_rep=False, - ) - - # produce all random keys that will be needed - - key_sampling, key_adaptation = split(rng_key) - num_steps = jnp.array(num_steps).item() - keys_adaptation = split(key_adaptation, num_steps) - distribute_keys = lambda key, shape: device_put( - split(key, shape), NamedSharding(mesh, p) - ) # random keys, distributed across devices - keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) - - # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute( - initial_state, keys_sampling, keys_adaptation - ) - - return final_state, final_adaptation_state, info_history - -# from blackjax.util import run_eca - - - -class AdaptationState(NamedTuple): - steps_per_sample: float - step_size: float - stepsize_adaptation_state: ( - Any # the state of the bisection algorithm to find a stepsize - ) - iteration: int - - -build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( - logdensity_fn=logdensity_fn, - integrator=integrator, - inverse_mass_matrix=inverse_mass_matrix, -)( - rng_key=key, - state=state, - step_size=adap.step_size, - num_integration_steps=adap.steps_per_sample, - L_proposal_factor=1.25, -) - - -class Adaptation: - def __init__( - self, - adaptation_state, - num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params - steps_per_sample=15, # L/eps - acc_prob_target=0.8, - observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep - observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains - contract=lambda x: 0.0, # just for diagnostics: observabiels for bias, contracted over dimensions - ): - self.num_adaptation_samples = num_adaptation_samples - self.observables = observables - self.observables_for_bias = observables_for_bias - self.contract = contract - - # Determine the initial hyperparameters # - - # stepsize # - # if we switched to the more accurate integrator we can use longer step size - # integrator_factor = jnp.sqrt(10.) if mclachlan else 1. - # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 - # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 - # With the current eps, we had sigma^2 = EEVPD * d for N = 1. - # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 - # adjustment_factor = jnp.power(0.82 / (ndims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) - step_size = adaptation_state.step_size - - # Initialize the bisection for finding the step size - self.epsadap_update = bisection_monotonic_fn(acc_prob_target) - stepsize_adaptation_state = (jnp.array([-jnp.inf, jnp.inf]), False) - - self.initial_state = AdaptationState( - steps_per_sample, step_size, stepsize_adaptation_state, 0 - ) - - def summary_statistics_fn(self, state, info, rng_key): - return { - "acceptance_probability": info.acceptance_rate, - "equipartition_diagonal": equipartition_diagonal( - state - ), # metric for bias: equipartition theorem gives todo... - "observables": self.observables(state.position), - "observables_for_bias": self.observables_for_bias(state.position), - } - - def update(self, adaptation_state, Etheta): - acc_prob = Etheta["acceptance_probability"] - equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) - true_bias = self.contract(Etheta["observables_for_bias"]) # remove - - info_to_be_stored = { - "L": adaptation_state.step_size * adaptation_state.steps_per_sample, - "steps_per_sample": adaptation_state.steps_per_sample, - "step_size": adaptation_state.step_size, - "acc_prob": acc_prob, - "equi_diag": equi_diag, - "bias": true_bias, - "observables": Etheta["observables"], - } - - # Bisection to find step size - stepsize_adaptation_state, step_size = self.epsadap_update( - adaptation_state.stepsize_adaptation_state, - adaptation_state.step_size, - acc_prob, - ) - - return ( - AdaptationState( - adaptation_state.steps_per_sample, - step_size, - stepsize_adaptation_state, - adaptation_state.iteration + 1, - ), - info_to_be_stored, - ) - - -def bias(model): - """should be transfered to benchmarks/""" - - def observables(position): - return jnp.square(model.transform(position)) - - def contract(sampler_E_x2): - bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 - return jnp.array([jnp.max(bsq), jnp.average(bsq)]) - - return observables, contract - - -def while_steps_num(cond): - if jnp.all(cond): - return len(cond) - else: - return jnp.argmin(cond) + 1 - - - -def logdensity_fn(x): - mu2 = 0.03 * (x[0] ** 2 - 100) - return -0.5 * (jnp.square(x[0] / 10.0) + jnp.square(x[1] - mu2)) - -def transform(x): - return x - -def sample_init(key): - z = jax.random.normal(key, shape=(2,)) - x0 = 10.0 * z[0] - x1 = 0.03 * (x0**2 - 100) + z[1] - return jnp.array([x0, x1]) - -num_chains = 128 - -mesh = jax.sharding.Mesh(devices=jax.devices(),axis_names= "chains") - -key_init, key_umclmc, key_mclmc = jax.random.split(jax.random.key(0), 3) - -integrator_coefficients = mclachlan_coefficients - -acc_prob = None - -# initialize the chains -initial_state = umclmc.initialize( - key_init, logdensity_fn, sample_init, num_chains, mesh -) - -diagonal_preconditioning = False -ndims = 2 - -alpha = 1.9 -C = 0.1 -r_end=5e-3 -ensemble_observables=lambda x: x - -# burn-in with the unadjusted method # -kernel = umclmc.build_kernel(logdensity_fn) -save_num = 20 # (int)(jnp.rint(save_frac * num_steps1)) -adap = umclmc.Adaptation( - ndims, - alpha=alpha, - bias_type=3, - save_num=save_num, - C=C, - power=3.0 / 8.0, - r_end=r_end, - observables_for_bias=lambda position: jnp.square( - transform(jax.flatten_util.ravel_pytree(position)[0]) - ), -) - -final_state, final_adaptation_state, info1 = run_eca( - key_umclmc, - initial_state, - kernel, - adap, - 100, - num_chains, - mesh, - ensemble_observables, - early_stop=True, - ) - -# refine the results with the adjusted method -_acc_prob = acc_prob -if integrator_coefficients is None: - high_dims = ndims > 200 - _integrator_coefficients = ( - omelyan_coefficients if high_dims else mclachlan_coefficients - ) - if acc_prob is None: - _acc_prob = 0.9 if high_dims else 0.7 - -else: - _integrator_coefficients = integrator_coefficients - if acc_prob is None: - _acc_prob = 0.9 - -integrator = generate_isokinetic_integrator(_integrator_coefficients) -gradient_calls_per_step = ( - len(_integrator_coefficients) // 2 -) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. - -if diagonal_preconditioning: - inverse_mass_matrix = jnp.sqrt(final_adaptation_state.inverse_mass_matrix) - - # scale the stepsize so that it reflects averag scale change of the preconditioning - average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) - final_adaptation_state = final_adaptation_state._replace( - step_size=final_adaptation_state.step_size / average_scale_change - ) - -else: - inverse_mass_matrix = 1.0 - -kernel = build_kernel( - logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix -) -steps_per_sample = 15 -num_steps2 = 100 - - -initial_state = HMCState( - final_state.position, final_state.logdensity, final_state.logdensity_grad - ) - -print(initial_state.position.shape, "bar\n\n") - -# pos = jax.random.normal(key_mclmc, shape=(num_chains, ndims)) - - - -# print("baz", logdensity_fn(pos)) - -# initial_state = HMCState( -# pos, logdensity_fn(pos[0]), jax.grad(logdensity_fn)(pos[0]) -# ) - - -num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) -num_adaptation_samples = ( - num_samples // 2 -) # number of samples after which the stepsize is fixed. - -adap = Adaptation( - final_adaptation_state, - num_adaptation_samples, - steps_per_sample, - _acc_prob, -) - - - -final_state, final_adaptation_state, info2 = run_eca( - key_mclmc, - initial_state, - kernel, - adap, - num_samples, - num_chains, - mesh, - ensemble_observables, -) - From b55ab0df1666efc8757fd1999c98e2146dfbcf22 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 13:36:39 -0400 Subject: [PATCH 31/63] bug fix --- blackjax/adaptation/ensemble_umclmc.py | 1 + blackjax/mcmc/alternate_emaus.py | 85 -------------------------- 2 files changed, 1 insertion(+), 85 deletions(-) delete mode 100644 blackjax/mcmc/alternate_emaus.py diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 3cc62ae00..d5df4ee92 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -128,6 +128,7 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): + new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) return jnp.concatenate((new_vals[None, :], history[:-1, :])) diff --git a/blackjax/mcmc/alternate_emaus.py b/blackjax/mcmc/alternate_emaus.py deleted file mode 100644 index 6010bab73..000000000 --- a/blackjax/mcmc/alternate_emaus.py +++ /dev/null @@ -1,85 +0,0 @@ -import jax -import jax.numpy as jnp -from blackjax.util import run_eca -import blackjax.adaptation.ensemble_umclmc as umclmc - - -def emaus( - logdensity_fn, - sample_init, - transform, - ndims, - num_steps1, - num_steps2, - num_chains, - mesh, - rng_key, - alpha=1.9, - save_frac=0.2, - C=0.1, - early_stop=True, - r_end=5e-3, - diagonal_preconditioning=True, - integrator_coefficients=None, - steps_per_sample=15, - acc_prob=None, - observables=lambda x: None, - ensemble_observables=None, - diagnostics=True, -): - """ - model: the target density object - num_steps1: number of steps in the first phase - num_steps2: number of steps in the second phase - num_chains: number of chains - mesh: the mesh object, used for distributing the computation across cpus and nodes - rng_key: the random key - alpha: L = sqrt{d}*alpha*variances - save_frac: the fraction of samples used to estimate the fluctuation in the first phase - C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) - early_stop: whether to stop the first phase early - r_end - diagonal_preconditioning: whether to use diagonal preconditioning - integrator_coefficients: the coefficients of the integrator - steps_per_sample: the number of steps per sample - acc_prob: the acceptance probability - observables: the observables (for diagnostic use) - ensemble_observables: observable calculated over the ensemble (for diagnostic use) - diagnostics: whether to return diagnostics - """ - - # observables_for_bias, contract = bias(model) - key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) - - # initialize the chains - initial_state = umclmc.initialize( - key_init, logdensity_fn, sample_init, num_chains, mesh - ) - - # burn-in with the unadjusted method # - kernel = umclmc.build_kernel(logdensity_fn) - save_num = (int)(jnp.rint(save_frac * num_steps1)) - adap = umclmc.Adaptation( - ndims, - alpha=alpha, - bias_type=3, - save_num=save_num, - C=C, - power=3.0 / 8.0, - r_end=r_end, - observables_for_bias=lambda position: jnp.square( - transform(jax.flatten_util.ravel_pytree(position)[0]) - ), - ) - - final_state, final_adaptation_state, info1 = run_eca( - key_umclmc, - initial_state, - kernel, - adap, - num_steps1, - num_chains, - mesh, - ensemble_observables, - early_stop=early_stop, - ) \ No newline at end of file From e6da5c2a951685ffc3c9683ed8195f2b496e72a3 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Mon, 10 Mar 2025 12:19:58 -0700 Subject: [PATCH 32/63] changes --- blackjax/adaptation/ensemble_mclmc.py | 6 +++--- blackjax/util.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index ebc87cd5a..8c16008b2 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -222,7 +222,7 @@ def emaus( contract=contract, ) - final_state, final_adaptation_state, info1 = run_eca( + final_state, final_adaptation_state, info1, steps_done_phase_1 = run_eca( key_umclmc, initial_state, kernel, @@ -288,7 +288,7 @@ def emaus( ) - final_state, final_adaptation_state, info2 = run_eca( + final_state, final_adaptation_state, info2, steps_done_phase_2 = run_eca( key_mclmc, initial_state, kernel, @@ -300,7 +300,7 @@ def emaus( ) if diagnostics: - info = {"phase_1": info1, "phase_2": info2} + info = {"phase_1": {'steps_done' : steps_done_phase_1}, "phase_2": info2} else: info = None diff --git a/blackjax/util.py b/blackjax/util.py index 4668befee..2c87d0e8e 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -441,11 +441,12 @@ def step_while(a): step_while, (initial_state_all, 0, True), ) + steps_done = i info_history = None else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) - + steps_done = num_steps final_state, final_adaptation_state = final_state_all @@ -453,6 +454,7 @@ def step_while(a): final_state, final_adaptation_state, info_history, + steps_done ) # info history is composed of averages over all chains, so it is a couple of scalars @@ -461,7 +463,7 @@ def step_while(a): all_steps, mesh=mesh, in_specs=(p, p, pscalar), - out_specs=(p, pscalar, pscalar), + out_specs=(p, pscalar, pscalar, pscalar), check_rep=False, ) @@ -476,11 +478,13 @@ def step_while(a): keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) # run sampling in parallel - final_state, final_adaptation_state, info_history = parallel_execute( + final_state, final_adaptation_state, info_history, steps_done = parallel_execute( initial_state, keys_sampling, keys_adaptation ) - return final_state, final_adaptation_state, info_history + + + return final_state, final_adaptation_state, info_history, steps_done def ensemble_execute_fn( From 0bd14143170dad939b7d1c30a4c4df86381c5231 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 15:24:24 -0400 Subject: [PATCH 33/63] bug fix --- blackjax/adaptation/ensemble_mclmc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index cb763d94a..0b418c9e9 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -294,8 +294,7 @@ def emaus( observables_for_bias=observables_for_bias, ) - - final_state, final_adaptation_state, info2, steps_done_phase_2 = run_eca( + final_state, final_adaptation_state, info2, _ = run_eca( key_mclmc, initial_state, kernel, @@ -307,7 +306,7 @@ def emaus( ) if diagnostics: - info = {"phase_1": {'steps_done' : steps_done_phase_1}, "phase_2": info2} + info = {"phase_1": {"steps_done": steps_done_phase_1}, "phase_2": info2} else: info = None From 13a375ca0ba1cb516992f8114b535945bd6da767 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Mar 2025 15:29:01 -0400 Subject: [PATCH 34/63] bug fix --- blackjax/util.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 91dcc01db..d6a2d15ca 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -442,14 +442,13 @@ def step_while(a): else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) steps_done = num_steps - final_state, final_adaptation_state = final_state_all return ( final_state, final_adaptation_state, info_history, - steps_done + steps_done, ) # info history is composed of averages over all chains, so it is a couple of scalars p, pscalar = PartitionSpec("chains"), PartitionSpec() @@ -476,8 +475,6 @@ def step_while(a): initial_state, keys_sampling, keys_adaptation ) - - return final_state, final_adaptation_state, info_history, steps_done From 04e5b61a08b6735b73aa1f69d59f1a24cbab8108 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Sat, 29 Mar 2025 12:30:27 -0700 Subject: [PATCH 35/63] first attempt at langevin --- blackjax/mcmc/integrators.py | 19 ++++++++++++++----- blackjax/mcmc/nuts.py | 4 ++++ blackjax/mcmc/underdamped_langevin.py | 10 ++++------ 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index d240822ba..11721460a 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -504,7 +504,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): return stochastic_integrator -def with_maruyama(integrator): +def with_maruyama(integrator, kinetic_energy): def stochastic_integrator(init_state, step_size, L_proposal, rng_key): key1, key2 = jax.random.split(rng_key) # partial refreshment @@ -520,19 +520,28 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): # jax.debug.print("state 1.5 {x}",x=state) # state = init_state # TODO: add noise back! # one step of the deterministic dynamics - state = integrator(state, step_size) + new_state = integrator(state, step_size) # jax.debug.print("state 2 {x}",x=state) + + kinetic_change = - kinetic_energy(new_state.momentum) + kinetic_energy( + state.momentum + ) + energy_change = kinetic_change - new_state.logdensity + state.logdensity # partial refreshment - state = state._replace( + state = new_state._replace( momentum=partially_refresh_momentum( - momentum=state.momentum, + momentum=new_state.momentum, rng_key=key2, L=L_proposal, step_size=step_size * 0.5, ) ) - return state + + + + + return state, (kinetic_change, energy_change) return stochastic_integrator diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index c75ecdec6..73d8aab4c 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -120,6 +120,7 @@ def kernel( ) -> tuple[hmc.HMCState, NUTSInfo]: """Generate a new sample with the NUTS kernel.""" + metric = metrics.default_metric(inverse_mass_matrix) symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy) proposal_generator = iterative_nuts_proposal( @@ -142,6 +143,9 @@ def kernel( proposal = hmc.HMCState( proposal.position, proposal.logdensity, proposal.logdensity_grad ) + # jax.debug.print("step {x}",x=info.acceptance_rate) + # jax.debug.print("num steps {x}",x=info.num_integration_steps) + # jax.debug.print("step size {x}",x=step_size) return proposal, info return kernel diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py index 7417249ee..68683f1c8 100644 --- a/blackjax/mcmc/underdamped_langevin.py +++ b/blackjax/mcmc/underdamped_langevin.py @@ -82,18 +82,16 @@ def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): # ) metric = metrics.default_metric(inverse_mass_matrix) - step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy)) + step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, LangevinInfo]: - (position, momentum, logdensity, logdensitygrad) = step( + (position, momentum, logdensity, logdensitygrad), (kinetic_change, energy_change) = step( state, step_size, L, rng_key ) - kinetic_change = - metric.kinetic_energy(momentum) + metric.kinetic_energy( - state.momentum - ) + # kinetic_change = - momentum@momentum/2 + state.momentum@state.momentum/2 @@ -101,7 +99,7 @@ def kernel( position, momentum, logdensity, logdensitygrad ), LangevinInfo( logdensity=logdensity, - energy_change= kinetic_change - logdensity + state.logdensity, + energy_change=energy_change, kinetic_change=kinetic_change ) From 3906c0e370aafc0e6710c5465f575e7031fc0879 Mon Sep 17 00:00:00 2001 From: = Date: Tue, 1 Apr 2025 18:22:46 -0400 Subject: [PATCH 36/63] emaus diagnostics --- blackjax/adaptation/ensemble_mclmc.py | 2 +- blackjax/util.py | 40 ++++++++++++++++++++++++--- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 0b418c9e9..7b47b2164 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -306,7 +306,7 @@ def emaus( ) if diagnostics: - info = {"phase_1": {"steps_done": steps_done_phase_1}, "phase_2": info2} + info = {"phase_1": info1, "phase_2": info2} else: info = None diff --git a/blackjax/util.py b/blackjax/util.py index d6a2d15ca..b074b79e7 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -420,15 +420,35 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): keys_adaptation, ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) - # ((a, Int) -> (a, Int)) + EEVPD = jnp.zeros((num_steps,)) + EEVPD_wanted = jnp.zeros((num_steps,)) + L = jnp.zeros((num_steps,)) + entropy = jnp.zeros((num_steps,)) + equi_diag = jnp.zeros((num_steps,)) + equi_full = jnp.zeros((num_steps,)) + observables = jnp.zeros((num_steps,)) + r_avg = jnp.zeros((num_steps,)) + r_max = jnp.zeros((num_steps,)) + step_size = jnp.zeros((num_steps,)) + def step_while(a): x, i, _ = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) - output, info = step(x, auxilliary_input) + output, (info, pos) = step(x, auxilliary_input) + EEVPD.at[i].set(info.get("EEVPD")) + EEVPD_wanted.at[i].set(info.get("EEVPD_wanted")) + L.at[i].set(info.get("L")) + entropy.at[i].set(info.get("entropy")) + equi_diag.at[i].set(info.get("equi_diag")) + equi_full.at[i].set(info.get("equi_full")) + observables.at[i].set(info.get("observables")) + r_avg.at[i].set(info.get("r_avg")) + r_max.at[i].set(info.get("r_max")) + step_size.at[i].set(info.get("step_size")) - return (output, i + 1, info[0].get("while_cond")) + return (output, i + 1, info.get("while_cond")) if early_stop: final_state_all, i, _ = lax.while_loop( @@ -437,7 +457,19 @@ def step_while(a): (initial_state_all, 0, True), ) steps_done = i - info_history = None + info_history = { + "EEVPD": EEVPD, + "EEVPD_wanted": EEVPD_wanted, + "L": L, + "entropy": entropy, + "equi_diag": equi_diag, + "equi_full": equi_full, + "observables": observables, + "r_avg": r_avg, + "r_max": r_max, + "step_size": step_size, + "steps_done": steps_done, + } else: final_state_all, info_history = lax.scan(step, initial_state_all, xs) From 3d465f296a5c11f2c4583869e4f3a91f66865b56 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Thu, 3 Apr 2025 20:13:58 -0700 Subject: [PATCH 37/63] tuning for hmc --- .../adaptation/adjusted_mclmc_adaptation.py | 10 +++++- blackjax/adaptation/mclmc_adaptation.py | 24 ++++++++++--- blackjax/mcmc/dynamic_hmc.py | 4 +++ blackjax/mcmc/hmc.py | 4 +++ blackjax/mcmc/integrators.py | 3 +- blackjax/mcmc/underdamped_langevin.py | 36 ++----------------- 6 files changed, 40 insertions(+), 41 deletions(-) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index c86085a54..6236b19e3 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -28,6 +28,7 @@ def adjusted_mclmc_find_L_and_step_size( max="avg", num_windows=1, tuning_factor=1.3, + euclidean=False, ): """ Finds the optimal value of the parameters for the MH-MCHMC algorithm. @@ -73,7 +74,14 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: - params = MCLMCAdaptationState( + if euclidean: + + params = MCLMCAdaptationState( + 1.0, 0.2, inverse_mass_matrix=jnp.ones((dim,)) + ) + + else: + params = MCLMCAdaptationState( jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,)) ) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 5a673414f..b121fed0b 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -52,6 +52,7 @@ def mclmc_find_L_and_step_size( num_effective_samples=150, params=None, diagonal_preconditioning=True, + euclidean=False ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -82,6 +83,7 @@ def mclmc_find_L_and_step_size( Whether to do diagonal preconditioning (i.e. a mass matrix) params Initial params to start tuning from (optional) + euclidean: if this tuning is used for HMC or underdamped LMC, there are sqrt{d} factors that need to be taken into account (because L is parametrized differently) Returns ------- @@ -109,9 +111,15 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) if params is None: - params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) - ) + if euclidean: + params = MCLMCAdaptationState( + 1, 0.25, inverse_mass_matrix=jnp.ones((dim,)) + ) + else: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) + ) + part1_key, part2_key = jax.random.split(rng_key, 2) total_num_tuning_integrator_steps = 0 @@ -131,6 +139,7 @@ def mclmc_find_L_and_step_size( trust_in_estimate=trust_in_estimate, num_effective_samples=num_effective_samples, diagonal_preconditioning=diagonal_preconditioning, + euclidean=euclidean )(state, params, num_steps, part1_key) total_num_tuning_integrator_steps += num_steps1 + num_steps2 @@ -152,6 +161,7 @@ def make_L_step_size_adaptation( desired_energy_var=1e-3, trust_in_estimate=1.5, num_effective_samples=150, + euclidean=False ): """Adapts the stepsize and L of the MCLMC kernel. Designed for unadjusted MCLMC""" @@ -265,12 +275,16 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) - L = jnp.sqrt(jnp.sum(variances)) + L = jnp.sqrt(jnp.sum(variances)) # lmc: should be jnp.mean + if euclidean: + L /= jnp.sqrt(dim) if diagonal_preconditioning: inverse_mass_matrix = variances params = params._replace(inverse_mass_matrix=inverse_mass_matrix) - L = jnp.sqrt(dim) + L = jnp.sqrt(dim) # lmc: 1 + if euclidean: + L /= jnp.sqrt(dim) # readjust the stepsize steps = round(num_steps2 / 3) # we do some small number of steps diff --git a/blackjax/mcmc/dynamic_hmc.py b/blackjax/mcmc/dynamic_hmc.py index de77be825..848922774 100644 --- a/blackjax/mcmc/dynamic_hmc.py +++ b/blackjax/mcmc/dynamic_hmc.py @@ -101,6 +101,10 @@ def kernel( inverse_mass_matrix, num_integration_steps, ) + + # jax.debug.print("logdensity {x}", x=hmc_proposal.logdensity) + # jax.debug.print("acceptance {x}", x=info) + next_random_arg = next_random_arg_fn(state.random_generator_arg) return ( DynamicHMCState( diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 452b94e44..1fdb82893 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -136,6 +136,8 @@ def kernel( position, logdensity, logdensity_grad = state momentum = metric.sample_momentum(key_momentum, position) + import jax.numpy as jnp + # jax.debug.print("momentum nan? {x}",x=jnp.any(jnp.isnan(momentum))) integrator_state = integrators.IntegratorState( position, momentum, logdensity, logdensity_grad @@ -299,7 +301,9 @@ def generate( proposal_energy = hmc_energy_fn(state) new_energy = hmc_energy_fn(end_state) delta_energy = safe_energy_diff(proposal_energy, new_energy) + # jax.debug.print("delta_energy {x}",x=delta_energy) is_diverging = -delta_energy > divergence_threshold + # jax.debug.print("is_diverging {x}",x=is_diverging) sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 11721460a..682ff6e9a 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -100,7 +100,7 @@ def generalized_two_stage_integrator( def one_step(state: IntegratorState, step_size: float): position, momentum, _, logdensity_grad = state - # jax.debug.print("initial state {x}", x=state) + # jax.debug.print("initial state {x}", x=jnp.any(jnp.isnan(momentum))) # auxiliary infomation generated during integration for diagnostics. It is # updated by the operator1 and operator2 at each call. momentum_update_info = None @@ -167,6 +167,7 @@ def update( auxiliary_info=None, ): + # jax.debug.print("nan? {x}", x=jnp.any(jnp.isnan(kinetic_grad))) # jax.debug.print("position {x}", x=position) del auxiliary_info new_position = jax.tree_util.tree_map( diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py index 68683f1c8..b6dfa2be4 100644 --- a/blackjax/mcmc/underdamped_langevin.py +++ b/blackjax/mcmc/underdamped_langevin.py @@ -91,6 +91,8 @@ def kernel( state, step_size, L, rng_key ) + # jax.debug.print("energy change {x}", x=energy_change) + # kinetic_change = - momentum@momentum/2 + state.momentum@state.momentum/2 @@ -121,42 +123,8 @@ def as_top_level_api( We also add the general kernel and state generator as an attribute to this class so users only need to pass `blackjax.langevin` to SMC, adaptation, etc. algorithms. - Examples - -------- - - A new langevin kernel can be initialized and used with the following code: - - .. code:: - - langevin = blackjax.mcmc.langevin.langevin( - logdensity_fn=logdensity_fn, - L=L, - step_size=step_size - ) - state = langevin.init(position) - new_state, info = langevin.step(rng_key, state) - - Kernels are not jit-compiled by default so you will need to do it manually: - .. code:: - step = jax.jit(langevin.step) - new_state, info = step(rng_key, state) - - Parameters - ---------- - logdensity_fn - The log-density function we wish to draw samples from. - L - the momentum decoherence rate - step_size - step size of the integrator - integrator - an integrator. We recommend using the default here. - - Returns - ------- - A ``SamplingAlgorithm``. """ kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) From e17b1cc19032e6982fdcb333ae22135e5e14ef7d Mon Sep 17 00:00:00 2001 From: = Date: Wed, 9 Apr 2025 12:02:08 -0400 Subject: [PATCH 38/63] energy error monitoring --- blackjax/mcmc/mclmc.py | 61 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ff9638a1f..e0e4ee304 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,6 +15,7 @@ from typing import Callable, NamedTuple import jax +import jax.numpy as jnp from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import ( @@ -60,7 +61,13 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) -def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): +def build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +): """Build a HMC kernel. Parameters @@ -91,14 +98,46 @@ def kernel( state, step_size, L, rng_key ) - return IntegratorState( - position, momentum, logdensity, logdensitygrad - ), MCLMCInfo( - logdensity=logdensity, - energy_change=kinetic_change - logdensity + state.logdensity, - kinetic_change=kinetic_change, + energy_error = kinetic_change - logdensity + state.logdensity + + eev_max = desired_energy_var_max_ratio * desired_energy_var + # if energy_error > jnp.sqrt(eev_max): + # return state, MCLMCInfo( + # logdensity=state.logdensity, + # energy_change=0, + # kinetic_change=0, + # ) + + new_state, new_info = jax.lax.cond( + energy_error > jnp.sqrt(eev_max), + lambda: ( + state, + MCLMCInfo( + logdensity=state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + ), + ), + lambda: ( + IntegratorState(position, momentum, logdensity, logdensitygrad), + MCLMCInfo( + logdensity=logdensity, + energy_change=energy_error, + kinetic_change=kinetic_change, + ), + ), ) + return new_state, new_info + + # return IntegratorState( + # position, momentum, logdensity, logdensitygrad + # ), MCLMCInfo( + # logdensity=logdensity, + # energy_change=energy_error, + # kinetic_change=kinetic_change, + # ) + return kernel @@ -108,6 +147,7 @@ def as_top_level_api( step_size, integrator=isokinetic_mclachlan, inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,7 +195,12 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) + kernel = build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + ) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) From 1015774b420ff91df1d1a70b53870e405a514c7b Mon Sep 17 00:00:00 2001 From: = Date: Wed, 9 Apr 2025 16:15:26 -0400 Subject: [PATCH 39/63] energy error monitoring --- blackjax/mcmc/mclmc.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e0e4ee304..771c542f7 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -100,16 +100,11 @@ def kernel( energy_error = kinetic_change - logdensity + state.logdensity - eev_max = desired_energy_var_max_ratio * desired_energy_var - # if energy_error > jnp.sqrt(eev_max): - # return state, MCLMCInfo( - # logdensity=state.logdensity, - # energy_change=0, - # kinetic_change=0, - # ) + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) new_state, new_info = jax.lax.cond( - energy_error > jnp.sqrt(eev_max), + energy_error > jnp.sqrt(ndims * eev_max_per_dim), lambda: ( state, MCLMCInfo( @@ -130,14 +125,6 @@ def kernel( return new_state, new_info - # return IntegratorState( - # position, momentum, logdensity, logdensitygrad - # ), MCLMCInfo( - # logdensity=logdensity, - # energy_change=energy_error, - # kinetic_change=kinetic_change, - # ) - return kernel From a164c63c8aaf05abeee258603bc1abbcfe06d942 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Wed, 23 Apr 2025 10:28:42 -0700 Subject: [PATCH 40/63] windows for unadjusted --- blackjax/adaptation/mclmc_adaptation.py | 29 +++++++++++++++---------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index db919687f..38c1a6f94 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -53,6 +53,7 @@ def mclmc_find_L_and_step_size( num_effective_samples=150, params=None, diagonal_preconditioning=True, + num_windows=1, euclidean=False ): """ @@ -122,6 +123,8 @@ def mclmc_find_L_and_step_size( ) + + part1_key, part2_key = jax.random.split(rng_key, 2) total_num_tuning_integrator_steps = 0 @@ -131,17 +134,21 @@ def mclmc_find_L_and_step_size( num_steps2 += diagonal_preconditioning * (num_steps2 // 3) num_steps3 = round(num_steps * frac_tune3) - state, params = make_L_step_size_adaptation( - kernel=mclmc_kernel, - dim=dim, - frac_tune1=frac_tune1, - frac_tune2=frac_tune2, - desired_energy_var=desired_energy_var, - trust_in_estimate=trust_in_estimate, - num_effective_samples=num_effective_samples, - diagonal_preconditioning=diagonal_preconditioning, - euclidean=euclidean - )(state, params, num_steps, part1_key) + for i in range(num_windows): + window_key = jax.random.fold_in(part1_key, i) + + state, params = make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1/num_windows, + frac_tune2=frac_tune2/num_windows, + desired_energy_var=desired_energy_var, + trust_in_estimate=trust_in_estimate, + num_effective_samples=num_effective_samples, + diagonal_preconditioning=diagonal_preconditioning, + euclidean=euclidean + )(state, params, num_steps, window_key) + total_num_tuning_integrator_steps += num_steps1 + num_steps2 if num_steps3 >= 2: # at least 2 samples for ESS estimation From 644873587f4417494453ce4b291dd46e96f61f0d Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Fri, 25 Apr 2025 07:05:08 -0700 Subject: [PATCH 41/63] add preconditioning for ulmc --- blackjax/adaptation/mclmc_adaptation.py | 2 +- blackjax/mcmc/integrators.py | 11 +++++++++-- blackjax/mcmc/underdamped_langevin.py | 6 +----- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 38c1a6f94..c3f805003 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -315,7 +315,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(mask, L_step_size_adaptation_keys), state=state, params=params ) - jax.debug.print("step size {x}", x=(params.step_size, x_average)) + # jax.debug.print("step size {x}", x=(params.step_size, x_average)) L = params.L # determine L diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 63e12d88e..f88c5eeaa 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -21,6 +21,7 @@ from blackjax.mcmc.metrics import KineticEnergy from blackjax.types import ArrayTree +from blackjax.mcmc.metrics import default_metric __all__ = [ "mclachlan", @@ -442,7 +443,7 @@ def partially_refresh_momentum_isokinetic(momentum, rng_key, step_size, L): ) -def partially_refresh_momentum(momentum, rng_key, step_size, L): +def partially_refresh_momentum(momentum, rng_key, step_size, L, inverse_mass_matrix): """Adds a small noise to momentum and normalizes. Parameters @@ -463,10 +464,14 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): # TODO m, unravel_fn = ravel_pytree(momentum) + # m = jax.tree_util.tree_map(lambda x: x * jnp.sqrt(inverse_mass_matrix), m) # dim = m.shape[0] c1 = jnp.exp(-step_size/L) c2 = jnp.sqrt((1-c1**2)) z = normal(rng_key, shape=m.shape, dtype=m.dtype) + metric = default_metric(inverse_mass_matrix) + z = metric.sample_momentum(rng_key, m) + # normal(rng_key, shape=m.shape, dtype=m.dtype) new_momentum = unravel_fn(c1*m + c2*z) return jax.lax.cond( @@ -506,7 +511,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): return stochastic_integrator -def with_maruyama(integrator, kinetic_energy): +def with_maruyama(integrator, kinetic_energy,inverse_mass_matrix): def stochastic_integrator(init_state, step_size, L_proposal, rng_key): key1, key2 = jax.random.split(rng_key) # partial refreshment @@ -517,6 +522,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): rng_key=key1, L=L_proposal, step_size=step_size * 0.5, + inverse_mass_matrix=inverse_mass_matrix, ) ) # jax.debug.print("state 1.5 {x}",x=state) @@ -537,6 +543,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): rng_key=key2, L=L_proposal, step_size=step_size * 0.5, + inverse_mass_matrix=inverse_mass_matrix, ) ) diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py index be6816446..d92de73e0 100644 --- a/blackjax/mcmc/underdamped_langevin.py +++ b/blackjax/mcmc/underdamped_langevin.py @@ -84,12 +84,8 @@ def build_kernel( """ - # step = with_isokinetic_maruyama( - # integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) - # ) - metric = metrics.default_metric(inverse_mass_matrix) - step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy) + step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy,inverse_mass_matrix) def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float From a612d39afbd48c9cfbf97ef688dfc93686c37a17 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Fri, 25 Apr 2025 17:55:28 -0700 Subject: [PATCH 42/63] fix emaus code --- blackjax/adaptation/ensemble_umclmc.py | 10 +++++---- blackjax/util.py | 30 +++++++++++++------------- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index d5df4ee92..ff10b7094 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -44,12 +44,14 @@ def nan_reject(nonans, old, new): def build_kernel(logdensity_fn): """MCLMC kernel (with nan rejection)""" - kernel = mclmc.build_kernel( - logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet - ) + # kernel = mclmc.build_kernel( + # logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet + # ) def sequential_kernel(key, state, adap): - new_state, info = kernel(key, state, adap.L, adap.step_size) + new_state, info = mclmc.build_kernel( + logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet, inverse_mass_matrix=adap.inverse_mass_matrix + )(key, state, adap.L, adap.step_size) # reject the new state if there were nans nonans = no_nans(new_state) diff --git a/blackjax/util.py b/blackjax/util.py index b074b79e7..295fdc691 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -432,29 +432,29 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): step_size = jnp.zeros((num_steps,)) def step_while(a): - x, i, _ = a + x, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) output, (info, pos) = step(x, auxilliary_input) - EEVPD.at[i].set(info.get("EEVPD")) - EEVPD_wanted.at[i].set(info.get("EEVPD_wanted")) - L.at[i].set(info.get("L")) - entropy.at[i].set(info.get("entropy")) - equi_diag.at[i].set(info.get("equi_diag")) - equi_full.at[i].set(info.get("equi_full")) - observables.at[i].set(info.get("observables")) - r_avg.at[i].set(info.get("r_avg")) - r_max.at[i].set(info.get("r_max")) - step_size.at[i].set(info.get("step_size")) - - return (output, i + 1, info.get("while_cond")) + new_EEVPD = EEVPD.at[i].set(info.get("EEVPD")) + new_EEVPD_wanted = EEVPD_wanted.at[i].set(info.get("EEVPD_wanted")) + new_L = L.at[i].set(info.get("L")) + new_entropy = entropy.at[i].set(info.get("entropy")) + new_equi_diag = equi_diag.at[i].set(info.get("equi_diag")) + new_equi_full = equi_full.at[i].set(info.get("equi_full")) + new_observables = observables.at[i].set(info.get("observables")) + new_r_avg = r_avg.at[i].set(info.get("r_avg")) + new_r_max = r_max.at[i].set(info.get("r_max")) + new_step_size = step_size.at[i].set(info.get("step_size")) + + return (output, i + 1, info.get("while_cond"), new_EEVPD, new_EEVPD_wanted, new_L, new_entropy, new_equi_diag, new_equi_full, new_observables, new_r_avg, new_r_max, new_step_size) if early_stop: - final_state_all, i, _ = lax.while_loop( + final_state_all, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size = lax.while_loop( lambda a: ((a[1] < num_steps) & a[2]), step_while, - (initial_state_all, 0, True), + (initial_state_all, 0, True, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size), ) steps_done = i info_history = { From 0ced55f7fb45f7a4d534b64d59a365f6e8d885e0 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Thu, 15 May 2025 13:33:00 -0700 Subject: [PATCH 43/63] FOR NEURIPS --- blackjax/__init__.py | 4 + .../adaptation/adjusted_mclmc_adaptation.py | 17 +- blackjax/adaptation/mclmc_adaptation.py | 2 +- blackjax/mcmc/__init__.py | 4 + blackjax/mcmc/dynamic_hmc.py | 5 +- blackjax/mcmc/dynamic_malt.py | 190 +++++++++++ blackjax/mcmc/hmc.py | 6 +- blackjax/mcmc/integrators.py | 4 +- blackjax/mcmc/malt.py | 306 ++++++++++++++++++ blackjax/mcmc/trajectory.py | 25 ++ blackjax/mcmc/underdamped_langevin.py | 20 +- 11 files changed, 568 insertions(+), 15 deletions(-) create mode 100644 blackjax/mcmc/dynamic_malt.py create mode 100644 blackjax/mcmc/malt.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index dbe1d87b8..de7e94d1b 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -16,9 +16,11 @@ from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc +from .mcmc import dynamic_malt as _dynamic_malt from .mcmc import elliptical_slice as _elliptical_slice from .mcmc import ghmc as _ghmc from .mcmc import hmc as _hmc +from .mcmc import malt as _malt from .mcmc import mala as _mala from .mcmc import marginal_latent_gaussian from .mcmc import mclmc as _mclmc @@ -97,12 +99,14 @@ def generate_top_level_api_from(module): # MCMC hmc = generate_top_level_api_from(_hmc) +malt = generate_top_level_api_from(_malt) nuts = generate_top_level_api_from(_nuts) rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh) irmh = GenerateSamplingAPI( irmh_as_top_level_api, random_walk.init, random_walk.build_irmh ) dynamic_hmc = generate_top_level_api_from(_dynamic_hmc) +dynamic_malt = generate_top_level_api_from(_dynamic_malt) rmhmc = generate_top_level_api_from(_rmhmc) mala = generate_top_level_api_from(_mala) mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian) diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 6236b19e3..f3a214e37 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -104,6 +104,7 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, + euclidean=euclidean, )( state, params, num_steps, window_key ) @@ -164,6 +165,7 @@ def adjusted_mclmc_make_L_step_size_adaptation( fix_L_first_da=False, max="avg", tuning_factor=1.0, + euclidean=False, ): """Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC""" @@ -215,6 +217,7 @@ def step(iteration_state, weight_and_key): step_size = jax.lax.clamp( 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 ) + # jax.debug.print("step size in adaptation {x}",x=step_size) adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) x = ravel_pytree(state.position)[0] @@ -264,6 +267,8 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps * frac_tune2 ) + # jax.debug.print("num steps1 {x}",x=num_steps1) + check_key, rng_key = jax.random.split(rng_key, 2) rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) @@ -309,16 +314,24 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): else: raise ValueError("max should be either 'max' or 'avg'") + new_L = params.L + if euclidean: + new_L /= jnp.sqrt(dim) + change = jax.lax.clamp( Lratio_lowerbound, - contract(variances) / params.L, + contract(variances) / new_L, Lratio_upperbound, ) + params = params._replace( L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) + if euclidean: + params = params._replace(inverse_mass_matrix=variances, L=1.) + else: + params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index c3f805003..48ccaec1f 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -75,7 +75,7 @@ def mclmc_find_L_and_step_size( The fraction of tuning for the second step of the adaptation. frac_tune3 The fraction of tuning for the third step of the adaptation. - desired_energy_va + desired_energy_var The desired energy variance for the MCMC algorithm. trust_in_estimate The trust in the estimate of optimal stepsize. diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 686a7082c..203fbd44c 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,6 +1,8 @@ from . import ( adjusted_mclmc, adjusted_mclmc_dynamic, + malt, + dynamic_malt, barker, elliptical_slice, ghmc, @@ -30,4 +32,6 @@ "underdamped_langevin" "adjusted_mclmc_dynamic", "adjusted_mclmc", + "dynamic_malt", + "malt", ] diff --git a/blackjax/mcmc/dynamic_hmc.py b/blackjax/mcmc/dynamic_hmc.py index 848922774..cdf323805 100644 --- a/blackjax/mcmc/dynamic_hmc.py +++ b/blackjax/mcmc/dynamic_hmc.py @@ -1,3 +1,4 @@ + # Copyright 2020- The Blackjax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -101,10 +102,6 @@ def kernel( inverse_mass_matrix, num_integration_steps, ) - - # jax.debug.print("logdensity {x}", x=hmc_proposal.logdensity) - # jax.debug.print("acceptance {x}", x=info) - next_random_arg = next_random_arg_fn(state.random_generator_arg) return ( DynamicHMCState( diff --git a/blackjax/mcmc/dynamic_malt.py b/blackjax/mcmc/dynamic_malt.py new file mode 100644 index 000000000..3260e6ca8 --- /dev/null +++ b/blackjax/mcmc/dynamic_malt.py @@ -0,0 +1,190 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Dynamic HMC Kernel""" +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.hmc import HMCInfo, HMCState +from blackjax.mcmc.malt import build_kernel as build_static_hmc_kernel +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.mcmc.dynamic_hmc import DynamicHMCState + +__all__ = [ + "DynamicHMCState", + "init", + "build_kernel", + "halton_sequence", + "as_top_level_api", +] + + + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integrator: Callable = integrators.velocity_verlet, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + L_proposal_factor: float = jnp.inf, +): + """Build a Dynamic HMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + hmc_base = build_static_hmc_kernel(integrator, divergence_threshold, L_proposal_factor) + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, + **integration_steps_kwargs, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the HMC kernel.""" + num_integration_steps = integration_steps_fn( + state.random_generator_arg, **integration_steps_kwargs + ) + hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad) + hmc_proposal, info = hmc_base( + rng_key, + hmc_state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + ) + + # jax.debug.print("logdensity {x}", x=hmc_proposal.logdensity) + # jax.debug.print("acceptance {x}", x=info) + + next_random_arg = next_random_arg_fn(state.random_generator_arg) + return ( + DynamicHMCState( + hmc_proposal.position, + hmc_proposal.logdensity, + hmc_proposal.logdensity_grad, + next_random_arg, + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + L_proposal_factor: float = jnp.inf, +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic HMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + inverse_mass_matrix + The value to use for the inverse mass matrix when drawing a value for + the momentum and computing the kinetic energy. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + kernel = build_kernel( + integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn, L_proposal_factor + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + # Note that rng_key here is not necessarily a PRNGKey, could be a Array that + # for generates a sequence of pseudo or quasi-random numbers (previously + # named as `random_generator_arg`) + return init(position, logdensity_fn, rng_key) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + ) + + return SamplingAlgorithm(init_fn, step_fn) + + +def halton_sequence(i: Array, max_bits: int = 10) -> float: + bit_masks = 2 ** jnp.arange(max_bits, dtype=i.dtype) + return jnp.einsum("i,i->", jnp.mod((i + 1) // bit_masks, 2), 0.5 / bit_masks) + + +def rescale(mu): + # Returns s, such that `round(U(0, 1) * s + 0.5)` has expected value mu. + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def halton_trajectory_length( + i: Array, trajectory_length_adjustment: float, max_bits: int = 10 +) -> int: + """Generate a quasi-random number of integration steps.""" + s = rescale(trajectory_length_adjustment) + return jnp.asarray(jnp.rint(0.5 + halton_sequence(i, max_bits) * s), dtype=int) diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 8582342e6..3869a715b 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -136,8 +136,6 @@ def kernel( position, logdensity, logdensity_grad = state momentum = metric.sample_momentum(key_momentum, position) - # import jax.numpy as jnp - # jax.debug.print("momentum nan? {x}",x=jnp.any(jnp.isnan(momentum))) integrator_state = integrators.IntegratorState( position, momentum, logdensity, logdensity_grad @@ -301,9 +299,7 @@ def generate( proposal_energy = hmc_energy_fn(state) new_energy = hmc_energy_fn(end_state) delta_energy = safe_energy_diff(proposal_energy, new_energy) - # jax.debug.print("delta_energy {x}",x=delta_energy) is_diverging = -delta_energy > divergence_threshold - # jax.debug.print("is_diverging {x}",x=is_diverging) sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info @@ -339,4 +335,4 @@ def flip_momentum( flipped_momentum, state.logdensity, state.logdensity_grad, - ) + ) \ No newline at end of file diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index f88c5eeaa..23b36f630 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -531,8 +531,8 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): new_state = integrator(state, step_size) # jax.debug.print("state 2 {x}",x=state) - kinetic_change = - kinetic_energy(new_state.momentum) + kinetic_energy( - state.momentum + kinetic_change = - kinetic_energy(state.momentum) + kinetic_energy( + new_state.momentum ) energy_change = kinetic_change - new_state.logdensity + state.logdensity diff --git a/blackjax/mcmc/malt.py b/blackjax/mcmc/malt.py new file mode 100644 index 000000000..2630c1a86 --- /dev/null +++ b/blackjax/mcmc/malt.py @@ -0,0 +1,306 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the HMC Kernel""" +from typing import Callable, NamedTuple, Union + +import jax +import jax.numpy as jnp +import blackjax.mcmc.integrators as integrators +import blackjax.mcmc.metrics as metrics +import blackjax.mcmc.trajectory as trajectory +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.proposal import safe_energy_diff, static_binomial_sampling +from blackjax.mcmc.trajectory import hmc_energy +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.mcmc.hmc import HMCState, HMCInfo + +__all__ = [ + "HMCState", + "HMCInfo", + "init", + "build_kernel", + "as_top_level_api", +] + + + +def init(position: ArrayLikeTree, logdensity_fn: Callable): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return HMCState(position, logdensity, logdensity_grad) + + +def build_kernel( + integrator: Callable = integrators.velocity_verlet, + divergence_threshold: float = 1000, + L_proposal_factor = jnp.inf, +): + """Build a HMC kernel. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is + divergent. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + + def kernel( + rng_key: PRNGKey, + state: HMCState, + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes, + num_integration_steps: int, + + ) -> tuple[HMCState, HMCInfo]: + """Generate a new sample with the HMC kernel.""" + + L = num_integration_steps * step_size + L_proposal = L_proposal_factor * L + + # jax.debug.print("L_proposal {x}",x=L_proposal) + + key_trajectory, key_momentum, key_integrator = jax.random.split(rng_key, 3) + metric = metrics.default_metric(inverse_mass_matrix) + symplectic_integrator = lambda state, step_size, rng_key: integrators.with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy, inverse_mass_matrix)(state, step_size, L_proposal=L_proposal, rng_key=rng_key) + proposal_generator = hmc_proposal( + symplectic_integrator, + metric.kinetic_energy, + step_size, + num_integration_steps, + divergence_threshold, + rng_key=key_trajectory, + ) + + + position, logdensity, logdensity_grad = state + momentum = metric.sample_momentum(key_momentum, position) + # import jax.numpy as jnp + # jax.debug.print("momentum nan? {x}",x=jnp.any(jnp.isnan(momentum))) + + integrator_state = integrators.IntegratorState( + position, momentum, logdensity, logdensity_grad + ) + proposal, info, _ = proposal_generator(key_integrator, integrator_state) + proposal = HMCState( + proposal.position, proposal.logdensity, proposal.logdensity_grad + ) + + return proposal, info + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes, + num_integration_steps: int, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, + L_proposal_factor = jnp.inf, +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the HMC kernel. + + The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias + `blackjax.hmc.build_kernel`) can be cumbersome to manipulate. Since most users only + need to specify the kernel parameters at initialization time, we provide a helper + function that specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.hmc` to SMC, adaptation, etc. algorithms. + + Examples + -------- + + A new HMC kernel can be initialized and used with the following code: + + .. code:: + + hmc = blackjax.hmc( + logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps + ) + state = hmc.init(position) + new_state, info = hmc.step(rng_key, state) + + Kernels are not jit-compiled by default so you will need to do it manually: + + .. code:: + + step = jax.jit(hmc.step) + new_state, info = step(rng_key, state) + + Should you need to you can always use the base kernel directly: + + .. code:: + + import blackjax.mcmc.integrators as integrators + + kernel = blackjax.hmc.build_kernel(integrators.mclachlan) + state = blackjax.hmc.init(position, logdensity_fn) + state, info = kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + ) + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + inverse_mass_matrix + The value to use for the inverse mass matrix when drawing a value for + the momentum and computing the kinetic energy. This argument will be + passed to the ``metrics.default_metric`` function so it supports the + full interface presented there. + num_integration_steps + The number of steps we take with the symplectic integrator at each + sample step before returning a sample. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the + trajectory. + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel(integrator, divergence_threshold) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, step_fn) + + +def hmc_proposal( + integrator: Callable, + kinetic_energy: metrics.KineticEnergy, + step_size: Union[float, ArrayLikeTree], + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, + rng_key: PRNGKey, +) -> Callable: + """Vanilla HMC algorithm. + + The algorithm integrates the trajectory applying a symplectic integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + Symplectic integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the symplectic integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + build_trajectory = trajectory.langevin_integration(integrator, rng_key) + hmc_energy_fn = hmc_energy(kinetic_energy) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, delta_energy = build_trajectory(state, step_size, num_integration_steps) + end_state = flip_momentum(end_state) + # proposal_energy = hmc_energy_fn(state) + new_energy = hmc_energy_fn(end_state) + # delta_energy = safe_energy_diff(proposal_energy, new_energy) + + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + # jax.debug.print("delta_energy_trajectory {x}",x=(delta_energy_trajectory, delta_energy)) + # jax.debug.print("delta_energy {x}",x=delta_energy) + is_diverging = -delta_energy > divergence_threshold + # jax.debug.print("is_diverging {x}",x=is_diverging) + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def flip_momentum( + state: integrators.IntegratorState, +) -> integrators.IntegratorState: + """Flip the momentum at the end of the trajectory. + + To guarantee time-reversibility (hence detailed balance) we + need to flip the last state's momentum. If we run the hamiltonian + dynamics starting from the last state with flipped momentum we + should indeed retrieve the initial state (with flipped momentum). + + """ + flipped_momentum = jax.tree_util.tree_map(lambda m: -1.0 * m, state.momentum) + return integrators.IntegratorState( + state.position, + flipped_momentum, + state.logdensity, + state.logdensity_grad, + ) diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 7bb1b35a5..14851c03c 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -127,6 +127,31 @@ def one_step(_, state): return integrate +def langevin_integration( + integrator: Callable, + rng_key: PRNGKey, + direction: int = 1, +) -> Callable: + """Generate a trajectory by integrating several times in one direction.""" + + def integrate( + initial_state: IntegratorState, step_size, num_integration_steps + ) -> IntegratorState: + directed_step_size = jax.tree_util.tree_map( + lambda step_size: direction * step_size, step_size + ) + + def one_step(_, accum): + state, delta_energy, rng_key = accum + rng_key, key2 = jax.random.split(rng_key) + new_state, (_, delta_energy_new) = integrator(state, directed_step_size, rng_key) + return (new_state, delta_energy+delta_energy_new, key2) + + state, delta_energy, _ = jax.lax.fori_loop(0, num_integration_steps, one_step, (initial_state, 0.0, rng_key)) + return state, -delta_energy + + return integrate + class DynamicIntegrationState(NamedTuple): step: int diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py index d92de73e0..b064ce007 100644 --- a/blackjax/mcmc/underdamped_langevin.py +++ b/blackjax/mcmc/underdamped_langevin.py @@ -26,7 +26,7 @@ import blackjax.mcmc.metrics as metrics import jax.numpy as jnp from blackjax.util import pytree_size - +from blackjax.adaptation.mclmc_adaptation import handle_high_energy __all__ = ["LangevinInfo", "init", "build_kernel", "as_top_level_api"] @@ -112,6 +112,24 @@ def kernel( ndims = pytree_size(position) # jax.debug.print("diagnostics {x}", x=(eev_max_per_dim, jnp.abs(energy_error), jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim))) + energy_key, rng_key = jax.random.split(rng_key) + + energy, new_state = handle_high_energy( + previous_state=state, + next_state=IntegratorState(position, momentum, logdensity, logdensitygrad), + energy_change=energy_error, + key=energy_key, + inverse_mass_matrix=inverse_mass_matrix, + cutoff=jnp.sqrt(ndims * eev_max_per_dim), + euclidean=True + ) + + return new_state, LangevinInfo( + logdensity=new_state.logdensity, + energy_change=energy, + kinetic_change=kinetic_change + ) + new_state, new_info = jax.lax.cond( jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim), lambda: ( From 89ccc8b2cefde4f60eebd0527e09aa2e2f7a2b84 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Tue, 3 Jun 2025 05:14:43 -0700 Subject: [PATCH 44/63] updates --- blackjax/__init__.py | 4 + blackjax/mcmc/__init__.py | 6 +- blackjax/mcmc/hmc.py | 2 + blackjax/mcmc/integrators.py | 1 + blackjax/mcmc/malt.py | 1 + blackjax/mcmc/mchmc.py | 173 +++++++++++++++++++++++++ blackjax/mcmc/uhmc.py | 180 ++++++++++++++++++++++++++ blackjax/mcmc/underdamped_langevin.py | 40 +++--- 8 files changed, 386 insertions(+), 21 deletions(-) create mode 100644 blackjax/mcmc/mchmc.py create mode 100644 blackjax/mcmc/uhmc.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index de7e94d1b..6b5671351 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -20,10 +20,12 @@ from .mcmc import elliptical_slice as _elliptical_slice from .mcmc import ghmc as _ghmc from .mcmc import hmc as _hmc +from .mcmc import uhmc as _uhmc from .mcmc import malt as _malt from .mcmc import mala as _mala from .mcmc import marginal_latent_gaussian from .mcmc import mclmc as _mclmc +from .mcmc import mchmc as _mchmc from .mcmc import underdamped_langevin as _langevin from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk @@ -99,6 +101,7 @@ def generate_top_level_api_from(module): # MCMC hmc = generate_top_level_api_from(_hmc) +uhmc = generate_top_level_api_from(_uhmc) malt = generate_top_level_api_from(_malt) nuts = generate_top_level_api_from(_nuts) rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh) @@ -119,6 +122,7 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +mchmc = generate_top_level_api_from(_mchmc) langevin = generate_top_level_api_from(_langevin) adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 203fbd44c..47bc0837f 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -7,9 +7,11 @@ elliptical_slice, ghmc, hmc, + uhmc, mala, marginal_latent_gaussian, mclmc, + mchmc, nuts, periodic_orbital, random_walk, @@ -29,7 +31,9 @@ "marginal_latent_gaussian", "random_walk", "mclmc", - "underdamped_langevin" + "mchmc", + "underdamped_langevin", + "uhmc", "adjusted_mclmc_dynamic", "adjusted_mclmc", "dynamic_malt", diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 3869a715b..6d5b8ed8e 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -121,6 +121,7 @@ def kernel( num_integration_steps: int, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the HMC kernel.""" + metric = metrics.default_metric(inverse_mass_matrix) symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy) @@ -302,6 +303,7 @@ def generate( is_diverging = -delta_energy > divergence_threshold sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info + info = HMCInfo( state.momentum, diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 23b36f630..3cd611140 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -167,6 +167,7 @@ def update( coef: float, auxiliary_info=None, ): + # jax.debug.print("nan? {x}", x=jnp.any(jnp.isnan(kinetic_grad))) # jax.debug.print("position {x}", x=position) diff --git a/blackjax/mcmc/malt.py b/blackjax/mcmc/malt.py index 2630c1a86..65b53ef88 100644 --- a/blackjax/mcmc/malt.py +++ b/blackjax/mcmc/malt.py @@ -270,6 +270,7 @@ def generate( # jax.debug.print("is_diverging {x}",x=is_diverging) sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info + jax.debug.print("p_accept {p_accept}", p_accept=(p_accept, delta_energy)) info = HMCInfo( state.momentum, diff --git a/blackjax/mcmc/mchmc.py b/blackjax/mcmc/mchmc.py new file mode 100644 index 000000000..544eb79c8 --- /dev/null +++ b/blackjax/mcmc/mchmc.py @@ -0,0 +1,173 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the MCLMC Kernel""" +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import ( + IntegratorState, + isokinetic_mclachlan, + with_isokinetic_maruyama, +) +from blackjax.types import ArrayLike, PRNGKey +from blackjax.util import generate_unit_vector, pytree_size +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale + +__all__ = ["MCLMCInfo", "init", "build_kernel", "as_top_level_api"] + +class MCHMCState(NamedTuple): + position: ArrayLike + momentum: ArrayLike + logdensity: float + logdensity_grad: ArrayLike + steps_until_refresh: int + +class MCLMCInfo(NamedTuple): + """ + Additional information on the MCLMC transition. + + logdensity + The log-density of the distribution at the current step of the MCLMC chain. + kinetic_change + The difference in kinetic energy between the current and previous step. + energy_change + The difference in energy between the current and previous step. + """ + + logdensity: float + kinetic_change: float + energy_change: float + + +def init(position: ArrayLike, logdensity_fn, rng_key): + if pytree_size(position) < 2: + raise ValueError( + "The target distribution must have more than 1 dimension for MCLMC." + ) + l, g = jax.value_and_grad(logdensity_fn)(position) + + return MCHMCState( + position=position, + momentum=generate_unit_vector(rng_key, position), + logdensity=l, + logdensity_grad=g, + steps_until_refresh=0, + ) + +def integrator_state(state: MCHMCState) -> IntegratorState: + return IntegratorState( + position=state.position, + momentum=state.momentum, + logdensity=state.logdensity, + logdensity_grad=state.logdensity_grad, + ) + + +def build_kernel( + # integration_steps_fn, + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +): + """ + """ + + step = integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) + + + def kernel( + rng_key: PRNGKey, state: MCHMCState, L: float, step_size: float + ) -> tuple[MCHMCState, MCLMCInfo]: + (position, momentum, logdensity, logdensitygrad), kinetic_change = step( + integrator_state(state), step_size + ) + + # num_integration_steps = integration_steps_fn(state.random_generator_arg) + jitter_key, refresh_key = jax.random.split(rng_key) + + num_steps_per_traj = jnp.ceil(L/step_size).astype(int) + + + num_steps_per_traj = jnp.ceil( + jax.random.uniform(jitter_key) * rescale(num_steps_per_traj) + ).astype(int) + + + + energy_error = kinetic_change - logdensity + state.logdensity + + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) + + momentum=(state.steps_until_refresh==0) * generate_unit_vector(refresh_key, state.position) + (state.steps_until_refresh>0) * momentum + # new_state = new_state._replace(momentum=generate_unit_vector(refresh_key, new_state.position)) + + steps_until_refresh = (state.steps_until_refresh==0) * num_steps_per_traj + (state.steps_until_refresh>0) * (state.steps_until_refresh - 1) + # jax.debug.print("steps_until_refresh: {x}", x=steps_until_refresh) + + new_state, new_info = jax.lax.cond( + energy_error > jnp.sqrt(ndims * eev_max_per_dim), + lambda: ( + state, + MCLMCInfo( + logdensity=state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + ), + ), + lambda: ( + MCHMCState(position, momentum, logdensity, logdensitygrad, steps_until_refresh), + MCLMCInfo( + logdensity=logdensity, + energy_change=energy_error, + kinetic_change=kinetic_change, + ), + ), + ) + + return new_state, new_info + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + L, + step_size, + integrator=isokinetic_mclachlan, + inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, +) -> SamplingAlgorithm: + """ + """ + + kernel = build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + ) + + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key, state): + return kernel(rng_key, state, L, step_size) + + return SamplingAlgorithm(init_fn, update_fn) diff --git a/blackjax/mcmc/uhmc.py b/blackjax/mcmc/uhmc.py new file mode 100644 index 000000000..f0eb14444 --- /dev/null +++ b/blackjax/mcmc/uhmc.py @@ -0,0 +1,180 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Underdamped Langevin Kernel""" +from typing import Callable, NamedTuple + +import jax + +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import ( + IntegratorState, + with_maruyama, + velocity_verlet, +) +from blackjax.types import ArrayLike, PRNGKey +import blackjax.mcmc.metrics as metrics +import jax.numpy as jnp +from blackjax.util import pytree_size, generate_unit_vector +from blackjax.adaptation.mclmc_adaptation import handle_high_energy +__all__ = ["LangevinInfo", "init", "build_kernel", "as_top_level_api"] + +class UHMCState(NamedTuple): + position: ArrayLike + momentum: ArrayLike + logdensity: float + logdensity_grad: ArrayLike + steps_until_refresh: int + +def integrator_state(state: UHMCState) -> IntegratorState: + return IntegratorState( + position=state.position, + momentum=state.momentum, + logdensity=state.logdensity, + logdensity_grad=state.logdensity_grad, + ) + +class LangevinInfo(NamedTuple): + """ + Additional information on the Langevin transition. + + logdensity + The log-density of the distribution at the current step of the Langevin chain. + kinetic_change + The difference in kinetic energy between the current and previous step. + energy_change + The difference in energy between the current and previous step. + """ + + logdensity: float + kinetic_change: float + energy_change: float + + +def init(position: ArrayLike, logdensity_fn, metric, rng_key): + + l, g = jax.value_and_grad(logdensity_fn)(position) + + return UHMCState( + position=position, + momentum = metric.sample_momentum(rng_key, position), + logdensity=l, + logdensity_grad=g, + steps_until_refresh=0, + ) + + +def build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4,): + """Build a HMC kernel. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Langevin dynamics. + L + the momentum decoherence rate. + step_size + step size of the integrator. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + + metric = metrics.default_metric(inverse_mass_matrix) + step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy,inverse_mass_matrix) + + def kernel( + rng_key: PRNGKey, state: UHMCState, L: float, step_size: float + ) -> tuple[UHMCState, LangevinInfo]: + + refresh_key, energy_key, run_key = jax.random.split(rng_key, 3) + + (position, momentum, logdensity, logdensitygrad), (kinetic_change, energy_error) = step( + integrator_state(state), step_size, jnp.inf, run_key + ) + + + num_steps_per_traj = jnp.ceil(L/step_size).astype(jnp.int64) + momentum = (state.steps_until_refresh==0) * metric.sample_momentum(refresh_key, position) + (state.steps_until_refresh>0) * momentum + steps_until_refresh = (state.steps_until_refresh==0) * num_steps_per_traj + (state.steps_until_refresh>0) * (state.steps_until_refresh - 1) + + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) + + + energy, new_integrator_state = handle_high_energy( + previous_state=integrator_state(state), + next_state=IntegratorState(position, momentum, logdensity, logdensitygrad), + energy_change=energy_error, + key=energy_key, + inverse_mass_matrix=inverse_mass_matrix, + cutoff=jnp.sqrt(ndims * eev_max_per_dim), + euclidean=True + ) + + return UHMCState(new_integrator_state.position, new_integrator_state.momentum, new_integrator_state.logdensity, new_integrator_state.logdensity_grad, steps_until_refresh), LangevinInfo( + logdensity=logdensity, + energy_change=energy, + kinetic_change=kinetic_change + ) + + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + L, + step_size, + integrator=velocity_verlet, + inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +) -> SamplingAlgorithm: + """The general Langevin kernel builder (:meth:`blackjax.mcmc.langevin.build_kernel`, alias `blackjax.langevin.build_kernel`) can be + cumbersome to manipulate. Since most users only need to specify the kernel + parameters at initialization time, we provide a helper function that + specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.langevin` to SMC, adaptation, etc. algorithms. + + + + """ + + kernel = build_kernel( + logdensity_fn, + inverse_mass_matrix, + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + desired_energy_var=desired_energy_var, + ) + metric = metrics.default_metric(inverse_mass_matrix) + + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, metric, rng_key) + + def update_fn(rng_key, state): + return kernel(rng_key, state, L, step_size) + + return SamplingAlgorithm(init_fn, update_fn) diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py index b064ce007..272e79105 100644 --- a/blackjax/mcmc/underdamped_langevin.py +++ b/blackjax/mcmc/underdamped_langevin.py @@ -130,27 +130,27 @@ def kernel( kinetic_change=kinetic_change ) - new_state, new_info = jax.lax.cond( - jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim), - lambda: ( - state, - LangevinInfo( - logdensity=state.logdensity, - energy_change=0.0, - kinetic_change=0.0, - ), - ), - lambda: ( - IntegratorState(position, momentum, logdensity, logdensitygrad), - LangevinInfo( - logdensity=logdensity, - energy_change=energy_error, - kinetic_change=kinetic_change, - ), - ), - ) + # new_state, new_info = jax.lax.cond( + # jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim), + # lambda: ( + # state, + # LangevinInfo( + # logdensity=state.logdensity, + # energy_change=0.0, + # kinetic_change=0.0, + # ), + # ), + # lambda: ( + # IntegratorState(position, momentum, logdensity, logdensitygrad), + # LangevinInfo( + # logdensity=logdensity, + # energy_change=energy_error, + # kinetic_change=kinetic_change, + # ), + # ), + # ) - return new_state, new_info + # return new_state, new_info return kernel From 6fe396379d4ff297e2e34adea56bc3361efdba49 Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Wed, 4 Jun 2025 05:35:48 -0700 Subject: [PATCH 45/63] update --- blackjax/util.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/blackjax/util.py b/blackjax/util.py index 295fdc691..6010bf05d 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -426,13 +426,15 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): entropy = jnp.zeros((num_steps,)) equi_diag = jnp.zeros((num_steps,)) equi_full = jnp.zeros((num_steps,)) + bias0 = jnp.zeros((num_steps,)) + bias1 = jnp.zeros((num_steps,)) observables = jnp.zeros((num_steps,)) r_avg = jnp.zeros((num_steps,)) r_max = jnp.zeros((num_steps,)) step_size = jnp.zeros((num_steps,)) def step_while(a): - x, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size = a + x, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, step_size = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) @@ -443,18 +445,20 @@ def step_while(a): new_entropy = entropy.at[i].set(info.get("entropy")) new_equi_diag = equi_diag.at[i].set(info.get("equi_diag")) new_equi_full = equi_full.at[i].set(info.get("equi_full")) + new_bias0 = bias0.at[i].set(info.get("bias")[0]) + new_bias1 = bias1.at[i].set(info.get("bias")[1]) new_observables = observables.at[i].set(info.get("observables")) new_r_avg = r_avg.at[i].set(info.get("r_avg")) new_r_max = r_max.at[i].set(info.get("r_max")) new_step_size = step_size.at[i].set(info.get("step_size")) - return (output, i + 1, info.get("while_cond"), new_EEVPD, new_EEVPD_wanted, new_L, new_entropy, new_equi_diag, new_equi_full, new_observables, new_r_avg, new_r_max, new_step_size) + return (output, i + 1, info.get("while_cond"), new_EEVPD, new_EEVPD_wanted, new_L, new_entropy, new_equi_diag, new_equi_full, new_bias0, new_bias1, new_observables, new_r_avg, new_r_max, new_step_size) if early_stop: - final_state_all, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size = lax.while_loop( + final_state_all, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, step_size = lax.while_loop( lambda a: ((a[1] < num_steps) & a[2]), step_while, - (initial_state_all, 0, True, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, observables, r_avg, r_max, step_size), + (initial_state_all, 0, True, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, step_size), ) steps_done = i info_history = { @@ -464,6 +468,8 @@ def step_while(a): "entropy": entropy, "equi_diag": equi_diag, "equi_full": equi_full, + "bias0": bias0, + "bias1": bias1, "observables": observables, "r_avg": r_avg, "r_max": r_max, From 1ea9ac4e6725b9594238fc8256da825a888ce66a Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Thu, 5 Jun 2025 06:35:14 -0700 Subject: [PATCH 46/63] fixed unadjusted phase (removed diagonal precond), tried out nuts-style L tuning --- blackjax/adaptation/ensemble_umclmc.py | 16 ++++++---------- blackjax/mcmc/mclmc.py | 1 + 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index ff10b7094..5feef4b0c 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -50,7 +50,7 @@ def build_kernel(logdensity_fn): def sequential_kernel(key, state, adap): new_state, info = mclmc.build_kernel( - logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet, inverse_mass_matrix=adap.inverse_mass_matrix + logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet, inverse_mass_matrix= jnp.ones(adap.inverse_mass_matrix.shape) )(key, state, adap.L, adap.step_size) # reject the new state if there were nans @@ -256,6 +256,7 @@ def summary_statistics_fn(self, state, info, rng_key): "observables_for_bias": self.observables_for_bias(state.position), "observables": self.observables(state.position), "entropy": -info["logdensity"], + "uturn": jnp.sqrt(jnp.sum(jnp.square(state.logdensity_grad - jnp.dot(state.logdensity_grad, state.momentum) * state.momentum))) / (self.ndims - 1) } def update(self, adaptation_state, Etheta): @@ -280,9 +281,8 @@ def update(self, adaptation_state, Etheta): ) history = History(history_observables, history_stopping, history_weights) - L = self.alpha * jnp.sqrt( - jnp.sum(Etheta["xsq"] - jnp.square(Etheta["x"])) - ) # average over the ensemble, sum over parameters (to get sqrt(d)) + L = self.alpha * jnp.sqrt(jnp.sum(Etheta["xsq"] - jnp.square(Etheta["x"]))) # average over the ensemble, sum over parameters (to get sqrt(d)) + #L = self.alpha / Etheta["uturn"] inverse_mass_matrix = Etheta["xsq"] - jnp.square(Etheta["x"]) EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.ndims true_bias = self.contract(Etheta["observables_for_bias"]) @@ -290,17 +290,13 @@ def update(self, adaptation_state, Etheta): # hyperparameter adaptation # estimate bias - bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[ - self.bias_type - ] # r_max, r_avg, equi_full, equi_diag + bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag EEVPD_wanted = self.C * jnp.power(bias, self.power) eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1.0 / 6.0) eps_factor = jnp.clip(eps_factor, 0.3, 3.0) - eps_factor = nan_reject( - 1 - nans, 0.5, eps_factor - ) # reduce the stepsize if there were nans + eps_factor = nan_reject(1 - nans, 0.5, eps_factor) # reduce the stepsize if there were nans # determine if we want to finish this stage (i.e. if loss is no longer decreassing) # increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 771c542f7..88438ce53 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -61,6 +61,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): ) + def build_kernel( logdensity_fn, inverse_mass_matrix, From 82ed6da8ed2cde807de39515f2b223be725d80b0 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Thu, 12 Jun 2025 10:19:01 -0700 Subject: [PATCH 47/63] working branch --- blackjax/mcmc/malt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blackjax/mcmc/malt.py b/blackjax/mcmc/malt.py index 65b53ef88..572264589 100644 --- a/blackjax/mcmc/malt.py +++ b/blackjax/mcmc/malt.py @@ -270,7 +270,7 @@ def generate( # jax.debug.print("is_diverging {x}",x=is_diverging) sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info - jax.debug.print("p_accept {p_accept}", p_accept=(p_accept, delta_energy)) + # jax.debug.print("p_accept {p_accept}", p_accept=(p_accept, delta_energy)) info = HMCInfo( state.momentum, From c44dce3311fb176c392078ee436576b2c81c359a Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Thu, 19 Jun 2025 04:27:07 -0700 Subject: [PATCH 48/63] attempted to change the nuts angle --- blackjax/adaptation/window_adaptation.py | 3 ++- blackjax/mcmc/metrics.py | 11 +++++++---- blackjax/mcmc/nuts.py | 6 ++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 69a098325..4e1848a75 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -251,6 +251,7 @@ def window_adaptation( progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, integrator=mcmc.integrators.velocity_verlet, + cos_angle_termination = 0., **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -296,7 +297,7 @@ def window_adaptation( """ - mcmc_kernel = algorithm.build_kernel(integrator) + mcmc_kernel = algorithm.build_kernel(integrator, cos_angle_termination= cos_angle_termination) adapt_init, adapt_step, adapt_final = base( is_mass_matrix_diagonal, diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index f0720acf4..5ff6364f4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -81,7 +81,7 @@ class Metric(NamedTuple): MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]] -def default_metric(metric: MetricTypes) -> Metric: +def default_metric(metric: MetricTypes, cos_angle_termination= 0.) -> Metric: """Convert an input metric into a ``Metric`` object following sensible default rules The metric can be specified in three different ways: @@ -102,11 +102,12 @@ def default_metric(metric: MetricTypes) -> Metric: # If we make it here then the argument should be an array, and we'll assume # that it specifies a static inverse mass matrix. - return gaussian_euclidean(metric) + return gaussian_euclidean(metric, cos_angle_termination= cos_angle_termination) def gaussian_euclidean( inverse_mass_matrix: Array, + cos_angle_termination: float, ) -> Metric: r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum :cite:p:`betancourt2013general`. @@ -186,8 +187,10 @@ def is_turning( # rho = m_sum rho = m_sum - (m_right + m_left) / 2 - turning_at_left = jnp.dot(velocity_left, rho) <= 0 - turning_at_right = jnp.dot(velocity_right, rho) <= 0 + cos_distance = lambda v1, v2: jnp.dot(v1, v2) / jnp.sqrt(jnp.dot(v1, v1) * jnp.dot(v2, v2)) + + turning_at_left = cos_distance(velocity_left, rho) <= cos_angle_termination + turning_at_right = cos_distance(velocity_right, rho) <= cos_angle_termination return turning_at_left | turning_at_right def scale( diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 73d8aab4c..2740c88e0 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -77,6 +77,7 @@ class NUTSInfo(NamedTuple): def build_kernel( integrator: Callable = integrators.velocity_verlet, divergence_threshold: int = 1000, + cos_angle_termination = 0. ): """Build an iterative NUTS kernel. @@ -121,7 +122,7 @@ def kernel( """Generate a new sample with the NUTS kernel.""" - metric = metrics.default_metric(inverse_mass_matrix) + metric = metrics.default_metric(inverse_mass_matrix, cos_angle_termination=cos_angle_termination) symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy) proposal_generator = iterative_nuts_proposal( symplectic_integrator, @@ -159,6 +160,7 @@ def as_top_level_api( max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = integrators.velocity_verlet, + cos_angle_termination: float = 0.0, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the nuts kernel. @@ -214,7 +216,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(integrator, divergence_threshold) + kernel = build_kernel(integrator, divergence_threshold, cos_angle_termination= cos_angle_termination) def init_fn(position: ArrayLikeTree, rng_key=None): del rng_key From abc0cb713dc832e7974293bba376b91e87d4d33f Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Thu, 19 Jun 2025 04:30:01 -0700 Subject: [PATCH 49/63] Revert "attempted to change the nuts angle" This reverts commit c44dce3311fb176c392078ee436576b2c81c359a. --- blackjax/adaptation/window_adaptation.py | 3 +-- blackjax/mcmc/metrics.py | 11 ++++------- blackjax/mcmc/nuts.py | 6 ++---- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 4e1848a75..69a098325 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -251,7 +251,6 @@ def window_adaptation( progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, integrator=mcmc.integrators.velocity_verlet, - cos_angle_termination = 0., **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -297,7 +296,7 @@ def window_adaptation( """ - mcmc_kernel = algorithm.build_kernel(integrator, cos_angle_termination= cos_angle_termination) + mcmc_kernel = algorithm.build_kernel(integrator) adapt_init, adapt_step, adapt_final = base( is_mass_matrix_diagonal, diff --git a/blackjax/mcmc/metrics.py b/blackjax/mcmc/metrics.py index 5ff6364f4..f0720acf4 100644 --- a/blackjax/mcmc/metrics.py +++ b/blackjax/mcmc/metrics.py @@ -81,7 +81,7 @@ class Metric(NamedTuple): MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]] -def default_metric(metric: MetricTypes, cos_angle_termination= 0.) -> Metric: +def default_metric(metric: MetricTypes) -> Metric: """Convert an input metric into a ``Metric`` object following sensible default rules The metric can be specified in three different ways: @@ -102,12 +102,11 @@ def default_metric(metric: MetricTypes, cos_angle_termination= 0.) -> Metric: # If we make it here then the argument should be an array, and we'll assume # that it specifies a static inverse mass matrix. - return gaussian_euclidean(metric, cos_angle_termination= cos_angle_termination) + return gaussian_euclidean(metric) def gaussian_euclidean( inverse_mass_matrix: Array, - cos_angle_termination: float, ) -> Metric: r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum :cite:p:`betancourt2013general`. @@ -187,10 +186,8 @@ def is_turning( # rho = m_sum rho = m_sum - (m_right + m_left) / 2 - cos_distance = lambda v1, v2: jnp.dot(v1, v2) / jnp.sqrt(jnp.dot(v1, v1) * jnp.dot(v2, v2)) - - turning_at_left = cos_distance(velocity_left, rho) <= cos_angle_termination - turning_at_right = cos_distance(velocity_right, rho) <= cos_angle_termination + turning_at_left = jnp.dot(velocity_left, rho) <= 0 + turning_at_right = jnp.dot(velocity_right, rho) <= 0 return turning_at_left | turning_at_right def scale( diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 2740c88e0..73d8aab4c 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -77,7 +77,6 @@ class NUTSInfo(NamedTuple): def build_kernel( integrator: Callable = integrators.velocity_verlet, divergence_threshold: int = 1000, - cos_angle_termination = 0. ): """Build an iterative NUTS kernel. @@ -122,7 +121,7 @@ def kernel( """Generate a new sample with the NUTS kernel.""" - metric = metrics.default_metric(inverse_mass_matrix, cos_angle_termination=cos_angle_termination) + metric = metrics.default_metric(inverse_mass_matrix) symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy) proposal_generator = iterative_nuts_proposal( symplectic_integrator, @@ -160,7 +159,6 @@ def as_top_level_api( max_num_doublings: int = 10, divergence_threshold: int = 1000, integrator: Callable = integrators.velocity_verlet, - cos_angle_termination: float = 0.0, ) -> SamplingAlgorithm: """Implements the (basic) user interface for the nuts kernel. @@ -216,7 +214,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(integrator, divergence_threshold, cos_angle_termination= cos_angle_termination) + kernel = build_kernel(integrator, divergence_threshold) def init_fn(position: ArrayLikeTree, rng_key=None): del rng_key From 2d33a1fa1fad7eef9c131c4cd126f08ccf65a1e7 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Mon, 23 Jun 2025 08:41:19 -0700 Subject: [PATCH 50/63] working branch --- blackjax/__init__.py | 3 +- .../adaptation/adjusted_mclmc_adaptation.py | 28 +++- blackjax/mcmc/hmc.py | 5 +- blackjax/mcmc/integrators.py | 6 +- blackjax/mcmc/mclmc.py | 10 +- blackjax/mcmc/nuts.py | 5 +- blackjax/mcmc/pseudofermion.py | 151 ++++++++++++++++++ 7 files changed, 196 insertions(+), 12 deletions(-) create mode 100644 blackjax/mcmc/pseudofermion.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 6b5671351..66ba98e63 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -23,6 +23,7 @@ from .mcmc import uhmc as _uhmc from .mcmc import malt as _malt from .mcmc import mala as _mala +from .mcmc import pseudofermion as _pseudofermion from .mcmc import marginal_latent_gaussian from .mcmc import mclmc as _mclmc from .mcmc import mchmc as _mchmc @@ -129,7 +130,7 @@ def generate_top_level_api_from(module): elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) - +pseudofermion = generate_top_level_api_from(_pseudofermion) hmc_family = [hmc, nuts] # SMC diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index f3a214e37..00767ce7a 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -268,8 +268,9 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): ) # jax.debug.print("num steps1 {x}",x=num_steps1) + # jax.debug.print("num steps 2 {x}",x=num_steps2) - check_key, rng_key = jax.random.split(rng_key, 2) + # check_key, rng_key = jax.random.split(rng_key, 2) rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) L_step_size_adaptation_keys_pass1 = jax.random.split( @@ -306,23 +307,34 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): variances = x_squared_average - jnp.square(x_average) if max == "max": - contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor + if euclidean: + contract = lambda x: (jnp.sqrt(jnp.max(x) * dim) * tuning_factor) / jnp.sqrt(dim) + + else: + contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor elif max == "avg": - contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor + print("avg") + if euclidean: + + contract = lambda x: (jnp.sqrt(jnp.sum(x)) * tuning_factor) / jnp.sqrt(dim) + else: + contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor else: raise ValueError("max should be either 'max' or 'avg'") new_L = params.L - if euclidean: - new_L /= jnp.sqrt(dim) change = jax.lax.clamp( Lratio_lowerbound, contract(variances) / new_L, Lratio_upperbound, ) + # if euclidean: + # # new_L /= jnp.sqrt(dim) + # change /= jnp.sqrt(dim) + params = params._replace( L=params.L * change, step_size=params.step_size * change @@ -332,6 +344,12 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): params = params._replace(inverse_mass_matrix=variances, L=1.) else: params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) + + # else: + # if euclidean: + # params = params._replace(L = params.L / jnp.sqrt(dim)) + + # jax.debug.print("params L {x}", x=(params.L, contract(variances), jnp.sum(variances), tuning_factor)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 6d5b8ed8e..384c96b28 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -135,7 +135,10 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) - position, logdensity, logdensity_grad = state + # position, logdensity, logdensity_grad = state + position = state.position + logdensity = state.logdensity + logdensity_grad = state.logdensity_grad momentum = metric.sample_momentum(key_momentum, position) integrator_state = integrators.IntegratorState( diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 3cd611140..4023c6245 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -100,7 +100,11 @@ def generalized_two_stage_integrator( """ def one_step(state: IntegratorState, step_size: float): - position, momentum, _, logdensity_grad = state + # position, momentum, _, logdensity_grad = state + position = state.position + momentum = state.momentum + logdensity_grad = state.logdensity_grad + # logdensity = state.logdensity # jax.debug.print("initial state {x}", x=jnp.any(jnp.isnan(momentum))) # auxiliary infomation generated during integration for diagnostics. It is # updated by the operator1 and operator2 at each call. diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 771c542f7..ce0454238 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -62,9 +62,10 @@ def init(position: ArrayLike, logdensity_fn, rng_key): def build_kernel( + + integrator, logdensity_fn, inverse_mass_matrix, - integrator, desired_energy_var_max_ratio=jnp.inf, desired_energy_var=5e-4, ): @@ -94,6 +95,8 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float ) -> tuple[IntegratorState, MCLMCInfo]: + + (position, momentum, logdensity, logdensitygrad), kinetic_change = step( state, step_size, L, rng_key ) @@ -183,10 +186,11 @@ def as_top_level_api( """ kernel = build_kernel( - logdensity_fn, - inverse_mass_matrix, + integrator, desired_energy_var_max_ratio=desired_energy_var_max_ratio, + logdensity_fn=logdensity_fn, + inverse_mass_matrix=inverse_mass_matrix, ) def init_fn(position: ArrayLike, rng_key: PRNGKey): diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 73d8aab4c..75399a1e7 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -133,7 +133,10 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) - position, logdensity, logdensity_grad = state + # position, logdensity, logdensity_grad = state + position = state.position + logdensity = state.logdensity + logdensity_grad = state.logdensity_grad momentum = metric.sample_momentum(key_momentum, position) integrator_state = integrators.IntegratorState( diff --git a/blackjax/mcmc/pseudofermion.py b/blackjax/mcmc/pseudofermion.py new file mode 100644 index 000000000..5a9b00f9e --- /dev/null +++ b/blackjax/mcmc/pseudofermion.py @@ -0,0 +1,151 @@ +from functools import partial +from typing import Any, Callable, NamedTuple + +import jax +import jax.numpy as jnp +import numpy as np + +import blackjax.mcmc.hmc as hmc +import blackjax.mcmc.integrators as integrators +import blackjax.mcmc.metrics as metrics +import blackjax.mcmc.proposal as proposal +import blackjax.mcmc.termination as termination +import blackjax.mcmc.trajectory as trajectory +from blackjax.base import SamplingAlgorithm +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey + + +class GibbsState(NamedTuple): + + # pos : Any + # aux: ArrayTree + position: ArrayTree + logdensity: float + logdensity_grad: ArrayTree + momentum: ArrayTree + temporary_state : Any + fermion_matrix : Any + count : int + +def build_kernel(): + return () + +def init(position, logdensity_fn, fermion_matrix, temporary_state, init_main, rng_key ): + # state = hmc.state + # state.boson_state = position + # fermion_matrix = hmc.theory.get_fermion_matrix(hmc.state) + # temporary_state = hmc.theory.sample_temporary_state(position,hmc.state,fermion_matrix) + position, momentum, logdensity, logdensity_grad = init_main(position, logdensity_fn(fermion_matrix, temporary_state), rng_key ) + return GibbsState( + position=position, + logdensity=logdensity, + logdensity_grad=logdensity_grad, + momentum=momentum, + temporary_state=temporary_state, + fermion_matrix=fermion_matrix, + count=0, + ) + +def as_top_level_api( + kernel_main, + init_main, + logdensity_fn: Callable, + # step_size: float, + # inverse_mass_matrix: metrics.MetricTypes, + *, + max_num_doublings: int = 10, + divergence_threshold: int = 1000, + # integrator: Callable = integrators.velocity_verlet, + get_fermion_matrix_fn: Callable = None, + sample_temporary_state_fn: Callable = None, + # num_integration_steps: int = 1, + # alg1, +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the nuts kernel. + + Examples + -------- + + A new NUTS kernel can be initialized and used with the following code: + + .. code:: + + nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix) + state = nuts.init(position) + new_state, info = nuts.step(rng_key, state) + + We can JIT-compile the step function for more speed: + + .. code:: + + step = jax.jit(nuts.step) + new_state, info = step(rng_key, state) + + You can always use the base kernel should you need to: + + .. code:: + + import blackjax.mcmc.integrators as integrators + + kernel = blackjax.nuts.build_kernel(integrators.yoshida) + state = blackjax.nuts.init(position, logdensity_fn) + state, info = kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix) + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + inverse_mass_matrix + The value to use for the inverse mass matrix when drawing a value for + the momentum and computing the kinetic energy. + max_num_doublings + The maximum number of times we double the length of the trajectory before + returning if no U-turn has been obserbed or no divergence has occured. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + + Returns + ------- + A ``SamplingAlgorithm``. + + """ + # kernel = build_kernel(integrator, divergence_threshold) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + fermion_matrix = get_fermion_matrix_fn(position) + temporary_state = sample_temporary_state_fn(position,fermion_matrix) + return init(position, logdensity_fn, fermion_matrix, temporary_state, init_main, rng_key) + + def step_fn(rng_key: PRNGKey, state): + next_state, info = kernel_main( + rng_key, + state, + logdensity_fn(state.fermion_matrix, state.temporary_state), + # step_size, + # inverse_mass_matrix, + # max_num_doublings, + # num_integration_steps=num_integration_steps, + ) + new_fermion_matrix = get_fermion_matrix_fn(next_state.position) + new_temporary_state = sample_temporary_state_fn(next_state.position, new_fermion_matrix) + full_state = GibbsState( + position=next_state.position, + momentum=None, + # momentum=next_state.momentum, + logdensity=next_state.logdensity, + logdensity_grad=next_state.logdensity_grad, + temporary_state=new_temporary_state, + fermion_matrix=new_fermion_matrix, + count=state.count + info.num_integration_steps, + ) + jax.debug.print("count {x}", x=full_state.count) + return full_state, info + + return SamplingAlgorithm(init_fn, step_fn) From 26f452828b81940bb11679c8990ae09d99997130 Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Mon, 23 Jun 2025 09:50:02 -0700 Subject: [PATCH 51/63] fixed laps --- blackjax/adaptation/ensemble_mclmc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 7b47b2164..b1639522e 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -159,7 +159,7 @@ def while_steps_num(cond): return jnp.argmin(cond) + 1 -def emaus( +def laps( logdensity_fn, sample_init, ndims, @@ -258,7 +258,7 @@ def emaus( ) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. if diagonal_preconditioning: - inverse_mass_matrix = jnp.sqrt(final_adaptation_state.inverse_mass_matrix) + inverse_mass_matrix = final_adaptation_state.inverse_mass_matrix # scale the stepsize so that it reflects averag scale change of the preconditioning average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) From b65a664a34c396c25893adf2ed4933d86e82f4db Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Fri, 27 Jun 2025 02:42:36 -0700 Subject: [PATCH 52/63] working old version of eca_step with splitR --- blackjax/adaptation/ensemble_mclmc.py | 7 ++- blackjax/adaptation/ensemble_umclmc.py | 19 +++---- blackjax/diagnostics.py | 16 ++++++ blackjax/util.py | 76 +++++++++++++++++++------- 4 files changed, 85 insertions(+), 33 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index b1639522e..1b71a8d95 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -181,6 +181,7 @@ def laps( ensemble_observables=None, diagnostics=True, contract=lambda x: 0.0, + superchain_size= None, ): """ model: the target density object @@ -207,7 +208,7 @@ def laps( # initialize the chains initial_state = umclmc.initialize( - key_init, logdensity_fn, sample_init, num_chains, mesh + key_init, logdensity_fn, sample_init, num_chains, mesh, superchain_size=superchain_size ) # burn-in with the unadjusted method # @@ -233,6 +234,7 @@ def laps( num_steps1, num_chains, mesh, + superchain_size, ensemble_observables, early_stop=early_stop, ) @@ -302,7 +304,8 @@ def laps( num_samples, num_chains, mesh, - ensemble_observables, + superchain_size, + ensemble_observables ) if diagnostics: diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 5feef4b0c..5b8be5b28 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -66,7 +66,7 @@ def sequential_kernel(key, state, adap): return sequential_kernel -def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh): +def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh, superchain_size): """initialize the chains based on the equipartition of the initial condition. We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. """ @@ -117,14 +117,13 @@ def ensemble_init(key, state, signs): num_chains, mesh, summary_statistics_fn=summary_statistics_fn, + superchain_size= superchain_size ) flat_equi, _ = ravel_pytree(equipartition) signs = -2.0 * (flat_equi < 1.0) + 1.0 - initial_state, _ = ensemble_execute_fn( - ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs - ) + initial_state, _ = ensemble_execute_fn(ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs, superchain_size= superchain_size) return initial_state @@ -298,13 +297,6 @@ def update(self, adaptation_state, Etheta): eps_factor = nan_reject(1 - nans, 0.5, eps_factor) # reduce the stepsize if there were nans - # determine if we want to finish this stage (i.e. if loss is no longer decreassing) - # increasing = history.stopping[0] > history.stopping[-1] # will be false if some elements of history are still nan (have not been filled yet). Do not be tempted to simply change to while_cond = history[0] < history[-1] - # while_cond = ~increasing - while_cond = (fluctuations[0] > self.r_end) | ( - adaptation_state.step_count < self.save_num - ) - info_to_be_stored = { "L": adaptation_state.L, "step_size": adaptation_state.step_size, @@ -315,7 +307,6 @@ def update(self, adaptation_state, Etheta): "bias": true_bias, "r_max": fluctuations[0], "r_avg": fluctuations[1], - "while_cond": while_cond, "entropy": Etheta["entropy"], "observables": Etheta["observables"], } @@ -331,3 +322,7 @@ def update(self, adaptation_state, Etheta): ) return adaptation_state_new, info_to_be_stored + + def while_cond(self, info): + """determine if we want to switch to adjustment""" + return info['r_max'] > self.r_end \ No newline at end of file diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index 257ce759c..fd184d048 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -207,3 +207,19 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): ess = ess_raw / tau_hat return ess.squeeze() + + +def splitR(position, num_chains, superchain_size, func_for_splitR = jnp.square): + + # combine the chains in super-chains to compute expectation values + func_mk = jax.vmap(func_for_splitR)(position) # shape = (# chains, # func) + func_mk = func_mk.reshape(num_chains // superchain_size, superchain_size, func_mk.shape[-1]) #shape = (# superchains, # chains in superchain, # func) + func_k = jnp.average(func_mk, axis = 1) #shape = (# superchains, # func) + func_sq_k = jnp.average(jnp.square(func_mk), axis = 1) #shape = (# superchains, # func) + W_k = (func_sq_k - jnp.square(func_k)) * superchain_size / (superchain_size - 1) # variance withing k-th superchain + W = jnp.average(W_k, axis = 0) # average within superchain variance + B = jnp.var(func_k, axis = 0, ddof= 1) # between superchain variance + + R = jnp.sqrt(1. + (B/W)) # splitR, shape = (# func) + + return R \ No newline at end of file diff --git a/blackjax/util.py b/blackjax/util.py index 6010bf05d..32025dbba 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -14,6 +14,7 @@ from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.diagnostics import splitR @partial(jit, static_argnames=("precision",), inline=True) @@ -319,13 +320,13 @@ def incremental_value_update( def eca_step( - kernel, summary_statistics_fn, adaptation_update, num_chains, ensemble_info=None + kernel, summary_statistics_fn, adaptation_update, num_chains, superchain_size = None, ensemble_info=None ): """ Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. """ - def _step(state_all, xs): + def step(state_all, xs): """This function operates on a single device.""" ( state, @@ -355,19 +356,37 @@ def _step(state_all, xs): return (state, adaptation_state), info_to_be_stored - if ensemble_info is not None: + return add_ensemble_info(add_splitR(step, num_chains, superchain_size), ensemble_info) - def step(state_all, xs): - (state, adaptation_state), info_to_be_stored = _step(state_all, xs) - return (state, adaptation_state), ( - info_to_be_stored, - vmap(ensemble_info)(state.position), - ) - return step +def add_splitR(step, num_chains, superchain_size): + + def _step(state_all, xs): + + state_all, info_to_be_stored = step(state_all, xs) + + state, adaptation_state = state_all + + R = splitR(state.position, num_chains, superchain_size) + split_bavg = jnp.average(jnp.square(R) - 1) + split_bmax = jnp.max(jnp.square(R) - 1) + + info_to_be_stored['R_avg'] = split_bavg + info_to_be_stored['R_max'] = split_bmax + + return (state, adaptation_state), info_to_be_stored + + return _step if superchain_size is not None else step + + +def add_ensemble_info(step, ensemble_info): + + def _step(state_all, xs): + (state, adaptation_state), info_to_be_stored = step(state_all, xs) + return (state, adaptation_state), (info_to_be_stored, vmap(ensemble_info)(state.position)) + + return _step if ensemble_info is not None else step - else: - return _step def run_eca( @@ -378,6 +397,7 @@ def run_eca( num_steps, num_chains, mesh, + superchain_size= None, ensemble_info=None, early_stop=False, ): @@ -405,6 +425,7 @@ def run_eca( adaptation.summary_statistics_fn, adaptation.update, num_chains, + superchain_size, ensemble_info, ) @@ -431,10 +452,12 @@ def all_steps(initial_state, keys_sampling, keys_adaptation): observables = jnp.zeros((num_steps,)) r_avg = jnp.zeros((num_steps,)) r_max = jnp.zeros((num_steps,)) + R_avg = jnp.zeros((num_steps,)) + R_max = jnp.zeros((num_steps,)) step_size = jnp.zeros((num_steps,)) def step_while(a): - x, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, step_size = a + x, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, R_avg, R_max, step_size = a auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) @@ -450,15 +473,19 @@ def step_while(a): new_observables = observables.at[i].set(info.get("observables")) new_r_avg = r_avg.at[i].set(info.get("r_avg")) new_r_max = r_max.at[i].set(info.get("r_max")) + new_R_avg = R_avg.at[i].set(info.get("R_avg")) + new_R_max = R_max.at[i].set(info.get("R_max")) new_step_size = step_size.at[i].set(info.get("step_size")) - return (output, i + 1, info.get("while_cond"), new_EEVPD, new_EEVPD_wanted, new_L, new_entropy, new_equi_diag, new_equi_full, new_bias0, new_bias1, new_observables, new_r_avg, new_r_max, new_step_size) + return (output, i + 1, + True, #info.get("r_max") > adaptation.r_end,x + new_EEVPD, new_EEVPD_wanted, new_L, new_entropy, new_equi_diag, new_equi_full, new_bias0, new_bias1, new_observables, new_r_avg, new_r_max, new_R_avg, new_R_max, new_step_size) if early_stop: - final_state_all, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, step_size = lax.while_loop( + final_state_all, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, R_avg, R_max, step_size = lax.while_loop( lambda a: ((a[1] < num_steps) & a[2]), step_while, - (initial_state_all, 0, True, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, step_size), + (initial_state_all, 0, True, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, R_avg, R_max, step_size), ) steps_done = i info_history = { @@ -473,6 +500,8 @@ def step_while(a): "observables": observables, "r_avg": r_avg, "r_max": r_max, + "R_avg": R_avg, + "R_max": R_max, "step_size": step_size, "steps_done": steps_done, } @@ -524,6 +553,7 @@ def ensemble_execute_fn( x=None, args=None, summary_statistics_fn=lambda y: 0.0, + superchain_size = None ): """Given a sequential function func(rng_key, x, args) = y, @@ -563,8 +593,16 @@ def F(x, keys): F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False ) - keys = device_put( - split(rng_key, num_chains), NamedSharding(mesh, p) - ) # random keys, distributed across devices + if superchain_size == None: + _keys = split(rng_key, num_chains) + + else: + _keys = jnp.repeat(split(rng_key, num_chains // superchain_size), superchain_size) + + + keys = device_put(_keys, NamedSharding(mesh, p)) # random keys, distributed across devices + # apply F in parallel return parallel_execute(X, keys) + + From 28cabf07f1b3c673b8ecd424280a595141e0fa35 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Mon, 30 Jun 2025 12:39:52 -0700 Subject: [PATCH 53/63] new adaptation --- blackjax/__init__.py | 6 + blackjax/adaptation/__init__.py | 4 + blackjax/adaptation/adjusted_abla.py | 139 +++++++++ .../adaptation/adjusted_mclmc_adaptation.py | 1 + blackjax/adaptation/mclmc_adaptation.py | 5 + blackjax/adaptation/unadjusted_alba.py | 278 ++++++++++++++++++ blackjax/adaptation/unadjusted_step_size.py | 46 +++ blackjax/adaptation/window_adaptation.py | 1 + blackjax/mcmc/adjusted_mclmc_dynamic.py | 18 +- blackjax/mcmc/dynamic_hmc.py | 5 +- blackjax/mcmc/dynamic_malt.py | 10 +- blackjax/mcmc/mchmc.py | 4 +- blackjax/mcmc/mclmc.py | 22 +- blackjax/mcmc/uhmc.py | 10 +- blackjax/mcmc/underdamped_langevin.py | 27 +- 15 files changed, 529 insertions(+), 47 deletions(-) create mode 100644 blackjax/adaptation/adjusted_abla.py create mode 100644 blackjax/adaptation/unadjusted_alba.py create mode 100644 blackjax/adaptation/unadjusted_step_size.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 66ba98e63..de085bc62 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -9,6 +9,9 @@ from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation +from .adaptation.unadjusted_alba import unadjusted_alba +from .adaptation.unadjusted_step_size import robnik_step_size_tuning +from .adaptation.adjusted_abla import alba_adjusted from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat @@ -180,4 +183,7 @@ def generate_top_level_api_from(module): "adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation "ess", # diagnostics "rhat", + "unadjusted_alba", + "robnik_step_size_tuning", + "alba_adjusted", ] diff --git a/blackjax/adaptation/__init__.py b/blackjax/adaptation/__init__.py index 53d5fe2b6..bdade7fff 100644 --- a/blackjax/adaptation/__init__.py +++ b/blackjax/adaptation/__init__.py @@ -4,6 +4,8 @@ meads_adaptation, pathfinder_adaptation, window_adaptation, + unadjusted_alba, + unadjusted_step_size, ) __all__ = [ @@ -12,4 +14,6 @@ "window_adaptation", "pathfinder_adaptation", "mclmc_adaptation", + "unadjusted_alba", + "robnik_step_size_tuning", ] diff --git a/blackjax/adaptation/adjusted_abla.py b/blackjax/adaptation/adjusted_abla.py new file mode 100644 index 000000000..ccee18379 --- /dev/null +++ b/blackjax/adaptation/adjusted_abla.py @@ -0,0 +1,139 @@ +from blackjax.adaptation.step_size import ( + dual_averaging_adaptation, +) +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale +from blackjax.base import AdaptationAlgorithm +from blackjax.types import ArrayLikeTree, PRNGKey +import jax +import jax.numpy as jnp +from typing import Callable +import blackjax +from blackjax.adaptation.unadjusted_alba import unadjusted_alba + + + +def make_random_trajectory_length_fn(random_trajectory_length : bool): + if random_trajectory_length: + integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ).astype('int32') + else: + integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil( + avg_num_integration_steps + ).astype('int32') + return integration_steps_fn + +def da_adaptation( + algorithm, + logdensity_fn: Callable, + integration_steps_fn: Callable, + inverse_mass_matrix, + initial_step_size: float = 1.0, + target_acceptance_rate: float = 0.80, + integrator=blackjax.mcmc.integrators.velocity_verlet, +): + + da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate) + kernel = algorithm.build_kernel(integrator=integrator) + + def step(state, key): + + adaptation_state, kernel_state = state + + new_kernel_state, info = kernel( + rng_key=key, + state=kernel_state, + logdensity_fn=logdensity_fn, + step_size=jnp.exp(adaptation_state.log_step_size), + inverse_mass_matrix=inverse_mass_matrix, + integration_steps_fn=integration_steps_fn, + ) + + new_adaptation_state = da_update( + adaptation_state, + info.acceptance_rate, + ) + + return ( + (new_adaptation_state, new_kernel_state), + info, + ) + + def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): + + + init_key, rng_key = jax.random.split(rng_key) + + init_kernel_state = algorithm.init(position=position, logdensity_fn=logdensity_fn, random_generator_arg=init_key) + + keys = jax.random.split(rng_key, num_steps) + init_state = da_init(initial_step_size), init_kernel_state + (adaptation_state, kernel_state), info = jax.lax.scan( + step, + init_state, + keys, + ) + step_size = da_final(adaptation_state) + return ( + kernel_state, + { + "step_size": step_size, + "inverse_mass_matrix": inverse_mass_matrix, + }, + info, + ) + + return AdaptationAlgorithm(run) + + +def alba_adjusted( + unadjusted_algorithm, + logdensity_fn: Callable, + target_eevpd, + v, + adjusted_algorithm, + num_dimensions: int, + integrator, + target_acceptance_rate: float = 0.80, + num_alba_steps: int = 500, + alba_factor: float = 0.4, + **extra_parameters, + ): + + unadjusted_warmup = unadjusted_alba( + algorithm= unadjusted_algorithm, + logdensity_fn=logdensity_fn, + target_eevpd=target_eevpd, + v=v, + integrator=integrator, + num_alba_steps=num_alba_steps, + alba_factor=alba_factor, **extra_parameters) + + def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): + + unadjusted_warmup_key, adjusted_warmup_key = jax.random.split(rng_key) + + (state, params), adaptation_info = unadjusted_warmup.run(unadjusted_warmup_key, position, num_steps) + + avg_num_integration_steps = params["L"] / params["step_size"] + + integration_steps_fn = lambda k: jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ) + + adjusted_warmup = da_adaptation( + algorithm=adjusted_algorithm, + logdensity_fn=logdensity_fn, + integration_steps_fn=integration_steps_fn, + initial_step_size=params["step_size"], + target_acceptance_rate=target_acceptance_rate, + inverse_mass_matrix=params["inverse_mass_matrix"], + integrator=integrator, **extra_parameters) + + state, params, adaptation_info = adjusted_warmup.run(adjusted_warmup_key, state.position, num_steps) + params["L"] = adaptation_info.num_integration_steps.mean()*params["step_size"] + return state, params, adaptation_info + + return AdaptationAlgorithm(run) + + \ No newline at end of file diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 00767ce7a..474bfc5cc 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -369,6 +369,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): params = params._replace(step_size=final_da(dual_avg_state)) + return state, params, eigenvector, num_tuning_integrator_steps return L_step_size_adaptation diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 48ccaec1f..4be09f9ec 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -123,6 +123,7 @@ def mclmc_find_L_and_step_size( ) + # jax.debug.print("params {x}", x=(params, euclidean)) part1_key, part2_key = jax.random.split(rng_key, 2) @@ -326,8 +327,10 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): L = jnp.sqrt(jnp.sum(variances)) # lmc: should be jnp.mean if euclidean: L /= jnp.sqrt(dim) + # jax.debug.print("foo {x}", x=(L, euclidean, variances)) if diagonal_preconditioning: + # jax.debug.print("bar") inverse_mass_matrix = variances params = params._replace(inverse_mass_matrix=inverse_mass_matrix) L = jnp.sqrt(dim) # lmc: 1 @@ -341,6 +344,8 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) + # jax.debug.print("params {x}", x=(MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix), euclidean)) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return L_step_size_adaptation diff --git a/blackjax/adaptation/unadjusted_alba.py b/blackjax/adaptation/unadjusted_alba.py new file mode 100644 index 000000000..b782baf85 --- /dev/null +++ b/blackjax/adaptation/unadjusted_alba.py @@ -0,0 +1,278 @@ +import jax +import jax.numpy as jnp + +from typing import Callable, NamedTuple + +import blackjax.mcmc as mcmc +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info +from blackjax.adaptation.mass_matrix import ( + MassMatrixAdaptationState, + mass_matrix_adaptation, +) +from blackjax.base import AdaptationAlgorithm +from blackjax.progress_bar import gen_scan_fn +from blackjax.types import Array, ArrayLikeTree, PRNGKey +from blackjax.util import pytree_size +from blackjax.adaptation.window_adaptation import build_schedule +from jax.flatten_util import ravel_pytree +from blackjax.diagnostics import effective_sample_size +from blackjax.adaptation.unadjusted_step_size import robnik_step_size_tuning, RobnikStepSizeTuningState + +class AlbaAdaptationState(NamedTuple): + ss_state: RobnikStepSizeTuningState # step size + imm_state: MassMatrixAdaptationState # inverse mass matrix + step_size: float + inverse_mass_matrix: Array + L : float + +def base( + is_mass_matrix_diagonal: bool, + v, + target_eevpd, +) -> tuple[Callable, Callable, Callable]: + + mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal) + + # step_size_init, step_size_update, step_size_final = dual_averaging_adaptation(target_eevpd) + step_size_init, step_size_update, step_size_final = robnik_step_size_tuning(desired_energy_var=target_eevpd) + + def init( + position: ArrayLikeTree, + ) -> AlbaAdaptationState: + + num_dimensions = pytree_size(position) + imm_state = mm_init(num_dimensions) + + ss_state = step_size_init(initial_step_size=jnp.sqrt(num_dimensions)/5, num_dimensions=num_dimensions) + + return AlbaAdaptationState( + ss_state, + imm_state, + ss_state.step_size, + imm_state.inverse_mass_matrix, + L = jnp.sqrt(num_dimensions)/v + ) + + def fast_update( + position: ArrayLikeTree, + value: float, + warmup_state: AlbaAdaptationState, + ) -> AlbaAdaptationState: + """Update the adaptation state when in a "fast" window. + + Only the step size is adapted in fast windows. "Fast" refers to the fact + that the optimization algorithms are relatively fast to converge + compared to the covariance estimation with Welford's algorithm + + """ + + del position + + + new_ss_state = step_size_update(warmup_state.ss_state, value) + new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + + return AlbaAdaptationState( + new_ss_state, + warmup_state.imm_state, + new_step_size, + warmup_state.inverse_mass_matrix, + L = warmup_state.L + ) + + def slow_update( + position: ArrayLikeTree, + value: float, + warmup_state: AlbaAdaptationState, + ) -> AlbaAdaptationState: + + new_imm_state = mm_update(warmup_state.imm_state, position) + new_ss_state = step_size_update(warmup_state.ss_state, value) + new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + + return AlbaAdaptationState( + new_ss_state, new_imm_state, new_step_size, warmup_state.inverse_mass_matrix, L = warmup_state.L + ) + + def slow_final(warmup_state: AlbaAdaptationState) -> AlbaAdaptationState: + + new_imm_state = mm_final(warmup_state.imm_state) + new_ss_state = step_size_init(step_size_final(warmup_state.ss_state), warmup_state.ss_state.num_dimensions) + new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + + new_L = jnp.sqrt(warmup_state.ss_state.num_dimensions)/v # + + return AlbaAdaptationState( + new_ss_state, + new_imm_state, + new_step_size, + new_imm_state.inverse_mass_matrix, + L = new_L + ) + + def update( + adaptation_state: AlbaAdaptationState, + adaptation_stage: tuple, + position: ArrayLikeTree, + value: float, + ) -> AlbaAdaptationState: + """Update the adaptation state and parameter values. + + Parameters + ---------- + adaptation_state + Current adptation state. + adaptation_stage + The current stage of the warmup: whether this is a slow window, + a fast window and if we are at the last step of a slow window. + position + Current value of the model parameters. + value + Value of the acceptance rate for the last mcmc step. + + Returns + ------- + The updated adaptation state. + + """ + stage, is_middle_window_end = adaptation_stage + + warmup_state = jax.lax.switch( + stage, + (fast_update, slow_update), + position, + value, + adaptation_state, + ) + + warmup_state = jax.lax.cond( + is_middle_window_end, + slow_final, + lambda x: x, + warmup_state, + ) + + return warmup_state + + def final(warmup_state: AlbaAdaptationState) -> tuple[float, Array]: + """Return the final values for the step size and mass matrix.""" + step_size = warmup_state.ss_state.step_size + # step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg) + inverse_mass_matrix = warmup_state.imm_state.inverse_mass_matrix + L = warmup_state.L + return step_size, L, inverse_mass_matrix + + return init, update, final + +def unadjusted_alba( + algorithm, + logdensity_fn: Callable, + target_eevpd, + v, + is_mass_matrix_diagonal: bool = True, + progress_bar: bool = False, + adaptation_info_fn: Callable = return_all_adapt_info, + integrator=mcmc.integrators.velocity_verlet, + num_alba_steps: int = 500, + alba_factor: float = 0.4, + **extra_parameters, +) -> AdaptationAlgorithm: + + + mcmc_kernel = algorithm.build_kernel(integrator) + + adapt_init, adapt_step, adapt_final = base( + is_mass_matrix_diagonal, + target_eevpd=target_eevpd, + v=v, + ) + + def one_step(carry, xs): + _, rng_key, adaptation_stage = xs + state, adaptation_state = carry + + new_state, info = mcmc_kernel( + rng_key=rng_key, + state=state, + logdensity_fn=logdensity_fn, + step_size=adaptation_state.step_size, + inverse_mass_matrix=adaptation_state.inverse_mass_matrix, + L=adaptation_state.L, + **extra_parameters, + ) + new_adaptation_state = adapt_step( + adaptation_state, + adaptation_stage, + new_state.position, + info.energy_change, + ) + + return ( + (new_state, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), + ) + + def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): + init_key, rng_key, alba_key = jax.random.split(rng_key, 3) + init_state = algorithm.init(position=position, logdensity_fn=logdensity_fn, random_generator_arg=init_key) + init_adaptation_state = adapt_init(position) + + if progress_bar: + print("Running window adaptation") + scan_fn = gen_scan_fn(num_steps-num_alba_steps, progress_bar=progress_bar) + start_state = (init_state, init_adaptation_state) + keys = jax.random.split(rng_key, num_steps-num_alba_steps) + schedule = build_schedule(num_steps-num_alba_steps) + last_state, info = scan_fn( + one_step, + start_state, + (jnp.arange(num_steps-num_alba_steps), keys, schedule), + ) + + last_chain_state, last_warmup_state, *_ = last_state + step_size, L, inverse_mass_matrix = adapt_final(last_warmup_state) + + ### + ### ALBA TUNING + ### + keys = jax.random.split(alba_key, num_alba_steps) + mcmc_kernel = algorithm.build_kernel(integrator) + def step(state, key): + next_state, _ = mcmc_kernel( + rng_key=key, + state=state, + logdensity_fn=logdensity_fn, + L=L, + step_size=step_size, + inverse_mass_matrix=inverse_mass_matrix, + ) + + return next_state, next_state.position + + if num_alba_steps > 0: + _, samples = jax.lax.scan(step, last_chain_state, keys) + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + ess = effective_sample_size(flat_samples[None, ...]) + + L=alba_factor * step_size * jnp.mean(num_alba_steps / ess) + + + parameters = { + "step_size": step_size, + "inverse_mass_matrix": inverse_mass_matrix, + "L": L, + **extra_parameters, + } + + return ( + AdaptationResults( + last_chain_state, + parameters, + ), + info, + ) + + return AdaptationAlgorithm(run) + + + diff --git a/blackjax/adaptation/unadjusted_step_size.py b/blackjax/adaptation/unadjusted_step_size.py new file mode 100644 index 000000000..3306caa79 --- /dev/null +++ b/blackjax/adaptation/unadjusted_step_size.py @@ -0,0 +1,46 @@ +import jax.numpy as jnp +from typing import NamedTuple + + +class RobnikStepSizeTuningState(NamedTuple): + time : jnp.ndarray + step_size: float + x_average: float + step_size_max: float + num_dimensions: int + +def robnik_step_size_tuning(desired_energy_var, trust_in_estimate=1.5, num_effective_samples=150, step_size_max=jnp.inf): + + decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) + + def init(initial_step_size, num_dimensions): + return RobnikStepSizeTuningState(time=0.0, x_average=0.0, step_size=initial_step_size, step_size_max=step_size_max, num_dimensions=num_dimensions) + + def update(robnik_state, energy_change): + + xi = ( + jnp.square(energy_change) / (robnik_state.num_dimensions * desired_energy_var) + ) + 1e-8 # 1e-8 is added to avoid divergences in log xi + weight = jnp.exp( + -0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate)) + ) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. + + x_average = decay_rate * robnik_state.x_average + weight * ( + xi / jnp.power(robnik_state.step_size, 6.0) + ) + + time = decay_rate * robnik_state.time + weight + step_size = jnp.power( + x_average / time, -1.0 / 6.0 + ) # We use the Var[E] = O(eps^6) relation here. + step_size = (step_size < robnik_state.step_size_max) * step_size + ( + step_size > robnik_state.step_size_max + ) * robnik_state.step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences + + return RobnikStepSizeTuningState(time=time, x_average=x_average, step_size=step_size, step_size_max=step_size_max, num_dimensions=robnik_state.num_dimensions) + + + def final(robnik_state): + return robnik_state.step_size + + return init, update, final diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 69a098325..dba78b0bf 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -164,6 +164,7 @@ def slow_update( """ new_imm_state = mm_update(warmup_state.imm_state, position) + # new_imm_state = warmup_state.imm_state new_ss_state = da_update(warmup_state.ss_state, acceptance_rate) new_step_size = jnp.exp(new_ss_state.log_step_size) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index 1a69e1a28..844a723a5 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -34,11 +34,9 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: def build_kernel( - integration_steps_fn, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - inverse_mass_matrix=1.0, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -66,6 +64,8 @@ def kernel( state: DynamicHMCState, logdensity_fn: Callable, step_size: float, + integration_steps_fn, + inverse_mass_matrix=1.0, L_proposal_factor: float = jnp.inf, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" @@ -142,10 +142,8 @@ def as_top_level_api( """ kernel = build_kernel( - integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, - inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, ) @@ -154,11 +152,13 @@ def init_fn(position: ArrayLikeTree, rng_key: Array): def update_fn(rng_key: PRNGKey, state): return kernel( - rng_key, - state, - logdensity_fn, - step_size, - L_proposal_factor, + rng_key=rng_key, + state=state, + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=integration_steps_fn, + inverse_mass_matrix=inverse_mass_matrix, + L_proposal_factor=L_proposal_factor, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] diff --git a/blackjax/mcmc/dynamic_hmc.py b/blackjax/mcmc/dynamic_hmc.py index cdf323805..2916d0262 100644 --- a/blackjax/mcmc/dynamic_hmc.py +++ b/blackjax/mcmc/dynamic_hmc.py @@ -56,7 +56,6 @@ def build_kernel( integrator: Callable = integrators.velocity_verlet, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), ): """Build a Dynamic HMC kernel where the number of integration steps is chosen randomly. @@ -87,6 +86,7 @@ def kernel( logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Array, + integration_steps_fn: Callable, **integration_steps_kwargs, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the HMC kernel.""" @@ -155,7 +155,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ kernel = build_kernel( - integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn + integrator, divergence_threshold, next_random_arg_fn, ) def init_fn(position: ArrayLikeTree, rng_key: Array): @@ -171,6 +171,7 @@ def step_fn(rng_key: PRNGKey, state): logdensity_fn, step_size, inverse_mass_matrix, + integration_steps_fn ) return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/dynamic_malt.py b/blackjax/mcmc/dynamic_malt.py index 3260e6ca8..e048a8406 100644 --- a/blackjax/mcmc/dynamic_malt.py +++ b/blackjax/mcmc/dynamic_malt.py @@ -44,7 +44,6 @@ def build_kernel( integrator: Callable = integrators.velocity_verlet, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), L_proposal_factor: float = jnp.inf, ): """Build a Dynamic HMC kernel where the number of integration steps is chosen randomly. @@ -68,6 +67,7 @@ def build_kernel( information about the transition. """ + hmc_base = build_static_hmc_kernel(integrator, divergence_threshold, L_proposal_factor) def kernel( @@ -76,12 +76,15 @@ def kernel( logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Array, + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), **integration_steps_kwargs, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the HMC kernel.""" + + num_integration_steps = integration_steps_fn( state.random_generator_arg, **integration_steps_kwargs - ) + ).astype(int) hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad) hmc_proposal, info = hmc_base( rng_key, @@ -149,7 +152,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ kernel = build_kernel( - integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn, L_proposal_factor + integrator, divergence_threshold, next_random_arg_fn, L_proposal_factor ) def init_fn(position: ArrayLikeTree, rng_key: Array): @@ -165,6 +168,7 @@ def step_fn(rng_key: PRNGKey, state): logdensity_fn, step_size, inverse_mass_matrix, + integration_steps_fn, ) return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/mchmc.py b/blackjax/mcmc/mchmc.py index 544eb79c8..105eea6ae 100644 --- a/blackjax/mcmc/mchmc.py +++ b/blackjax/mcmc/mchmc.py @@ -53,7 +53,7 @@ class MCLMCInfo(NamedTuple): energy_change: float -def init(position: ArrayLike, logdensity_fn, rng_key): +def init(position: ArrayLike, logdensity_fn, random_generator_arg): if pytree_size(position) < 2: raise ValueError( "The target distribution must have more than 1 dimension for MCLMC." @@ -62,7 +62,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): return MCHMCState( position=position, - momentum=generate_unit_vector(rng_key, position), + momentum=generate_unit_vector(random_generator_arg, position), logdensity=l, logdensity_grad=g, steps_until_refresh=0, diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index e64a58061..959ee8fba 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -46,7 +46,7 @@ class MCLMCInfo(NamedTuple): energy_change: float -def init(position: ArrayLike, logdensity_fn, rng_key): +def init(position: ArrayLike, logdensity_fn, random_generator_arg): if pytree_size(position) < 2: raise ValueError( "The target distribution must have more than 1 dimension for MCLMC." @@ -55,7 +55,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): return IntegratorState( position=position, - momentum=generate_unit_vector(rng_key, position), + momentum=generate_unit_vector(random_generator_arg, position), logdensity=l, logdensity_grad=g, ) @@ -63,10 +63,7 @@ def init(position: ArrayLike, logdensity_fn, rng_key): def build_kernel( - integrator, - logdensity_fn, - inverse_mass_matrix, desired_energy_var_max_ratio=jnp.inf, desired_energy_var=5e-4, ): @@ -89,13 +86,14 @@ def build_kernel( """ - step = with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) - ) def kernel( - rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float + rng_key: PRNGKey, state: IntegratorState, logdensity_fn, L: float, step_size: float, + inverse_mass_matrix, ) -> tuple[IntegratorState, MCLMCInfo]: + step = with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) + ) (position, momentum, logdensity, logdensitygrad), kinetic_change = step( @@ -190,14 +188,14 @@ def as_top_level_api( integrator, desired_energy_var_max_ratio=desired_energy_var_max_ratio, - logdensity_fn=logdensity_fn, - inverse_mass_matrix=inverse_mass_matrix, + # logdensity_fn=logdensity_fn, + # inverse_mass_matrix=inverse_mass_matrix, ) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) def update_fn(rng_key, state): - return kernel(rng_key, state, L, step_size) + return kernel(rng_key, state, logdensity_fn, L, step_size, inverse_mass_matrix) return SamplingAlgorithm(init_fn, update_fn) diff --git a/blackjax/mcmc/uhmc.py b/blackjax/mcmc/uhmc.py index f0eb14444..a912e7a7a 100644 --- a/blackjax/mcmc/uhmc.py +++ b/blackjax/mcmc/uhmc.py @@ -61,13 +61,15 @@ class LangevinInfo(NamedTuple): energy_change: float -def init(position: ArrayLike, logdensity_fn, metric, rng_key): +def init(position: ArrayLike, logdensity_fn, random_generator_arg): l, g = jax.value_and_grad(logdensity_fn)(position) + metric = metrics.default_metric(jnp.ones_like(position)) + return UHMCState( position=position, - momentum = metric.sample_momentum(rng_key, position), + momentum = metric.sample_momentum(random_generator_arg, position), logdensity=l, logdensity_grad=g, steps_until_refresh=0, @@ -169,10 +171,10 @@ def as_top_level_api( desired_energy_var_max_ratio=desired_energy_var_max_ratio, desired_energy_var=desired_energy_var, ) - metric = metrics.default_metric(inverse_mass_matrix) + # metric = metrics.default_metric(inverse_mass_matrix) def init_fn(position: ArrayLike, rng_key: PRNGKey): - return init(position, logdensity_fn, metric, rng_key) + return init(position, logdensity_fn, rng_key) def update_fn(rng_key, state): return kernel(rng_key, state, L, step_size) diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py index 272e79105..1db5700e0 100644 --- a/blackjax/mcmc/underdamped_langevin.py +++ b/blackjax/mcmc/underdamped_langevin.py @@ -47,21 +47,21 @@ class LangevinInfo(NamedTuple): energy_change: float -def init(position: ArrayLike, logdensity_fn, metric, rng_key): +def init(position: ArrayLike, logdensity_fn, random_generator_arg): l, g = jax.value_and_grad(logdensity_fn)(position) + metric = metrics.default_metric(jnp.ones_like(position)) + return IntegratorState( position=position, - momentum = metric.sample_momentum(rng_key, position), + momentum = metric.sample_momentum(random_generator_arg, position), logdensity=l, logdensity_grad=g, ) def build_kernel( - logdensity_fn, - inverse_mass_matrix, integrator, desired_energy_var_max_ratio=jnp.inf, desired_energy_var=5e-4,): @@ -84,12 +84,13 @@ def build_kernel( """ - metric = metrics.default_metric(inverse_mass_matrix) - step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy,inverse_mass_matrix) def kernel( - rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float + rng_key: PRNGKey, state: IntegratorState, logdensity_fn, L: float, step_size: float, inverse_mass_matrix, ) -> tuple[IntegratorState, LangevinInfo]: + metric = metrics.default_metric(inverse_mass_matrix) + step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy,inverse_mass_matrix) + (position, momentum, logdensity, logdensitygrad), (kinetic_change, energy_error) = step( state, step_size, L, rng_key ) @@ -171,24 +172,20 @@ def as_top_level_api( We also add the general kernel and state generator as an attribute to this class so users only need to pass `blackjax.langevin` to SMC, adaptation, etc. algorithms. - - - """ kernel = build_kernel( - logdensity_fn, - inverse_mass_matrix, integrator, desired_energy_var_max_ratio=desired_energy_var_max_ratio, desired_energy_var=desired_energy_var, ) - metric = metrics.default_metric(inverse_mass_matrix) + # metric = metrics.default_metric(inverse_mass_matrix) def init_fn(position: ArrayLike, rng_key: PRNGKey): - return init(position, logdensity_fn, metric, rng_key) + return init(position, logdensity_fn, rng_key) def update_fn(rng_key, state): - return kernel(rng_key, state, L, step_size) + return kernel( + rng_key=rng_key, state=state, logdensity_fn=logdensity_fn, L=L, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix) return SamplingAlgorithm(init_fn, update_fn) From 9fdc5000980a536300fb86d7d7ff402ba3b6447b Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Sat, 23 Aug 2025 15:02:28 -0700 Subject: [PATCH 54/63] checkpoint --- blackjax/__init__.py | 7 +- blackjax/adaptation/adjusted_abla.py | 54 ++- blackjax/adaptation/mclmc_adaptation.py | 5 +- blackjax/adaptation/unadjusted_alba.py | 76 ++- blackjax/adaptation/unadjusted_step_size.py | 30 +- blackjax/adaptation/window_adaptation.py | 6 + blackjax/mcmc/adjusted_mclmc.py | 485 ++++++++++---------- blackjax/mcmc/adjusted_mclmc_dynamic.py | 15 +- blackjax/mcmc/dynamic_malt.py | 4 +- blackjax/mcmc/integrators.py | 3 +- blackjax/mcmc/malt.py | 6 +- blackjax/mcmc/mchmc.py | 75 +-- blackjax/mcmc/mclmc.py | 84 +++- blackjax/mcmc/pseudofermion.py | 2 +- blackjax/mcmc/trajectory.py | 3 + blackjax/mcmc/uhmc.py | 71 +-- blackjax/mcmc/underdamped_langevin.py | 106 +++-- blackjax/util.py | 2 + 18 files changed, 558 insertions(+), 476 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index de085bc62..8cd77c644 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -3,9 +3,7 @@ from blackjax._version import __version__ -from .adaptation.adjusted_mclmc_adaptation import adjusted_mclmc_find_L_and_step_size from .adaptation.chees_adaptation import chees_adaptation -from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation @@ -15,7 +13,6 @@ from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat -from .mcmc import adjusted_mclmc as _adjusted_mclmc from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc @@ -129,7 +126,7 @@ def generate_top_level_api_from(module): mchmc = generate_top_level_api_from(_mchmc) langevin = generate_top_level_api_from(_langevin) adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) -adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) +# adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) @@ -179,8 +176,6 @@ def generate_top_level_api_from(module): "meads_adaptation", "chees_adaptation", "pathfinder_adaptation", - "mclmc_find_L_and_step_size", # mclmc adaptation - "adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation "ess", # diagnostics "rhat", "unadjusted_alba", diff --git a/blackjax/adaptation/adjusted_abla.py b/blackjax/adaptation/adjusted_abla.py index ccee18379..f9fc77075 100644 --- a/blackjax/adaptation/adjusted_abla.py +++ b/blackjax/adaptation/adjusted_abla.py @@ -14,9 +14,10 @@ def make_random_trajectory_length_fn(random_trajectory_length : bool): if random_trajectory_length: - integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.ceil( + integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.where(jnp.ceil( jax.random.uniform(k) * rescale(avg_num_integration_steps) - ).astype('int32') + )==0, 1, jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps))).astype('int32') else: integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil( avg_num_integration_steps @@ -30,32 +31,37 @@ def da_adaptation( inverse_mass_matrix, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, + initial_L: float = 1.0, integrator=blackjax.mcmc.integrators.velocity_verlet, + L_proposal_factor=jnp.inf, ): da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate) - kernel = algorithm.build_kernel(integrator=integrator) + kernel = algorithm.build_kernel(integrator=integrator, L_proposal_factor=L_proposal_factor) + + # initial_L = jnp.clip(initial_L, min=initial_step_size+0.01) + def step(state, key): - adaptation_state, kernel_state = state - + (adaptation_state, kernel_state), L = state new_kernel_state, info = kernel( rng_key=key, state=kernel_state, logdensity_fn=logdensity_fn, step_size=jnp.exp(adaptation_state.log_step_size), inverse_mass_matrix=inverse_mass_matrix, - integration_steps_fn=integration_steps_fn, + integration_steps_fn=integration_steps_fn(L/jnp.exp(adaptation_state.log_step_size)), ) - + new_adaptation_state = da_update( adaptation_state, info.acceptance_rate, ) + return ( - (new_adaptation_state, new_kernel_state), + ((new_adaptation_state, new_kernel_state), L), info, ) @@ -68,10 +74,11 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): keys = jax.random.split(rng_key, num_steps) init_state = da_init(initial_step_size), init_kernel_state - (adaptation_state, kernel_state), info = jax.lax.scan( + ((adaptation_state, kernel_state), L), info = jax.lax.scan( step, - init_state, + (init_state, initial_L), keys, + ) step_size = da_final(adaptation_state) return ( @@ -79,6 +86,7 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): { "step_size": step_size, "inverse_mass_matrix": inverse_mass_matrix, + "L": L, }, info, ) @@ -92,11 +100,12 @@ def alba_adjusted( target_eevpd, v, adjusted_algorithm, - num_dimensions: int, integrator, target_acceptance_rate: float = 0.80, num_alba_steps: int = 500, alba_factor: float = 0.4, + preconditioning: bool = True, + L_proposal_factor=jnp.inf, **extra_parameters, ): @@ -107,31 +116,38 @@ def alba_adjusted( v=v, integrator=integrator, num_alba_steps=num_alba_steps, - alba_factor=alba_factor, **extra_parameters) + alba_factor=alba_factor, + preconditioning=preconditioning, + **extra_parameters) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): unadjusted_warmup_key, adjusted_warmup_key = jax.random.split(rng_key) - (state, params), adaptation_info = unadjusted_warmup.run(unadjusted_warmup_key, position, num_steps) + num_unadjusted_steps = 20000 - avg_num_integration_steps = params["L"] / params["step_size"] + (state, params), adaptation_info = unadjusted_warmup.run(unadjusted_warmup_key, position, num_unadjusted_steps) - integration_steps_fn = lambda k: jnp.ceil( - jax.random.uniform(k) * rescale(avg_num_integration_steps) - ) + jax.debug.print("unadjusted params: {params}", params=(params["L"], params["step_size"])) + + integration_steps_fn = make_random_trajectory_length_fn(random_trajectory_length=True) adjusted_warmup = da_adaptation( algorithm=adjusted_algorithm, logdensity_fn=logdensity_fn, integration_steps_fn=integration_steps_fn, + initial_L=params["L"], initial_step_size=params["step_size"], target_acceptance_rate=target_acceptance_rate, inverse_mass_matrix=params["inverse_mass_matrix"], - integrator=integrator, **extra_parameters) + integrator=integrator, L_proposal_factor=L_proposal_factor, **extra_parameters) + + state, params, adaptation_info = adjusted_warmup.run(adjusted_warmup_key, state.position, num_steps) - params["L"] = adaptation_info.num_integration_steps.mean()*params["step_size"] + jax.debug.print("adjusted params: {params}", params=(params["L"], params["step_size"])) + # raise Exception("stop") + # return None return state, params, adaptation_info return AdaptationAlgorithm(run) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 4be09f9ec..e5c1c900f 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -403,6 +403,8 @@ def handle_nans( # new_momentum = euclidean*0.0 + (1-euclidean)*generate_unit_vector(key, previous_state.position) new_momentum = generate_unit_vector(key, previous_state.position) + # new_momentum = euclidean*metric.sample_momentum(key, previous_state.position) + (1-euclidean)*generate_unit_vector(key, previous_state.position) + state = jax.lax.cond( jnp.isnan(next_state.logdensity), lambda: state._replace( @@ -418,8 +420,7 @@ def handle_nans( def handle_high_energy( previous_state, next_state, energy_change, key, inverse_mass_matrix, cutoff, euclidean=False ): - """if there are nans, let's reduce the stepsize, and not update the state. The - function returns the old state in this case.""" + metric = metrics.default_metric(inverse_mass_matrix) diff --git a/blackjax/adaptation/unadjusted_alba.py b/blackjax/adaptation/unadjusted_alba.py index b782baf85..10e9746a8 100644 --- a/blackjax/adaptation/unadjusted_alba.py +++ b/blackjax/adaptation/unadjusted_alba.py @@ -29,9 +29,15 @@ def base( is_mass_matrix_diagonal: bool, v, target_eevpd, + preconditioning: bool = True, ) -> tuple[Callable, Callable, Callable]: mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal) + # if not preconditioning: + + # mm_update = lambda x, y: x + + # mm_final = lambda x: x # step_size_init, step_size_update, step_size_final = dual_averaging_adaptation(target_eevpd) step_size_init, step_size_update, step_size_final = robnik_step_size_tuning(desired_energy_var=target_eevpd) @@ -55,7 +61,7 @@ def init( def fast_update( position: ArrayLikeTree, - value: float, + info, warmup_state: AlbaAdaptationState, ) -> AlbaAdaptationState: """Update the adaptation state when in a "fast" window. @@ -69,38 +75,65 @@ def fast_update( del position - new_ss_state = step_size_update(warmup_state.ss_state, value) + new_ss_state = step_size_update(warmup_state.ss_state, info) new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + new_inverse_mass_matrix = jax.lax.cond( + preconditioning, + lambda: warmup_state.inverse_mass_matrix, + lambda: jnp.ones_like(warmup_state.inverse_mass_matrix), + ) + return AlbaAdaptationState( new_ss_state, warmup_state.imm_state, new_step_size, - warmup_state.inverse_mass_matrix, + new_inverse_mass_matrix, L = warmup_state.L ) def slow_update( position: ArrayLikeTree, - value: float, + info, warmup_state: AlbaAdaptationState, ) -> AlbaAdaptationState: + + # raise Exception new_imm_state = mm_update(warmup_state.imm_state, position) - new_ss_state = step_size_update(warmup_state.ss_state, value) + # jax.debug.print("imm state {x}", x=new_imm_state.inverse_mass_matrix[:3]) + # jax.debug.print("warmup_state.ss_state: {x}", x=(warmup_state.ss_state.step_size)) + new_ss_state = step_size_update(warmup_state.ss_state, info) + # new_ss_state = warmup_state.ss_state + # jax.debug.print("old then new: {new_ss_state}", new_ss_state=(warmup_state.ss_state.step_size, new_ss_state.step_size)) new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + new_inverse_mass_matrix = jax.lax.cond( + preconditioning, + lambda: warmup_state.inverse_mass_matrix, + lambda: jnp.ones_like(warmup_state.inverse_mass_matrix), + ) + return AlbaAdaptationState( - new_ss_state, new_imm_state, new_step_size, warmup_state.inverse_mass_matrix, L = warmup_state.L + new_ss_state, new_imm_state, new_step_size, new_inverse_mass_matrix, L = warmup_state.L ) def slow_final(warmup_state: AlbaAdaptationState) -> AlbaAdaptationState: new_imm_state = mm_final(warmup_state.imm_state) - new_ss_state = step_size_init(step_size_final(warmup_state.ss_state), warmup_state.ss_state.num_dimensions) + new_ss_state = warmup_state.ss_state + # ._replace(step_size=step_size_final(warmup_state.ss_state)) + # step_size_init(step_size_final(warmup_state.ss_state), warmup_state.ss_state.num_dimensions) new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + # jax.debug.print("new_ss_state: {new_ss_state}", new_ss_state=(new_ss_state.step_size)) - new_L = jnp.sqrt(warmup_state.ss_state.num_dimensions)/v # + new_L = jax.lax.cond( + preconditioning, + lambda: jnp.sqrt(warmup_state.ss_state.num_dimensions)/v, + lambda: jnp.sqrt(jnp.sum(new_imm_state.inverse_mass_matrix)), + ) + + # jax.debug.print("new_L: {x}", x=(warmup_state.L, new_L)) return AlbaAdaptationState( new_ss_state, @@ -114,7 +147,7 @@ def update( adaptation_state: AlbaAdaptationState, adaptation_stage: tuple, position: ArrayLikeTree, - value: float, + info, ) -> AlbaAdaptationState: """Update the adaptation state and parameter values. @@ -141,7 +174,7 @@ def update( stage, (fast_update, slow_update), position, - value, + info, adaptation_state, ) @@ -156,7 +189,7 @@ def update( def final(warmup_state: AlbaAdaptationState) -> tuple[float, Array]: """Return the final values for the step size and mass matrix.""" - step_size = warmup_state.ss_state.step_size + step_size = step_size_final(warmup_state.ss_state) # step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg) inverse_mass_matrix = warmup_state.imm_state.inverse_mass_matrix L = warmup_state.L @@ -169,6 +202,7 @@ def unadjusted_alba( logdensity_fn: Callable, target_eevpd, v, + preconditioning: bool = True, is_mass_matrix_diagonal: bool = True, progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, @@ -182,9 +216,10 @@ def unadjusted_alba( mcmc_kernel = algorithm.build_kernel(integrator) adapt_init, adapt_step, adapt_final = base( - is_mass_matrix_diagonal, + is_mass_matrix_diagonal=is_mass_matrix_diagonal, target_eevpd=target_eevpd, v=v, + preconditioning=preconditioning ) def one_step(carry, xs): @@ -204,8 +239,9 @@ def one_step(carry, xs): adaptation_state, adaptation_stage, new_state.position, - info.energy_change, + info, ) + # jax.debug.print("step sizes: {x}", x=(adaptation_state.step_size, new_adaptation_state.step_size)) return ( (new_state, new_adaptation_state), @@ -232,6 +268,8 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): last_chain_state, last_warmup_state, *_ = last_state step_size, L, inverse_mass_matrix = adapt_final(last_warmup_state) + jax.debug.print("unadjusted L before alba: {params}", params=(L, step_size)) + ### ### ALBA TUNING ### @@ -253,17 +291,25 @@ def step(state, key): _, samples = jax.lax.scan(step, last_chain_state, keys) flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) + print(jnp.mean(ess), num_alba_steps, "\n\ness (blackjax internal)\n") + # print("L etc", L, step_size, jnp.mean(ess), num_alba_steps, jnp.mean(num_alba_steps / ess)) L=alba_factor * step_size * jnp.mean(num_alba_steps / ess) - + # print("new L", L) + # raise Exception("stop") + max_num_steps = 500 + + # jax.debug.print("L: {x}", x=step_size*50.) parameters = { "step_size": step_size, "inverse_mass_matrix": inverse_mass_matrix, - "L": L, + "L": jnp.clip(L, max=step_size*max_num_steps), **extra_parameters, } + # jax.debug.print("parameters {x}", x=parameters) + return ( AdaptationResults( last_chain_state, diff --git a/blackjax/adaptation/unadjusted_step_size.py b/blackjax/adaptation/unadjusted_step_size.py index 3306caa79..f13aa7533 100644 --- a/blackjax/adaptation/unadjusted_step_size.py +++ b/blackjax/adaptation/unadjusted_step_size.py @@ -1,6 +1,6 @@ import jax.numpy as jnp from typing import NamedTuple - +import jax class RobnikStepSizeTuningState(NamedTuple): time : jnp.ndarray @@ -9,14 +9,18 @@ class RobnikStepSizeTuningState(NamedTuple): step_size_max: float num_dimensions: int -def robnik_step_size_tuning(desired_energy_var, trust_in_estimate=1.5, num_effective_samples=150, step_size_max=jnp.inf): +def robnik_step_size_tuning(desired_energy_var, trust_in_estimate=1.5, num_effective_samples=150, step_size_max=jnp.inf, step_size_reduction_factor=0.8): decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) def init(initial_step_size, num_dimensions): return RobnikStepSizeTuningState(time=0.0, x_average=0.0, step_size=initial_step_size, step_size_max=step_size_max, num_dimensions=num_dimensions) - def update(robnik_state, energy_change): + def update(robnik_state, info): + + + energy_change = info.energy_change + xi = ( jnp.square(energy_change) / (robnik_state.num_dimensions * desired_energy_var) @@ -37,7 +41,25 @@ def update(robnik_state, energy_change): step_size > robnik_state.step_size_max ) * robnik_state.step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences - return RobnikStepSizeTuningState(time=time, x_average=x_average, step_size=step_size, step_size_max=step_size_max, num_dimensions=robnik_state.num_dimensions) + # old_robnik_state = robnik_state + + # old_robnik_step_size = robnik_state.step_size + # jax.debug.print("step_size: {x}", x=(old_robnik_step_size, old_robnik_step_size * step_size_reduction_factor, step_size)) + # jax.debug.print("stuff {x}", x=(x_average, time, robnik_state.time, time)) + old_robnik_state = robnik_state + + + + robnik_state = jax.lax.cond( + info.nonans, + lambda: RobnikStepSizeTuningState(time=time, x_average=x_average, step_size=step_size, step_size_max=step_size_max, num_dimensions=robnik_state.num_dimensions), + lambda: robnik_state._replace(step_size=robnik_state.step_size * step_size_reduction_factor), + ) + # jax.debug.print("robnik_state: {robnik_state}", robnik_state=(robnik_state.step_size, info.nonans, robnik_state.step_size * step_size_reduction_factor)) + # jax.debug.print("robnik_state: {x}", x=(old_robnik_state.step_size, robnik_state.step_size)) + return robnik_state + + # return RobnikStepSizeTuningState(time=time, x_average=x_average, step_size=step_size, step_size_max=step_size_max, num_dimensions=robnik_state.num_dimensions) def final(robnik_state): diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index dba78b0bf..18dd76cbb 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -45,6 +45,7 @@ class WindowAdaptationState(NamedTuple): def base( is_mass_matrix_diagonal: bool, target_acceptance_rate: float = 0.80, + preconditioning = True, ) -> tuple[Callable, Callable, Callable]: """Warmup scheme for sampling procedures based on euclidean manifold HMC. The schedule and algorithms used match Stan's :cite:p:`stan_hmc_param` as closely as possible. @@ -100,6 +101,9 @@ def base( """ mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal) + if not preconditioning: + mm_update = lambda x, y: x + mm_final = lambda x: x da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate) def init( @@ -252,6 +256,7 @@ def window_adaptation( progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, integrator=mcmc.integrators.velocity_verlet, + preconditioning = True, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -302,6 +307,7 @@ def window_adaptation( adapt_init, adapt_step, adapt_final = base( is_mass_matrix_diagonal, target_acceptance_rate=target_acceptance_rate, + preconditioning=preconditioning, ) def one_step(carry, xs): diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index f390402f2..94e02bf2d 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -1,242 +1,243 @@ -# Copyright 2020- The Blackjax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". - -NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. - -""" -from typing import Callable, Union - -import jax -import jax.numpy as jnp - -import blackjax.mcmc.integrators as integrators -from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.hmc import HMCInfo, HMCState -from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -from blackjax.util import generate_unit_vector - -__all__ = ["init", "build_kernel", "as_top_level_api"] - - -def init(position: ArrayLikeTree, logdensity_fn: Callable): - logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) - return HMCState(position, logdensity, logdensity_grad) - - -def build_kernel( - logdensity_fn: Callable, - integrator: Callable = integrators.isokinetic_mclachlan, - divergence_threshold: float = 1000, - inverse_mass_matrix=1.0, -): - """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. - - Parameters - ---------- - integrator - The integrator to use to integrate the Hamiltonian dynamics. - divergence_threshold - Value of the difference in energy above which we consider that the transition is divergent. - next_random_arg_fn - Function that generates the next `random_generator_arg` from its previous value. - integration_steps_fn - Function that generates the next pseudo or quasi-random number of integration steps in the - sequence, given the current `random_generator_arg`. Needs to return an `int`. - - Returns - ------- - A kernel that takes a rng_key and a Pytree that contains the current state - of the chain and that returns a new state of the chain along with - information about the transition. - """ - - def kernel( - rng_key: PRNGKey, - state: HMCState, - step_size: float, - num_integration_steps: int, - L_proposal_factor: float = jnp.inf, - ) -> tuple[HMCState, HMCInfo]: - """Generate a new sample with the MHMCHMC kernel.""" - - key_momentum, key_integrator = jax.random.split(rng_key, 2) - momentum = generate_unit_vector(key_momentum, state.position) - proposal, info, _ = adjusted_mclmc_proposal( - integrator=integrators.with_isokinetic_maruyama( - integrator( - logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix - ) - ), - step_size=step_size, - L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), - num_integration_steps=num_integration_steps, - divergence_threshold=divergence_threshold, - )( - key_integrator, - integrators.IntegratorState( - state.position, momentum, state.logdensity, state.logdensity_grad - ), - ) - - return ( - HMCState( - proposal.position, - proposal.logdensity, - proposal.logdensity_grad, - ), - info, - ) - - return kernel - - -def as_top_level_api( - logdensity_fn: Callable, - step_size: float, - L_proposal_factor: float = jnp.inf, - inverse_mass_matrix=1.0, - *, - divergence_threshold: int = 1000, - integrator: Callable = integrators.isokinetic_mclachlan, - num_integration_steps, -) -> SamplingAlgorithm: - """Implements the (basic) user interface for the MHMCHMC kernel. - - Parameters - ---------- - logdensity_fn - The log-density function we wish to draw samples from. - step_size - The value to use for the step size in the symplectic integrator. - divergence_threshold - The absolute value of the difference in energy between two states above - which we say that the transition is divergent. The default value is - commonly found in other libraries, and yet is arbitrary. - integrator - (algorithm parameter) The symplectic integrator to use to integrate the trajectory. - next_random_arg_fn - Function that generates the next `random_generator_arg` from its previous value. - integration_steps_fn - Function that generates the next pseudo or quasi-random number of integration steps in the - sequence, given the current `random_generator_arg`. - - - Returns - ------- - A ``SamplingAlgorithm``. - """ - - kernel = build_kernel( - logdensity_fn=logdensity_fn, - integrator=integrator, - inverse_mass_matrix=inverse_mass_matrix, - divergence_threshold=divergence_threshold, - ) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return init(position, logdensity_fn) - - def update_fn(rng_key: PRNGKey, state): - return kernel( - rng_key=rng_key, - state=state, - step_size=step_size, - num_integration_steps=num_integration_steps, - L_proposal_factor=L_proposal_factor, - ) - - return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] - - -def adjusted_mclmc_proposal( - integrator: Callable, - step_size: Union[float, ArrayLikeTree], - L_proposal_factor: float, - num_integration_steps: int = 1, - divergence_threshold: float = 1000, - *, - sample_proposal: Callable = static_binomial_sampling, -) -> Callable: - """Vanilla MHMCHMC algorithm. - - The algorithm integrates the trajectory applying a integrator - `num_integration_steps` times in one direction to get a proposal and uses a - Metropolis-Hastings acceptance step to either reject or accept this - proposal. This is what people usually refer to when they talk about "the - HMC algorithm". - - Parameters - ---------- - integrator - integrator used to build the trajectory step by step. - kinetic_energy - Function that computes the kinetic energy. - step_size - Size of the integration step. - num_integration_steps - Number of times we run the integrator to build the trajectory - divergence_threshold - Threshold above which we say that there is a divergence. - - Returns - ------- - A kernel that generates a new chain state and information about the transition. - - """ - - def step(i, vars): - state, kinetic_energy, rng_key = vars - rng_key, next_rng_key = jax.random.split(rng_key) - next_state, next_kinetic_energy = integrator( - state, step_size, L_proposal_factor, rng_key - ) - - return next_state, kinetic_energy + next_kinetic_energy, next_rng_key - - def build_trajectory(state, num_integration_steps, rng_key): - return jax.lax.fori_loop( - 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) - ) - - def generate( - rng_key, state: integrators.IntegratorState - ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: - """Generate a new chain state.""" - end_state, kinetic_energy, rng_key = build_trajectory( - state, num_integration_steps, rng_key - ) - - new_energy = -end_state.logdensity - delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy - delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) - is_diverging = -delta_energy > divergence_threshold - sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) - do_accept, p_accept, other_proposal_info = info - - info = HMCInfo( - state.momentum, - p_accept, - do_accept, - is_diverging, - new_energy, - end_state, - num_integration_steps, - ) - - return sampled_state, info, other_proposal_info - - return generate +# # Copyright 2020- The Blackjax Authors. +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. +# """Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". + +# NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. + +# """ +# from typing import Callable, Union + +# import jax +# import jax.numpy as jnp + +# import blackjax.mcmc.integrators as integrators +# from blackjax.base import SamplingAlgorithm +# from blackjax.mcmc.hmc import HMCInfo, HMCState +# from blackjax.mcmc.proposal import static_binomial_sampling +# from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey +# from blackjax.util import generate_unit_vector + +# __all__ = ["init", "build_kernel", "as_top_level_api"] + + +# def init(position: ArrayLikeTree, logdensity_fn: Callable): +# logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) +# return HMCState(position, logdensity, logdensity_grad) + + +# def build_kernel( +# logdensity_fn: Callable, +# integrator: Callable = integrators.isokinetic_mclachlan, +# divergence_threshold: float = 1000, +# inverse_mass_matrix=1.0, +# ): +# """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. + +# Parameters +# ---------- +# integrator +# The integrator to use to integrate the Hamiltonian dynamics. +# divergence_threshold +# Value of the difference in energy above which we consider that the transition is divergent. +# next_random_arg_fn +# Function that generates the next `random_generator_arg` from its previous value. +# integration_steps_fn +# Function that generates the next pseudo or quasi-random number of integration steps in the +# sequence, given the current `random_generator_arg`. Needs to return an `int`. + +# Returns +# ------- +# A kernel that takes a rng_key and a Pytree that contains the current state +# of the chain and that returns a new state of the chain along with +# information about the transition. +# """ + +# def kernel( +# rng_key: PRNGKey, +# state: HMCState, +# step_size: float, +# num_integration_steps: int, +# L_proposal_factor: float = jnp.inf, +# ) -> tuple[HMCState, HMCInfo]: +# """Generate a new sample with the MHMCHMC kernel.""" + +# key_momentum, key_integrator = jax.random.split(rng_key, 2) +# momentum = generate_unit_vector(key_momentum, state.position) +# proposal, info, _ = adjusted_mclmc_proposal( +# integrator=integrators.with_isokinetic_maruyama( +# integrator( +# logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix +# ) +# ), +# step_size=step_size, +# L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), +# num_integration_steps=num_integration_steps, +# divergence_threshold=divergence_threshold, +# )( +# key_integrator, +# integrators.IntegratorState( +# state.position, momentum, state.logdensity, state.logdensity_grad +# ), +# ) + +# new_state = HMCState( +# proposal.position, +# proposal.logdensity, +# proposal.logdensity_grad) + +# new_state, info = handle_nans(state, new_state, info, nan_key) + + + +# return kernel + + +# def as_top_level_api( +# logdensity_fn: Callable, +# step_size: float, +# L_proposal_factor: float = jnp.inf, +# inverse_mass_matrix=1.0, +# *, +# divergence_threshold: int = 1000, +# integrator: Callable = integrators.isokinetic_mclachlan, +# num_integration_steps, +# ) -> SamplingAlgorithm: +# """Implements the (basic) user interface for the MHMCHMC kernel. + +# Parameters +# ---------- +# logdensity_fn +# The log-density function we wish to draw samples from. +# step_size +# The value to use for the step size in the symplectic integrator. +# divergence_threshold +# The absolute value of the difference in energy between two states above +# which we say that the transition is divergent. The default value is +# commonly found in other libraries, and yet is arbitrary. +# integrator +# (algorithm parameter) The symplectic integrator to use to integrate the trajectory. +# next_random_arg_fn +# Function that generates the next `random_generator_arg` from its previous value. +# integration_steps_fn +# Function that generates the next pseudo or quasi-random number of integration steps in the +# sequence, given the current `random_generator_arg`. + + +# Returns +# ------- +# A ``SamplingAlgorithm``. +# """ + +# kernel = build_kernel( +# logdensity_fn=logdensity_fn, +# integrator=integrator, +# inverse_mass_matrix=inverse_mass_matrix, +# divergence_threshold=divergence_threshold, +# ) + +# def init_fn(position: ArrayLikeTree, rng_key=None): +# del rng_key +# return init(position, logdensity_fn) + +# def update_fn(rng_key: PRNGKey, state): +# return kernel( +# rng_key=rng_key, +# state=state, +# step_size=step_size, +# num_integration_steps=num_integration_steps, +# L_proposal_factor=L_proposal_factor, +# ) + +# return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +# def adjusted_mclmc_proposal( +# integrator: Callable, +# step_size: Union[float, ArrayLikeTree], +# L_proposal_factor: float, +# num_integration_steps: int = 1, +# divergence_threshold: float = 1000, +# *, +# sample_proposal: Callable = static_binomial_sampling, +# ) -> Callable: +# """Vanilla MHMCHMC algorithm. + +# The algorithm integrates the trajectory applying a integrator +# `num_integration_steps` times in one direction to get a proposal and uses a +# Metropolis-Hastings acceptance step to either reject or accept this +# proposal. This is what people usually refer to when they talk about "the +# HMC algorithm". + +# Parameters +# ---------- +# integrator +# integrator used to build the trajectory step by step. +# kinetic_energy +# Function that computes the kinetic energy. +# step_size +# Size of the integration step. +# num_integration_steps +# Number of times we run the integrator to build the trajectory +# divergence_threshold +# Threshold above which we say that there is a divergence. + +# Returns +# ------- +# A kernel that generates a new chain state and information about the transition. + +# """ + +# def step(i, vars): +# state, kinetic_energy, rng_key = vars +# rng_key, next_rng_key = jax.random.split(rng_key) +# next_state, next_kinetic_energy = integrator( +# state, step_size, L_proposal_factor, rng_key +# ) + +# return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + +# def build_trajectory(state, num_integration_steps, rng_key): +# return jax.lax.fori_loop( +# 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) +# ) + +# def generate( +# rng_key, state: integrators.IntegratorState +# ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: +# """Generate a new chain state.""" +# end_state, kinetic_energy, rng_key = build_trajectory( +# state, num_integration_steps, rng_key +# ) + +# new_energy = -end_state.logdensity +# delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy +# delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) +# is_diverging = -delta_energy > divergence_threshold +# sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) +# do_accept, p_accept, other_proposal_info = info + +# info = HMCInfo( +# state.momentum, +# p_accept, +# do_accept, +# is_diverging, +# new_energy, +# end_state, +# num_integration_steps, +# nonans=True +# ) + +# return sampled_state, info, other_proposal_info + +# return generate diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index 844a723a5..aaf171a44 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -37,6 +37,7 @@ def build_kernel( integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + L_proposal_factor: float = jnp.inf, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -66,7 +67,6 @@ def kernel( step_size: float, integration_steps_fn, inverse_mass_matrix=1.0, - L_proposal_factor: float = jnp.inf, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" @@ -145,6 +145,7 @@ def as_top_level_api( integrator=integrator, next_random_arg_fn=next_random_arg_fn, divergence_threshold=divergence_threshold, + L_proposal_factor=L_proposal_factor, ) def init_fn(position: ArrayLikeTree, rng_key: Array): @@ -158,7 +159,6 @@ def update_fn(rng_key: PRNGKey, state): step_size=step_size, integration_steps_fn=integration_steps_fn, inverse_mass_matrix=inverse_mass_matrix, - L_proposal_factor=L_proposal_factor, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] @@ -257,3 +257,14 @@ def rescale(mu): def trajectory_length(t, mu): s = rescale(mu) return jnp.rint(0.5 + halton_sequence(t) * s) + +def make_random_trajectory_length_fn(random_trajectory_length : bool): + if random_trajectory_length: + integration_steps_fn = lambda avg_num_integration_steps: lambda k: (jnp.clip(jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), min=1)).astype('int32') + else: + integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.clip(jnp.ceil( + avg_num_integration_steps + ), min=1).astype('int32') + return integration_steps_fn \ No newline at end of file diff --git a/blackjax/mcmc/dynamic_malt.py b/blackjax/mcmc/dynamic_malt.py index e048a8406..cd2fbd0de 100644 --- a/blackjax/mcmc/dynamic_malt.py +++ b/blackjax/mcmc/dynamic_malt.py @@ -85,6 +85,7 @@ def kernel( num_integration_steps = integration_steps_fn( state.random_generator_arg, **integration_steps_kwargs ).astype(int) + hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad) hmc_proposal, info = hmc_base( rng_key, @@ -95,9 +96,6 @@ def kernel( num_integration_steps, ) - # jax.debug.print("logdensity {x}", x=hmc_proposal.logdensity) - # jax.debug.print("acceptance {x}", x=info) - next_random_arg = next_random_arg_fn(state.random_generator_arg) return ( DynamicHMCState( diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 4023c6245..fdbcaa0f8 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -473,7 +473,7 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L, inverse_mass_mat # dim = m.shape[0] c1 = jnp.exp(-step_size/L) c2 = jnp.sqrt((1-c1**2)) - z = normal(rng_key, shape=m.shape, dtype=m.dtype) + # z = normal(rng_key, shape=m.shape, dtype=m.dtype) metric = default_metric(inverse_mass_matrix) z = metric.sample_momentum(rng_key, m) # normal(rng_key, shape=m.shape, dtype=m.dtype) @@ -531,7 +531,6 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): ) ) # jax.debug.print("state 1.5 {x}",x=state) - # state = init_state # TODO: add noise back! # one step of the deterministic dynamics new_state = integrator(state, step_size) # jax.debug.print("state 2 {x}",x=state) diff --git a/blackjax/mcmc/malt.py b/blackjax/mcmc/malt.py index 572264589..90b007e9d 100644 --- a/blackjax/mcmc/malt.py +++ b/blackjax/mcmc/malt.py @@ -77,8 +77,6 @@ def kernel( L = num_integration_steps * step_size L_proposal = L_proposal_factor * L - # jax.debug.print("L_proposal {x}",x=L_proposal) - key_trajectory, key_momentum, key_integrator = jax.random.split(rng_key, 3) metric = metrics.default_metric(inverse_mass_matrix) symplectic_integrator = lambda state, step_size, rng_key: integrators.with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy, inverse_mass_matrix)(state, step_size, L_proposal=L_proposal, rng_key=rng_key) @@ -94,8 +92,6 @@ def kernel( position, logdensity, logdensity_grad = state momentum = metric.sample_momentum(key_momentum, position) - # import jax.numpy as jnp - # jax.debug.print("momentum nan? {x}",x=jnp.any(jnp.isnan(momentum))) integrator_state = integrators.IntegratorState( position, momentum, logdensity, logdensity_grad @@ -265,11 +261,11 @@ def generate( delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) # jax.debug.print("delta_energy_trajectory {x}",x=(delta_energy_trajectory, delta_energy)) - # jax.debug.print("delta_energy {x}",x=delta_energy) is_diverging = -delta_energy > divergence_threshold # jax.debug.print("is_diverging {x}",x=is_diverging) sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info + # jax.debug.print("delta_energy {x}",x=delta_energy) # jax.debug.print("p_accept {p_accept}", p_accept=(p_accept, delta_energy)) info = HMCInfo( diff --git a/blackjax/mcmc/mchmc.py b/blackjax/mcmc/mchmc.py index 105eea6ae..06bda7097 100644 --- a/blackjax/mcmc/mchmc.py +++ b/blackjax/mcmc/mchmc.py @@ -21,12 +21,11 @@ from blackjax.mcmc.integrators import ( IntegratorState, isokinetic_mclachlan, - with_isokinetic_maruyama, ) from blackjax.types import ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size -from blackjax.mcmc.adjusted_mclmc_dynamic import rescale - +from blackjax.mcmc.mclmc import handle_high_energy, handle_nans, MCLMCInfo +from blackjax.mcmc.adjusted_mclmc_dynamic import make_random_trajectory_length_fn __all__ = ["MCLMCInfo", "init", "build_kernel", "as_top_level_api"] class MCHMCState(NamedTuple): @@ -36,21 +35,6 @@ class MCHMCState(NamedTuple): logdensity_grad: ArrayLike steps_until_refresh: int -class MCLMCInfo(NamedTuple): - """ - Additional information on the MCLMC transition. - - logdensity - The log-density of the distribution at the current step of the MCLMC chain. - kinetic_change - The difference in kinetic energy between the current and previous step. - energy_change - The difference in energy between the current and previous step. - """ - - logdensity: float - kinetic_change: float - energy_change: float def init(position: ArrayLike, logdensity_fn, random_generator_arg): @@ -79,8 +63,6 @@ def integrator_state(state: MCHMCState) -> IntegratorState: def build_kernel( # integration_steps_fn, - logdensity_fn, - inverse_mass_matrix, integrator, desired_energy_var_max_ratio=jnp.inf, desired_energy_var=5e-4, @@ -88,60 +70,41 @@ def build_kernel( """ """ - step = integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) def kernel( - rng_key: PRNGKey, state: MCHMCState, L: float, step_size: float + rng_key: PRNGKey, state: MCHMCState, logdensity_fn, L: float, step_size: float, inverse_mass_matrix: ArrayLike ) -> tuple[MCHMCState, MCLMCInfo]: + step = integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) (position, momentum, logdensity, logdensitygrad), kinetic_change = step( integrator_state(state), step_size ) - # num_integration_steps = integration_steps_fn(state.random_generator_arg) - jitter_key, refresh_key = jax.random.split(rng_key) + randomization_key, refresh_key, energy_cutoff_key, nan_key = jax.random.split(rng_key, 4) - num_steps_per_traj = jnp.ceil(L/step_size).astype(int) - - - num_steps_per_traj = jnp.ceil( - jax.random.uniform(jitter_key) * rescale(num_steps_per_traj) - ).astype(int) + num_steps_per_traj = make_random_trajectory_length_fn(True)(L/step_size)(randomization_key).astype(jnp.int64) - energy_error = kinetic_change - logdensity + state.logdensity + energy_change = kinetic_change - logdensity + state.logdensity eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var ndims = pytree_size(position) momentum=(state.steps_until_refresh==0) * generate_unit_vector(refresh_key, state.position) + (state.steps_until_refresh>0) * momentum - # new_state = new_state._replace(momentum=generate_unit_vector(refresh_key, new_state.position)) steps_until_refresh = (state.steps_until_refresh==0) * num_steps_per_traj + (state.steps_until_refresh>0) * (state.steps_until_refresh - 1) - # jax.debug.print("steps_until_refresh: {x}", x=steps_until_refresh) - - new_state, new_info = jax.lax.cond( - energy_error > jnp.sqrt(ndims * eev_max_per_dim), - lambda: ( - state, - MCLMCInfo( - logdensity=state.logdensity, - energy_change=0.0, - kinetic_change=0.0, - ), - ), - lambda: ( - MCHMCState(position, momentum, logdensity, logdensitygrad, steps_until_refresh), - MCLMCInfo( - logdensity=logdensity, - energy_change=energy_error, - kinetic_change=kinetic_change, - ), - ), - ) - return new_state, new_info + new_state, info = handle_high_energy(state, MCHMCState(position, momentum, logdensity, logdensitygrad, steps_until_refresh), MCLMCInfo( + logdensity=logdensity, + energy_change=energy_change, + kinetic_change=kinetic_change, + nonans=True + ), energy_cutoff_key, cutoff = jnp.sqrt(ndims * eev_max_per_dim)) + + new_state, info = handle_nans(state, new_state, info, nan_key) + + return new_state, info return kernel @@ -158,8 +121,6 @@ def as_top_level_api( """ kernel = build_kernel( - logdensity_fn, - inverse_mass_matrix, integrator, desired_energy_var_max_ratio=desired_energy_var_max_ratio, ) @@ -168,6 +129,6 @@ def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) def update_fn(rng_key, state): - return kernel(rng_key, state, L, step_size) + return kernel(rng_key, state, logdensity_fn, L, step_size, inverse_mass_matrix) return SamplingAlgorithm(init_fn, update_fn) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 959ee8fba..ca245fe52 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -25,6 +25,7 @@ ) from blackjax.types import ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size +from jax.flatten_util import ravel_pytree __all__ = ["MCLMCInfo", "init", "build_kernel", "as_top_level_api"] @@ -44,6 +45,7 @@ class MCLMCInfo(NamedTuple): logdensity: float kinetic_change: float energy_change: float + nonans: bool def init(position: ArrayLike, logdensity_fn, random_generator_arg): @@ -94,38 +96,28 @@ def kernel( step = with_isokinetic_maruyama( integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) + + kernel_key, energy_cutoff_key, nan_key = jax.random.split(rng_key, 3) (position, momentum, logdensity, logdensitygrad), kinetic_change = step( - state, step_size, L, rng_key + state, step_size, L, kernel_key ) - energy_error = kinetic_change - logdensity + state.logdensity - + energy_change = kinetic_change - logdensity + state.logdensity eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var ndims = pytree_size(position) - new_state, new_info = jax.lax.cond( - energy_error > jnp.sqrt(ndims * eev_max_per_dim), - lambda: ( - state, - MCLMCInfo( - logdensity=state.logdensity, - energy_change=0.0, - kinetic_change=0.0, - ), - ), - lambda: ( - IntegratorState(position, momentum, logdensity, logdensitygrad), - MCLMCInfo( - logdensity=logdensity, - energy_change=energy_error, - kinetic_change=kinetic_change, - ), - ), - ) + new_state, info = handle_high_energy(state, IntegratorState(position, momentum, logdensity, logdensitygrad), MCLMCInfo( + logdensity=logdensity, + energy_change=energy_change, + kinetic_change=kinetic_change, + nonans=True + ), energy_cutoff_key, cutoff = jnp.sqrt(ndims * eev_max_per_dim)) + + new_state, info = handle_nans(state, new_state, info, nan_key) - return new_state, new_info + return new_state, info return kernel @@ -188,8 +180,6 @@ def as_top_level_api( integrator, desired_energy_var_max_ratio=desired_energy_var_max_ratio, - # logdensity_fn=logdensity_fn, - # inverse_mass_matrix=inverse_mass_matrix, ) def init_fn(position: ArrayLike, rng_key: PRNGKey): @@ -199,3 +189,47 @@ def update_fn(rng_key, state): return kernel(rng_key, state, logdensity_fn, L, step_size, inverse_mass_matrix) return SamplingAlgorithm(init_fn, update_fn) + + +def handle_nans( + previous_state, next_state, info, key +): + + new_momentum = generate_unit_vector(key, previous_state.position) + + nonans = jnp.logical_and(jnp.all(jnp.isfinite(next_state.position)), jnp.all(jnp.isfinite(next_state.momentum))) + + state, info = jax.lax.cond( + nonans, + lambda: (next_state, info), + lambda: (previous_state._replace( + momentum=new_momentum, + ), MCLMCInfo( + logdensity=previous_state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + nonans=nonans + )), + ) + + return state, info +def handle_high_energy( + previous_state, next_state, info, key, cutoff +): + + new_momentum = generate_unit_vector(key, previous_state.position) + + state, info = jax.lax.cond( + jnp.abs(info.energy_change) > cutoff, + lambda: (previous_state._replace( + momentum=new_momentum, + ), MCLMCInfo( + logdensity=previous_state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + nonans=info.nonans + )), + lambda: (next_state, info), + ) + + return state, info \ No newline at end of file diff --git a/blackjax/mcmc/pseudofermion.py b/blackjax/mcmc/pseudofermion.py index 5a9b00f9e..77f44ca4a 100644 --- a/blackjax/mcmc/pseudofermion.py +++ b/blackjax/mcmc/pseudofermion.py @@ -145,7 +145,7 @@ def step_fn(rng_key: PRNGKey, state): fermion_matrix=new_fermion_matrix, count=state.count + info.num_integration_steps, ) - jax.debug.print("count {x}", x=full_state.count) + # jax.debug.print("count {x}", x=full_state.count) return full_state, info return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 14851c03c..25260e777 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -141,10 +141,13 @@ def integrate( lambda step_size: direction * step_size, step_size ) + # jax.debug.print("num_integration_steps {x}", x=num_integration_steps) + def one_step(_, accum): state, delta_energy, rng_key = accum rng_key, key2 = jax.random.split(rng_key) new_state, (_, delta_energy_new) = integrator(state, directed_step_size, rng_key) + # jax.debug.print("delta_energy 0 {x}", x=delta_energy_new) return (new_state, delta_energy+delta_energy_new, key2) state, delta_energy, _ = jax.lax.fori_loop(0, num_integration_steps, one_step, (initial_state, 0.0, rng_key)) diff --git a/blackjax/mcmc/uhmc.py b/blackjax/mcmc/uhmc.py index a912e7a7a..4e6f6daea 100644 --- a/blackjax/mcmc/uhmc.py +++ b/blackjax/mcmc/uhmc.py @@ -15,18 +15,17 @@ from typing import Callable, NamedTuple import jax - +from blackjax.mcmc.adjusted_mclmc_dynamic import make_random_trajectory_length_fn from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import ( IntegratorState, - with_maruyama, velocity_verlet, ) from blackjax.types import ArrayLike, PRNGKey import blackjax.mcmc.metrics as metrics import jax.numpy as jnp from blackjax.util import pytree_size, generate_unit_vector -from blackjax.adaptation.mclmc_adaptation import handle_high_energy +from blackjax.mcmc.underdamped_langevin import handle_high_energy, handle_nans, LangevinInfo __all__ = ["LangevinInfo", "init", "build_kernel", "as_top_level_api"] class UHMCState(NamedTuple): @@ -44,22 +43,6 @@ def integrator_state(state: UHMCState) -> IntegratorState: logdensity_grad=state.logdensity_grad, ) -class LangevinInfo(NamedTuple): - """ - Additional information on the Langevin transition. - - logdensity - The log-density of the distribution at the current step of the Langevin chain. - kinetic_change - The difference in kinetic energy between the current and previous step. - energy_change - The difference in energy between the current and previous step. - """ - - logdensity: float - kinetic_change: float - energy_change: float - def init(position: ArrayLike, logdensity_fn, random_generator_arg): @@ -77,10 +60,8 @@ def init(position: ArrayLike, logdensity_fn, random_generator_arg): def build_kernel( - logdensity_fn, - inverse_mass_matrix, integrator, - desired_energy_var_max_ratio=jnp.inf, + desired_energy_var_max_ratio=1e3, desired_energy_var=5e-4,): """Build a HMC kernel. @@ -101,45 +82,43 @@ def build_kernel( """ - metric = metrics.default_metric(inverse_mass_matrix) - step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy,inverse_mass_matrix) def kernel( - rng_key: PRNGKey, state: UHMCState, L: float, step_size: float + rng_key: PRNGKey, state: UHMCState, logdensity_fn, L: float, step_size: float, inverse_mass_matrix, ) -> tuple[UHMCState, LangevinInfo]: + metric = metrics.default_metric(inverse_mass_matrix) + step = integrator(logdensity_fn, metric.kinetic_energy) - refresh_key, energy_key, run_key = jax.random.split(rng_key, 3) + refresh_key, energy_cutoff_key, nan_key, randomization_key = jax.random.split(rng_key, 4) - (position, momentum, logdensity, logdensitygrad), (kinetic_change, energy_error) = step( - integrator_state(state), step_size, jnp.inf, run_key + (position, momentum, logdensity, logdensitygrad) = step( + integrator_state(state), step_size ) + + kinetic_change = - metric.kinetic_energy(state.momentum) + metric.kinetic_energy(momentum) + energy_change = kinetic_change - logdensity + state.logdensity + num_steps_per_traj = make_random_trajectory_length_fn(True)(L/step_size)(randomization_key).astype(jnp.int64) - num_steps_per_traj = jnp.ceil(L/step_size).astype(jnp.int64) momentum = (state.steps_until_refresh==0) * metric.sample_momentum(refresh_key, position) + (state.steps_until_refresh>0) * momentum + steps_until_refresh = (state.steps_until_refresh==0) * num_steps_per_traj + (state.steps_until_refresh>0) * (state.steps_until_refresh - 1) eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var ndims = pytree_size(position) - energy, new_integrator_state = handle_high_energy( - previous_state=integrator_state(state), - next_state=IntegratorState(position, momentum, logdensity, logdensitygrad), - energy_change=energy_error, - key=energy_key, - inverse_mass_matrix=inverse_mass_matrix, - cutoff=jnp.sqrt(ndims * eev_max_per_dim), - euclidean=True - ) - return UHMCState(new_integrator_state.position, new_integrator_state.momentum, new_integrator_state.logdensity, new_integrator_state.logdensity_grad, steps_until_refresh), LangevinInfo( + new_state, info = handle_high_energy(state, UHMCState(position, momentum, logdensity, logdensitygrad, steps_until_refresh), LangevinInfo( logdensity=logdensity, - energy_change=energy, - kinetic_change=kinetic_change - ) - + energy_change=energy_change, + kinetic_change=kinetic_change, + nonans=True + ), energy_cutoff_key, cutoff = jnp.sqrt(ndims * eev_max_per_dim), inverse_mass_matrix=inverse_mass_matrix) + new_state, info = handle_nans(state, new_state, info, nan_key, inverse_mass_matrix) + return new_state, info + return kernel @@ -165,8 +144,6 @@ def as_top_level_api( """ kernel = build_kernel( - logdensity_fn, - inverse_mass_matrix, integrator, desired_energy_var_max_ratio=desired_energy_var_max_ratio, desired_energy_var=desired_energy_var, @@ -177,6 +154,8 @@ def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) def update_fn(rng_key, state): - return kernel(rng_key, state, L, step_size) + return kernel(rng_key, state, logdensity_fn, L, step_size, inverse_mass_matrix) return SamplingAlgorithm(init_fn, update_fn) + + diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py index 1db5700e0..6bf01591a 100644 --- a/blackjax/mcmc/underdamped_langevin.py +++ b/blackjax/mcmc/underdamped_langevin.py @@ -45,7 +45,7 @@ class LangevinInfo(NamedTuple): logdensity: float kinetic_change: float energy_change: float - + nonans : bool def init(position: ArrayLike, logdensity_fn, random_generator_arg): @@ -63,7 +63,7 @@ def init(position: ArrayLike, logdensity_fn, random_generator_arg): def build_kernel( integrator, - desired_energy_var_max_ratio=jnp.inf, + desired_energy_var_max_ratio=1e3, desired_energy_var=5e-4,): """Build a HMC kernel. @@ -94,64 +94,30 @@ def kernel( (position, momentum, logdensity, logdensitygrad), (kinetic_change, energy_error) = step( state, step_size, L, rng_key ) - - # jax.debug.print("energy change {x}", x=energy_change) - - - # kinetic_change = - momentum@momentum/2 + state.momentum@state.momentum/2 - # return IntegratorState( - # position, momentum, logdensity, logdensitygrad - # ), LangevinInfo( - # logdensity=logdensity, - # energy_change=energy_change, - # kinetic_change=kinetic_change - # ) - eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var ndims = pytree_size(position) - # jax.debug.print("diagnostics {x}", x=(eev_max_per_dim, jnp.abs(energy_error), jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim))) - energy_key, rng_key = jax.random.split(rng_key) + energy_key, nan_key = jax.random.split(rng_key) - energy, new_state = handle_high_energy( + new_state, info = handle_high_energy( previous_state=state, next_state=IntegratorState(position, momentum, logdensity, logdensitygrad), - energy_change=energy_error, + info=LangevinInfo( + logdensity=logdensity, + energy_change=energy_error, + kinetic_change=kinetic_change, + nonans=True + ), key=energy_key, inverse_mass_matrix=inverse_mass_matrix, cutoff=jnp.sqrt(ndims * eev_max_per_dim), - euclidean=True ) - return new_state, LangevinInfo( - logdensity=new_state.logdensity, - energy_change=energy, - kinetic_change=kinetic_change - ) + new_state, info = handle_nans(state, new_state, info, nan_key, inverse_mass_matrix) + return new_state, info - # new_state, new_info = jax.lax.cond( - # jnp.abs(energy_error) > jnp.sqrt(ndims * eev_max_per_dim), - # lambda: ( - # state, - # LangevinInfo( - # logdensity=state.logdensity, - # energy_change=0.0, - # kinetic_change=0.0, - # ), - # ), - # lambda: ( - # IntegratorState(position, momentum, logdensity, logdensitygrad), - # LangevinInfo( - # logdensity=logdensity, - # energy_change=energy_error, - # kinetic_change=kinetic_change, - # ), - # ), - # ) - - # return new_state, new_info return kernel @@ -179,7 +145,6 @@ def as_top_level_api( desired_energy_var_max_ratio=desired_energy_var_max_ratio, desired_energy_var=desired_energy_var, ) - # metric = metrics.default_metric(inverse_mass_matrix) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) @@ -189,3 +154,50 @@ def update_fn(rng_key, state): rng_key=rng_key, state=state, logdensity_fn=logdensity_fn, L=L, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix) return SamplingAlgorithm(init_fn, update_fn) + + +def handle_nans( + previous_state, next_state, info, key, inverse_mass_matrix +): + + metric = metrics.default_metric(inverse_mass_matrix) + new_momentum = metric.sample_momentum(key, previous_state.position) + + nonans = jnp.logical_and(jnp.all(jnp.isfinite(next_state.position)), jnp.all(jnp.isfinite(next_state.momentum))) + + state, info = jax.lax.cond( + nonans, + lambda: (next_state, info), + lambda: (previous_state._replace( + momentum=new_momentum, + ), LangevinInfo( + logdensity=previous_state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + nonans=nonans + )), + ) + + return state, info + +def handle_high_energy( + previous_state, next_state, info, key, cutoff, inverse_mass_matrix +): + + metric = metrics.default_metric(inverse_mass_matrix) + new_momentum = metric.sample_momentum(key, previous_state.position) + + state, info = jax.lax.cond( + jnp.abs(info.energy_change) > cutoff, + lambda: (previous_state._replace( + momentum=new_momentum, + ), LangevinInfo( + logdensity=previous_state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + nonans=False + )), + lambda: (next_state, info), + ) + + return state, info \ No newline at end of file diff --git a/blackjax/util.py b/blackjax/util.py index 32025dbba..bfdab8730 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -606,3 +606,5 @@ def F(x, keys): return parallel_execute(X, keys) + + From 3ee1c141a8d786cc93b8d78957f3267e566b922b Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Wed, 27 Aug 2025 08:25:11 -0700 Subject: [PATCH 55/63] fix laps --- blackjax/adaptation/ensemble_mclmc.py | 10 +- blackjax/adaptation/ensemble_umclmc.py | 4 +- blackjax/mcmc/adjusted_mclmc.py | 515 +++++++++++++----------- blackjax/mcmc/adjusted_mclmc_dynamic.py | 1 + 4 files changed, 281 insertions(+), 249 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 1b71a8d95..58536ba38 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -50,15 +50,16 @@ class AdaptationState(NamedTuple): build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( - logdensity_fn=logdensity_fn, + # logdensity_fn=logdensity_fn, integrator=integrator, - inverse_mass_matrix=inverse_mass_matrix, + L_proposal_factor=1.25, )( rng_key=key, state=state, + logdensity_fn=logdensity_fn, step_size=adap.step_size, - num_integration_steps=adap.steps_per_sample, - L_proposal_factor=1.25, + integration_steps_fn=lambda k:adap.steps_per_sample, + inverse_mass_matrix=inverse_mass_matrix, ) @@ -277,6 +278,7 @@ def laps( initial_state = HMCState( final_state.position, final_state.logdensity, final_state.logdensity_grad + # jax.random.key(0) ) num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) num_adaptation_samples = ( diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 5b8be5b28..6d69c7124 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -50,8 +50,8 @@ def build_kernel(logdensity_fn): def sequential_kernel(key, state, adap): new_state, info = mclmc.build_kernel( - logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet, inverse_mass_matrix= jnp.ones(adap.inverse_mass_matrix.shape) - )(key, state, adap.L, adap.step_size) + integrator=isokinetic_velocity_verlet, + )(key, state,logdensity_fn, adap.L, adap.step_size,jnp.ones(adap.inverse_mass_matrix.shape)) # reject the new state if there were nans nonans = no_nans(new_state) diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index 94e02bf2d..747bedc4c 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -1,243 +1,272 @@ -# # Copyright 2020- The Blackjax Authors. -# # -# # Licensed under the Apache License, Version 2.0 (the "License"); -# # you may not use this file except in compliance with the License. -# # You may obtain a copy of the License at -# # -# # http://www.apache.org/licenses/LICENSE-2.0 -# # -# # Unless required by applicable law or agreed to in writing, software -# # distributed under the License is distributed on an "AS IS" BASIS, -# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# # See the License for the specific language governing permissions and -# # limitations under the License. -# """Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". - -# NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. - -# """ -# from typing import Callable, Union - -# import jax -# import jax.numpy as jnp - -# import blackjax.mcmc.integrators as integrators -# from blackjax.base import SamplingAlgorithm -# from blackjax.mcmc.hmc import HMCInfo, HMCState -# from blackjax.mcmc.proposal import static_binomial_sampling -# from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -# from blackjax.util import generate_unit_vector - -# __all__ = ["init", "build_kernel", "as_top_level_api"] - - -# def init(position: ArrayLikeTree, logdensity_fn: Callable): -# logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) -# return HMCState(position, logdensity, logdensity_grad) - - -# def build_kernel( -# logdensity_fn: Callable, -# integrator: Callable = integrators.isokinetic_mclachlan, -# divergence_threshold: float = 1000, -# inverse_mass_matrix=1.0, -# ): -# """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. - -# Parameters -# ---------- -# integrator -# The integrator to use to integrate the Hamiltonian dynamics. -# divergence_threshold -# Value of the difference in energy above which we consider that the transition is divergent. -# next_random_arg_fn -# Function that generates the next `random_generator_arg` from its previous value. -# integration_steps_fn -# Function that generates the next pseudo or quasi-random number of integration steps in the -# sequence, given the current `random_generator_arg`. Needs to return an `int`. - -# Returns -# ------- -# A kernel that takes a rng_key and a Pytree that contains the current state -# of the chain and that returns a new state of the chain along with -# information about the transition. -# """ - -# def kernel( -# rng_key: PRNGKey, -# state: HMCState, -# step_size: float, -# num_integration_steps: int, -# L_proposal_factor: float = jnp.inf, -# ) -> tuple[HMCState, HMCInfo]: -# """Generate a new sample with the MHMCHMC kernel.""" - -# key_momentum, key_integrator = jax.random.split(rng_key, 2) -# momentum = generate_unit_vector(key_momentum, state.position) -# proposal, info, _ = adjusted_mclmc_proposal( -# integrator=integrators.with_isokinetic_maruyama( -# integrator( -# logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix -# ) -# ), -# step_size=step_size, -# L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), -# num_integration_steps=num_integration_steps, -# divergence_threshold=divergence_threshold, -# )( -# key_integrator, -# integrators.IntegratorState( -# state.position, momentum, state.logdensity, state.logdensity_grad -# ), -# ) - -# new_state = HMCState( -# proposal.position, -# proposal.logdensity, -# proposal.logdensity_grad) - -# new_state, info = handle_nans(state, new_state, info, nan_key) - - - -# return kernel - - -# def as_top_level_api( -# logdensity_fn: Callable, -# step_size: float, -# L_proposal_factor: float = jnp.inf, -# inverse_mass_matrix=1.0, -# *, -# divergence_threshold: int = 1000, -# integrator: Callable = integrators.isokinetic_mclachlan, -# num_integration_steps, -# ) -> SamplingAlgorithm: -# """Implements the (basic) user interface for the MHMCHMC kernel. - -# Parameters -# ---------- -# logdensity_fn -# The log-density function we wish to draw samples from. -# step_size -# The value to use for the step size in the symplectic integrator. -# divergence_threshold -# The absolute value of the difference in energy between two states above -# which we say that the transition is divergent. The default value is -# commonly found in other libraries, and yet is arbitrary. -# integrator -# (algorithm parameter) The symplectic integrator to use to integrate the trajectory. -# next_random_arg_fn -# Function that generates the next `random_generator_arg` from its previous value. -# integration_steps_fn -# Function that generates the next pseudo or quasi-random number of integration steps in the -# sequence, given the current `random_generator_arg`. - - -# Returns -# ------- -# A ``SamplingAlgorithm``. -# """ - -# kernel = build_kernel( -# logdensity_fn=logdensity_fn, -# integrator=integrator, -# inverse_mass_matrix=inverse_mass_matrix, -# divergence_threshold=divergence_threshold, -# ) - -# def init_fn(position: ArrayLikeTree, rng_key=None): -# del rng_key -# return init(position, logdensity_fn) - -# def update_fn(rng_key: PRNGKey, state): -# return kernel( -# rng_key=rng_key, -# state=state, -# step_size=step_size, -# num_integration_steps=num_integration_steps, -# L_proposal_factor=L_proposal_factor, -# ) - -# return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] - - -# def adjusted_mclmc_proposal( -# integrator: Callable, -# step_size: Union[float, ArrayLikeTree], -# L_proposal_factor: float, -# num_integration_steps: int = 1, -# divergence_threshold: float = 1000, -# *, -# sample_proposal: Callable = static_binomial_sampling, -# ) -> Callable: -# """Vanilla MHMCHMC algorithm. - -# The algorithm integrates the trajectory applying a integrator -# `num_integration_steps` times in one direction to get a proposal and uses a -# Metropolis-Hastings acceptance step to either reject or accept this -# proposal. This is what people usually refer to when they talk about "the -# HMC algorithm". - -# Parameters -# ---------- -# integrator -# integrator used to build the trajectory step by step. -# kinetic_energy -# Function that computes the kinetic energy. -# step_size -# Size of the integration step. -# num_integration_steps -# Number of times we run the integrator to build the trajectory -# divergence_threshold -# Threshold above which we say that there is a divergence. - -# Returns -# ------- -# A kernel that generates a new chain state and information about the transition. - -# """ - -# def step(i, vars): -# state, kinetic_energy, rng_key = vars -# rng_key, next_rng_key = jax.random.split(rng_key) -# next_state, next_kinetic_energy = integrator( -# state, step_size, L_proposal_factor, rng_key -# ) - -# return next_state, kinetic_energy + next_kinetic_energy, next_rng_key - -# def build_trajectory(state, num_integration_steps, rng_key): -# return jax.lax.fori_loop( -# 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) -# ) - -# def generate( -# rng_key, state: integrators.IntegratorState -# ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: -# """Generate a new chain state.""" -# end_state, kinetic_energy, rng_key = build_trajectory( -# state, num_integration_steps, rng_key -# ) - -# new_energy = -end_state.logdensity -# delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy -# delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) -# is_diverging = -delta_energy > divergence_threshold -# sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) -# do_accept, p_accept, other_proposal_info = info - -# info = HMCInfo( -# state.momentum, -# p_accept, -# do_accept, -# is_diverging, -# new_energy, -# end_state, -# num_integration_steps, -# nonans=True -# ) - -# return sampled_state, info, other_proposal_info - -# return generate +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" +from typing import Callable, Union + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.hmc import HMCInfo +from blackjax.mcmc.proposal import static_binomial_sampling +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.util import generate_unit_vector + +__all__ = ["init", "build_kernel", "as_top_level_api"] + + +def init(position: ArrayLikeTree, logdensity_fn: Callable): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return HMCState(position, logdensity, logdensity_grad) + + +def build_kernel( + integrator: Callable = integrators.isokinetic_mclachlan, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + L_proposal_factor: float = jnp.inf, +): + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + """ + + def kernel( + rng_key: PRNGKey, + state: HMCState, + logdensity_fn: Callable, + step_size: float, + integration_steps_fn, + inverse_mass_matrix=1.0, + ) -> tuple[HMCState, HMCInfo]: + """Generate a new sample with the MHMCHMC kernel.""" + + + # num_integration_steps = integration_steps_fn(state.random_generator_arg) + num_integration_steps = integration_steps_fn(None) + + key_momentum, key_integrator = jax.random.split(rng_key, 2) + momentum = generate_unit_vector(key_momentum, state.position) + proposal, info, _ = adjusted_mclmc_proposal( + integrator=integrators.with_isokinetic_maruyama( + integrator( + logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix + ) + ), + step_size=step_size, + L_proposal_factor=L_proposal_factor * (num_integration_steps * step_size), + num_integration_steps=num_integration_steps, + divergence_threshold=divergence_threshold, + )( + key_integrator, + integrators.IntegratorState( + state.position, momentum, state.logdensity, state.logdensity_grad + ), + ) + + return ( + HMCState( + proposal.position, + proposal.logdensity, + proposal.logdensity_grad, + # next_random_arg_fn(state.random_generator_arg), + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + L_proposal_factor: float = jnp.inf, + inverse_mass_matrix=1.0, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.isokinetic_mclachlan, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel( + integrator=integrator, + next_random_arg_fn=next_random_arg_fn, + divergence_threshold=divergence_threshold, + L_proposal_factor=L_proposal_factor, + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key: PRNGKey, state): + return kernel( + rng_key=rng_key, + state=state, + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=integration_steps_fn, + inverse_mass_matrix=inverse_mass_matrix, + ) + + return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] + + +def adjusted_mclmc_proposal( + integrator: Callable, + step_size: Union[float, ArrayLikeTree], + L_proposal_factor: float, + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, +) -> Callable: + """Vanilla MHMCHMC algorithm. + + The algorithm integrates the trajectory applying a integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + + def step(i, vars): + state, kinetic_energy, rng_key = vars + rng_key, next_rng_key = jax.random.split(rng_key) + next_state, next_kinetic_energy = integrator( + state, step_size, L_proposal_factor, rng_key + ) + + return next_state, kinetic_energy + next_kinetic_energy, next_rng_key + + def build_trajectory(state, num_integration_steps, rng_key): + return jax.lax.fori_loop( + 0 * num_integration_steps, num_integration_steps, step, (state, 0, rng_key) + ) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, kinetic_energy, rng_key = build_trajectory( + state, num_integration_steps, rng_key + ) + + new_energy = -end_state.logdensity + delta_energy = -state.logdensity + end_state.logdensity - kinetic_energy + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + is_diverging = -delta_energy > divergence_threshold + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + + info = HMCInfo( + state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) + +def make_random_trajectory_length_fn(random_trajectory_length : bool): + if random_trajectory_length: + integration_steps_fn = lambda avg_num_integration_steps: lambda k: (jnp.clip(jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), min=1)).astype('int32') + else: + integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.clip(jnp.ceil( + avg_num_integration_steps + ), min=1).astype('int32') + return integration_steps_fn \ No newline at end of file diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index aaf171a44..dff277738 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -70,6 +70,7 @@ def kernel( ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" + num_integration_steps = integration_steps_fn(state.random_generator_arg) key_momentum, key_integrator = jax.random.split(rng_key, 2) From ce2fe53fcae3aa788780bc14f293a73340e69293 Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Wed, 27 Aug 2025 23:26:39 -0700 Subject: [PATCH 56/63] tryinmg to fix (overwrite later) --- blackjax/adaptation/ensemble_mclmc.py | 8 ++++---- blackjax/util.py | 28 ++++++++++++++++++++++----- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 1b71a8d95..b8dba0090 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -181,7 +181,7 @@ def laps( ensemble_observables=None, diagnostics=True, contract=lambda x: 0.0, - superchain_size= None, + superchain_size= 1, ): """ model: the target density object @@ -203,12 +203,12 @@ def laps( ensemble_observables: observable calculated over the ensemble (for diagnostic use) diagnostics: whether to return diagnostics """ - + key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) # initialize the chains initial_state = umclmc.initialize( - key_init, logdensity_fn, sample_init, num_chains, mesh, superchain_size=superchain_size + key_init, logdensity_fn, sample_init, num_chains, mesh, superchain_size ) # burn-in with the unadjusted method # @@ -220,7 +220,7 @@ def laps( bias_type=3, save_num=save_num, C=C, - power=3.0 / 8.0, + power= 3.0 / 8.0, r_end=r_end, observables_for_bias=observables_for_bias, contract=contract, diff --git a/blackjax/util.py b/blackjax/util.py index 32025dbba..df680be15 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -356,12 +356,14 @@ def step(state_all, xs): return (state, adaptation_state), info_to_be_stored + return add_ensemble_info(add_splitR(step, num_chains, superchain_size), ensemble_info) def add_splitR(step, num_chains, superchain_size): - def _step(state_all, xs): + + def _step_with_R(state_all, xs): state_all, info_to_be_stored = step(state_all, xs) @@ -376,9 +378,25 @@ def _step(state_all, xs): return (state, adaptation_state), info_to_be_stored - return _step if superchain_size is not None else step + def _step_with_R_1(state_all, xs): + state_all, info_to_be_stored = step(state_all, xs) + + info_to_be_stored['R_avg'] = 0. + info_to_be_stored['R_max'] = 0. + return state_all, info_to_be_stored + + if superchain_size == None: + return step + + if superchain_size == 1: + return _step_with_R_1 + + else: + return _step_with_R + + def add_ensemble_info(step, ensemble_info): def _step(state_all, xs): @@ -425,8 +443,8 @@ def run_eca( adaptation.summary_statistics_fn, adaptation.update, num_chains, - superchain_size, - ensemble_info, + superchain_size= superchain_size, + ensemble_info = ensemble_info, ) def all_steps(initial_state, keys_sampling, keys_adaptation): @@ -593,7 +611,7 @@ def F(x, keys): F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False ) - if superchain_size == None: + if superchain_size == 1: _keys = split(rng_key, num_chains) else: From 649a3169066f52d859d18d6c32a04cafc579806e Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Tue, 16 Sep 2025 12:13:10 -0700 Subject: [PATCH 57/63] las --- blackjax/__init__.py | 2 + blackjax/adaptation/adjusted_abla.py | 5 +- blackjax/adaptation/las.py | 127 ++++++++++++++++++++ blackjax/adaptation/mass_matrix.py | 2 +- blackjax/adaptation/unadjusted_alba.py | 14 ++- blackjax/adaptation/unadjusted_step_size.py | 6 + blackjax/mcmc/dynamic_malt.py | 2 + blackjax/mcmc/malt.py | 2 + blackjax/mcmc/mclmc.py | 11 ++ blackjax/util.py | 6 +- 10 files changed, 168 insertions(+), 9 deletions(-) create mode 100644 blackjax/adaptation/las.py diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 8cd77c644..8c8aa0ad9 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -10,6 +10,7 @@ from .adaptation.unadjusted_alba import unadjusted_alba from .adaptation.unadjusted_step_size import robnik_step_size_tuning from .adaptation.adjusted_abla import alba_adjusted +from .adaptation.las import las from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat @@ -181,4 +182,5 @@ def generate_top_level_api_from(module): "unadjusted_alba", "robnik_step_size_tuning", "alba_adjusted", + "las", ] diff --git a/blackjax/adaptation/adjusted_abla.py b/blackjax/adaptation/adjusted_abla.py index f9fc77075..0294eb053 100644 --- a/blackjax/adaptation/adjusted_abla.py +++ b/blackjax/adaptation/adjusted_abla.py @@ -128,7 +128,8 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): (state, params), adaptation_info = unadjusted_warmup.run(unadjusted_warmup_key, position, num_unadjusted_steps) - jax.debug.print("unadjusted params: {params}", params=(params["L"], params["step_size"])) + # jax.debug.print("unadjusted params: {params}", params=(params["L"], params["step_size"])) + # jax.debug.print("unadjusted params: {params}", params=params) integration_steps_fn = make_random_trajectory_length_fn(random_trajectory_length=True) @@ -145,7 +146,7 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): state, params, adaptation_info = adjusted_warmup.run(adjusted_warmup_key, state.position, num_steps) - jax.debug.print("adjusted params: {params}", params=(params["L"], params["step_size"])) + # jax.debug.print("adjusted params: {params}", params=(params["L"], params["step_size"])) # raise Exception("stop") # return None return state, params, adaptation_info diff --git a/blackjax/adaptation/las.py b/blackjax/adaptation/las.py new file mode 100644 index 000000000..6fe9150a9 --- /dev/null +++ b/blackjax/adaptation/las.py @@ -0,0 +1,127 @@ +import jax +import jax.numpy as jnp +import blackjax +from blackjax.util import run_inference_algorithm +import blackjax + +from blackjax.adaptation.unadjusted_alba import unadjusted_alba +from blackjax.adaptation.unadjusted_step_size import robnik_step_size_tuning +from blackjax.adaptation.unadjusted_alba import unadjusted_alba +import math +from blackjax.mcmc.adjusted_mclmc_dynamic import make_random_trajectory_length_fn +from functools import partial + +def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_preconditioning=True): + + # begin by running unadjusted alba tuning for umclmc + + + init_key, tune_key, run_key = jax.random.split(key, 3) + initial_position = jax.random.normal(init_key, (ndims,)) + + integrator = blackjax.mcmc.integrators.isokinetic_mclachlan + + num_alba_steps = 10000 + warmup = unadjusted_alba( + algorithm=blackjax.mclmc, + logdensity_fn=logdensity_fn, integrator=integrator, + target_eevpd=5e-4, + v=jnp.sqrt(ndims), + num_alba_steps=num_alba_steps, + preconditioning=diagonal_preconditioning, + alba_factor=0.4, + ) + + + # run warmup + (blackjax_state_after_tuning, blackjax_mclmc_sampler_params), adaptation_info = warmup.run(tune_key, initial_position, 20000) + + ess_per_sample = blackjax_mclmc_sampler_params['ESS'] + print(ess_per_sample, "ESS") + # get_final_sample = lambda state, info: (model.default_event_space_bijector(state.position), info) + + num_steps = math.ceil(200 // ess_per_sample) + + alg = blackjax.mclmc( + logdensity_fn=logdensity_fn, + L=blackjax_mclmc_sampler_params['L'], + step_size=blackjax_mclmc_sampler_params['step_size'], + inverse_mass_matrix=blackjax_mclmc_sampler_params['inverse_mass_matrix'], + integrator=integrator, + ) + + final_output, history = run_inference_algorithm( + rng_key=key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=(lambda a, b: a), + progress_bar=False, + ) + + samples = history.position + + print(samples.shape) + + subsamples = samples[::math.ceil(1/ess_per_sample)] + + integration_steps_fn = make_random_trajectory_length_fn(True) + + num_steps_per_traj = blackjax_mclmc_sampler_params['L'] / blackjax_mclmc_sampler_params['step_size'] + + + + # initial_states = blackjax.adjusted_mclmc_dynamic.init( + # position=history.position[-1], + # logdensity_fn=logdensity_fn, + # random_generator_arg=jax.random.key(0), + # ) + + initial_states = jax.lax.map(lambda x: blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), xs=subsamples) + + print(initial_states, "initial_states") + + def f(step_size_and_positions): + + return jax.lax.map(lambda x: run_inference_algorithm( + rng_key=key, + initial_state=blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), + inference_algorithm=blackjax.adjusted_mclmc_dynamic( + logdensity_fn=logdensity_fn, + step_size=step_size_and_positions[0], + integration_steps_fn=integration_steps_fn(num_steps_per_traj), + integrator=blackjax.mcmc.integrators.isokinetic_velocity_verlet, + inverse_mass_matrix=blackjax_mclmc_sampler_params['inverse_mass_matrix'], + L_proposal_factor=jnp.inf, + ), + num_steps=1, + transform=(lambda a, b: (a,b)), + progress_bar=False, + ), xs=step_size_and_positions[1]) + g = lambda x: (blackjax_mclmc_sampler_params['step_size'],x[0].position) + + + + step_size, position = feedback(f,g, 10, (blackjax_mclmc_sampler_params['step_size'], subsamples)) + + # results = f((1.0, initial_states)) + # step_size = g(results) + # print(step_size, "step_size") + + # history, final_output = results + + # print(history[0].position.shape) + # print(history[1].acceptance_rate.mean()) + + return position + +# a ~ (stepsize, position), b ~ (results) +# type: forall a, b: (a -> b) -> (b -> a) -> Int -> (a -> b) +def feedback(f,g, n, state_a): + for i in range(n): + state_b = f(state_a) + # print(state_b, "state_b") + state_a = g(state_b) + # print(state_a, "state_a") + return state_a + diff --git a/blackjax/adaptation/mass_matrix.py b/blackjax/adaptation/mass_matrix.py index dc0730161..bb74fa2ee 100644 --- a/blackjax/adaptation/mass_matrix.py +++ b/blackjax/adaptation/mass_matrix.py @@ -132,7 +132,7 @@ def final(mm_state: MassMatrixAdaptationState) -> MassMatrixAdaptationState: """Final iteration of the mass matrix adaptation. In this step we compute the mass matrix from the covariance matrix computed - by the Welford algorithm, and re-initialize the later. + by the Welford algorithm, and re-initialize the latter. """ _, wc_state = mm_state diff --git a/blackjax/adaptation/unadjusted_alba.py b/blackjax/adaptation/unadjusted_alba.py index 10e9746a8..af4781f84 100644 --- a/blackjax/adaptation/unadjusted_alba.py +++ b/blackjax/adaptation/unadjusted_alba.py @@ -33,11 +33,11 @@ def base( ) -> tuple[Callable, Callable, Callable]: mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal) - # if not preconditioning: + if not preconditioning: - # mm_update = lambda x, y: x + mm_update = lambda x, y: x - # mm_final = lambda x: x + mm_final = lambda x: x # step_size_init, step_size_update, step_size_final = dual_averaging_adaptation(target_eevpd) step_size_init, step_size_update, step_size_final = robnik_step_size_tuning(desired_energy_var=target_eevpd) @@ -241,6 +241,7 @@ def one_step(carry, xs): new_state.position, info, ) + # jax.debug.print("info: {x}", x=(new_adaptation_state.step_size, info.energy_change)) # jax.debug.print("step sizes: {x}", x=(adaptation_state.step_size, new_adaptation_state.step_size)) return ( @@ -268,7 +269,7 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): last_chain_state, last_warmup_state, *_ = last_state step_size, L, inverse_mass_matrix = adapt_final(last_warmup_state) - jax.debug.print("unadjusted L before alba: {params}", params=(L, step_size)) + # jax.debug.print("unadjusted L before alba: {params}", params=(L, step_size)) ### ### ALBA TUNING @@ -288,10 +289,12 @@ def step(state, key): return next_state, next_state.position if num_alba_steps > 0: + print("params before alba tuning", L, step_size) _, samples = jax.lax.scan(step, last_chain_state, keys) flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) - print(jnp.mean(ess), num_alba_steps, "\n\ness (blackjax internal)\n") + # print(num_alba_steps/jnp.mean(ess), "ess (blackjax internal)\n") + # print( effective_sample_size(flat_samples[None, ...]), "ess (blackjax internal)\n") # print("L etc", L, step_size, jnp.mean(ess), num_alba_steps, jnp.mean(num_alba_steps / ess)) L=alba_factor * step_size * jnp.mean(num_alba_steps / ess) @@ -305,6 +308,7 @@ def step(state, key): "step_size": step_size, "inverse_mass_matrix": inverse_mass_matrix, "L": jnp.clip(L, max=step_size*max_num_steps), + "ESS": jnp.mean(ess)/num_alba_steps, **extra_parameters, } diff --git a/blackjax/adaptation/unadjusted_step_size.py b/blackjax/adaptation/unadjusted_step_size.py index f13aa7533..c0753345b 100644 --- a/blackjax/adaptation/unadjusted_step_size.py +++ b/blackjax/adaptation/unadjusted_step_size.py @@ -18,6 +18,11 @@ def init(initial_step_size, num_dimensions): def update(robnik_state, info): + # jax.debug.print("robnik state: {x}", x=robnik_state) + # jax.debug.print("info: {x}", x=(info.energy_change, info.nonans)) + # raise Exception("Stop here") + + energy_change = info.energy_change @@ -40,6 +45,7 @@ def update(robnik_state, info): step_size = (step_size < robnik_state.step_size_max) * step_size + ( step_size > robnik_state.step_size_max ) * robnik_state.step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences + # jax.debug.print("new step_size: {x}", x=(step_size)) # old_robnik_state = robnik_state diff --git a/blackjax/mcmc/dynamic_malt.py b/blackjax/mcmc/dynamic_malt.py index cd2fbd0de..91c1e592f 100644 --- a/blackjax/mcmc/dynamic_malt.py +++ b/blackjax/mcmc/dynamic_malt.py @@ -86,6 +86,8 @@ def kernel( state.random_generator_arg, **integration_steps_kwargs ).astype(int) + # jax.debug.print("num_integration_steps {x}", x=(num_integration_steps, step_size, inverse_mass_matrix)) + hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad) hmc_proposal, info = hmc_base( rng_key, diff --git a/blackjax/mcmc/malt.py b/blackjax/mcmc/malt.py index 90b007e9d..57854ee26 100644 --- a/blackjax/mcmc/malt.py +++ b/blackjax/mcmc/malt.py @@ -77,6 +77,8 @@ def kernel( L = num_integration_steps * step_size L_proposal = L_proposal_factor * L + # jax.debug.print("L_proposal {L_proposal}",L_proposal=L_proposal) + key_trajectory, key_momentum, key_integrator = jax.random.split(rng_key, 3) metric = metrics.default_metric(inverse_mass_matrix) symplectic_integrator = lambda state, step_size, rng_key: integrators.with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy, inverse_mass_matrix)(state, step_size, L_proposal=L_proposal, rng_key=rng_key) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ca245fe52..9ad52c001 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp +import time from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import ( @@ -93,10 +94,16 @@ def kernel( rng_key: PRNGKey, state: IntegratorState, logdensity_fn, L: float, step_size: float, inverse_mass_matrix, ) -> tuple[IntegratorState, MCLMCInfo]: + # tic = time.time() step = with_isokinetic_maruyama( integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) + + # jax.debug.print("L {x}", x=L) + # jax.debug.print("step_size {x}", x=step_size) + # jax.debug.print("inverse_mass_matrix {x}", x=inverse_mass_matrix) + kernel_key, energy_cutoff_key, nan_key = jax.random.split(rng_key, 3) @@ -108,6 +115,9 @@ def kernel( eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var ndims = pytree_size(position) + # jax.debug.print("kinetic_change: {x}", x=kinetic_change) + # jax.debug.print("potential energy_change: {x}", x=state.logdensity - logdensity) + new_state, info = handle_high_energy(state, IntegratorState(position, momentum, logdensity, logdensitygrad), MCLMCInfo( logdensity=logdensity, energy_change=energy_change, @@ -117,6 +127,7 @@ def kernel( new_state, info = handle_nans(state, new_state, info, nan_key) + # jax.debug.print("Time taken in chain: {x}", x=time.time() - tic) return new_state, info return kernel diff --git a/blackjax/util.py b/blackjax/util.py index 920198286..839f48d6e 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -15,7 +15,8 @@ from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.diagnostics import splitR - +import time +import jax @partial(jit, static_argnames=("precision",), inline=True) def linear_map(diag_or_dense_a, b, *, precision="highest"): @@ -199,6 +200,7 @@ def run_inference_algorithm( keys = split(rng_key, num_steps) + def one_step(state, xs): _, rng_key = xs state, info = inference_algorithm.step(rng_key, state) @@ -207,7 +209,9 @@ def one_step(state, xs): scan_fn = gen_scan_fn(num_steps, progress_bar) xs = jnp.arange(num_steps), keys + # tic = time.time() final_state, history = scan_fn(one_step, initial_state, xs) + # toc = time.time() return final_state, history From 18209ca1da291e63b91937958fea4201932188e3 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Wed, 17 Sep 2025 15:26:42 -0700 Subject: [PATCH 58/63] las --- blackjax/adaptation/las.py | 96 +++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 43 deletions(-) diff --git a/blackjax/adaptation/las.py b/blackjax/adaptation/las.py index 6fe9150a9..4bde449a3 100644 --- a/blackjax/adaptation/las.py +++ b/blackjax/adaptation/las.py @@ -11,10 +11,12 @@ from blackjax.mcmc.adjusted_mclmc_dynamic import make_random_trajectory_length_fn from functools import partial -def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_preconditioning=True): +# unbelievable that this is not in the standard library +def compose(f, g): + return lambda x: f(g(x)) - # begin by running unadjusted alba tuning for umclmc +def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_preconditioning=True): init_key, tune_key, run_key = jax.random.split(key, 3) initial_position = jax.random.normal(init_key, (ndims,)) @@ -31,14 +33,11 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ preconditioning=diagonal_preconditioning, alba_factor=0.4, ) - # run warmup (blackjax_state_after_tuning, blackjax_mclmc_sampler_params), adaptation_info = warmup.run(tune_key, initial_position, 20000) ess_per_sample = blackjax_mclmc_sampler_params['ESS'] - print(ess_per_sample, "ESS") - # get_final_sample = lambda state, info: (model.default_event_space_bijector(state.position), info) num_steps = math.ceil(200 // ess_per_sample) @@ -60,63 +59,74 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ ) samples = history.position - - print(samples.shape) - subsamples = samples[::math.ceil(1/ess_per_sample)] integration_steps_fn = make_random_trajectory_length_fn(True) - num_steps_per_traj = blackjax_mclmc_sampler_params['L'] / blackjax_mclmc_sampler_params['step_size'] - - - - # initial_states = blackjax.adjusted_mclmc_dynamic.init( - # position=history.position[-1], - # logdensity_fn=logdensity_fn, - # random_generator_arg=jax.random.key(0), - # ) initial_states = jax.lax.map(lambda x: blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), xs=subsamples) print(initial_states, "initial_states") - def f(step_size_and_positions): + def make_mams_step(key): + def mams_step(step_size_positions_info): + + step_size, positions, info = step_size_positions_info + num_steps_per_traj = blackjax_mclmc_sampler_params['L'] / step_size + + alg = blackjax.adjusted_mclmc_dynamic( + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=integration_steps_fn(num_steps_per_traj), + integrator=blackjax.mcmc.integrators.isokinetic_velocity_verlet, + inverse_mass_matrix=blackjax_mclmc_sampler_params['inverse_mass_matrix'], + L_proposal_factor=jnp.inf, + ) + + new_states, infos = jax.lax.map(lambda x: alg.step( + rng_key=key, + state=blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, key), + ), xs=positions) + return (step_size, new_states.position, infos) + + return mams_step - return jax.lax.map(lambda x: run_inference_algorithm( - rng_key=key, - initial_state=blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), - inference_algorithm=blackjax.adjusted_mclmc_dynamic( - logdensity_fn=logdensity_fn, - step_size=step_size_and_positions[0], - integration_steps_fn=integration_steps_fn(num_steps_per_traj), - integrator=blackjax.mcmc.integrators.isokinetic_velocity_verlet, - inverse_mass_matrix=blackjax_mclmc_sampler_params['inverse_mass_matrix'], - L_proposal_factor=jnp.inf, - ), - num_steps=1, - transform=(lambda a, b: (a,b)), - progress_bar=False, - ), xs=step_size_and_positions[1]) - g = lambda x: (blackjax_mclmc_sampler_params['step_size'],x[0].position) + + def tuning_step(old_step_size_positions_info): + + old_step_size, old_positions, old_infos = old_step_size_positions_info + acc_rate = old_infos.acceptance_rate.mean() + + new_step_size = jax.lax.cond(acc_rate < 0.8, lambda: old_step_size * 0.5, lambda: old_step_size * 2.0) + + return (new_step_size, old_positions, old_infos) + + step = lambda key: compose(tuning_step, make_mams_step(key)) + + _, _, infos = make_mams_step(jax.random.key(0))((blackjax_mclmc_sampler_params['step_size'], subsamples, None)) + + positions = subsamples + step_size = blackjax_mclmc_sampler_params['step_size'] + + (step_size, position, infos), (step_sizes, positions, infos) = jax.lax.scan(lambda state, key: (step(key)(state), step(key)(state)), (step_size, subsamples, infos), jax.random.split(jax.random.key(0), 10)) + print(position.shape, "position") + print(step_sizes.shape, "step_sizes") + print(positions.shape, "positions") - step_size, position = feedback(f,g, 10, (blackjax_mclmc_sampler_params['step_size'], subsamples)) + + # for i in range(10): + # step_size, positions = step(jax.random.key(i))((step_size, positions)) - # results = f((1.0, initial_states)) - # step_size = g(results) - # print(step_size, "step_size") - # history, final_output = results - # print(history[0].position.shape) - # print(history[1].acceptance_rate.mean()) + # step_size, position = feedback(mams_step,tuning_step, 10, (blackjax_mclmc_sampler_params['step_size'], subsamples)) - return position + return samples, positions, infos, num_steps -# a ~ (stepsize, position), b ~ (results) # type: forall a, b: (a -> b) -> (b -> a) -> Int -> (a -> b) +# e.g.: a ~ (stepsize, position), b ~ (state) def feedback(f,g, n, state_a): for i in range(n): state_b = f(state_a) From c6e99cee73d762879b036a4984e79ded3649c32c Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Wed, 17 Sep 2025 15:28:35 -0700 Subject: [PATCH 59/63] las --- blackjax/adaptation/las.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/blackjax/adaptation/las.py b/blackjax/adaptation/las.py index 4bde449a3..65a178408 100644 --- a/blackjax/adaptation/las.py +++ b/blackjax/adaptation/las.py @@ -66,8 +66,6 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ initial_states = jax.lax.map(lambda x: blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), xs=subsamples) - print(initial_states, "initial_states") - def make_mams_step(key): def mams_step(step_size_positions_info): @@ -108,21 +106,8 @@ def tuning_step(old_step_size_positions_info): positions = subsamples step_size = blackjax_mclmc_sampler_params['step_size'] - (step_size, position, infos), (step_sizes, positions, infos) = jax.lax.scan(lambda state, key: (step(key)(state), step(key)(state)), (step_size, subsamples, infos), jax.random.split(jax.random.key(0), 10)) - print(position.shape, "position") - print(step_sizes.shape, "step_sizes") - print(positions.shape, "positions") - - - # for i in range(10): - # step_size, positions = step(jax.random.key(i))((step_size, positions)) - - - - # step_size, position = feedback(mams_step,tuning_step, 10, (blackjax_mclmc_sampler_params['step_size'], subsamples)) - return samples, positions, infos, num_steps # type: forall a, b: (a -> b) -> (b -> a) -> Int -> (a -> b) From 029c78b57ee89c1d1fea3994ecb71fd1bb091fd6 Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Wed, 17 Sep 2025 23:20:17 -0700 Subject: [PATCH 60/63] las comments --- blackjax/adaptation/las.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/las.py b/blackjax/adaptation/las.py index 65a178408..c8b4e12e4 100644 --- a/blackjax/adaptation/las.py +++ b/blackjax/adaptation/las.py @@ -21,8 +21,11 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ init_key, tune_key, run_key = jax.random.split(key, 3) initial_position = jax.random.normal(init_key, (ndims,)) + ### Phase 1: unadjusted ### + integrator = blackjax.mcmc.integrators.isokinetic_mclachlan + # burn-in and adaptation num_alba_steps = 10000 warmup = unadjusted_alba( algorithm=blackjax.mclmc, @@ -34,9 +37,9 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ alba_factor=0.4, ) - # run warmup (blackjax_state_after_tuning, blackjax_mclmc_sampler_params), adaptation_info = warmup.run(tune_key, initial_position, 20000) + # sampling ess_per_sample = blackjax_mclmc_sampler_params['ESS'] num_steps = math.ceil(200 // ess_per_sample) @@ -56,14 +59,16 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ num_steps=num_steps, transform=(lambda a, b: a), progress_bar=False, - ) - + ) samples = history.position + + + ### Phase 2: adjusted ### + subsamples = samples[::math.ceil(1/ess_per_sample)] integration_steps_fn = make_random_trajectory_length_fn(True) - initial_states = jax.lax.map(lambda x: blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), xs=subsamples) def make_mams_step(key): From 35ba3135691eeb2664a9f36a934d025cfbe24952 Mon Sep 17 00:00:00 2001 From: "jakob.robnik@gmail.com" Date: Fri, 3 Oct 2025 02:20:45 -0700 Subject: [PATCH 61/63] submission --- blackjax/adaptation/ensemble_mclmc.py | 7 ++++--- blackjax/adaptation/ensemble_umclmc.py | 6 +++--- blackjax/util.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 0d93110c1..eb15cf908 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -160,6 +160,7 @@ def while_steps_num(cond): return jnp.argmin(cond) + 1 + def laps( logdensity_fn, sample_init, @@ -173,7 +174,8 @@ def laps( save_frac=0.2, C=0.1, early_stop=True, - r_end=5e-3, + r_end= 0.01, + bias_type= 3, diagonal_preconditioning=True, integrator_coefficients=None, steps_per_sample=15, @@ -218,10 +220,9 @@ def laps( adap = umclmc.Adaptation( ndims, alpha=alpha, - bias_type=3, + bias_type=bias_type, save_num=save_num, C=C, - power= 3.0 / 8.0, r_end=r_end, observables_for_bias=observables_for_bias, contract=contract, diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index 6d69c7124..46b3c8346 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -204,7 +204,6 @@ def __init__( ndims, alpha=1.0, C=0.1, - power=3.0 / 8.0, r_end=0.01, bias_type=0, save_num=10, @@ -215,7 +214,6 @@ def __init__( self.ndims = ndims self.alpha = alpha self.C = C - self.power = power self.r_end = r_end self.observables = observables self.observables_for_bias = observables_for_bias @@ -290,7 +288,9 @@ def update(self, adaptation_state, Etheta): # hyperparameter adaptation # estimate bias bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag - EEVPD_wanted = self.C * jnp.power(bias, self.power) + EEVPD_wanted = self.C * jnp.power(bias, 3./8.) + # bias_asym_wanted = self.C * bias + # EEVPD_wanted = 4 * jnp.power(bias_asym_wanted, 3./2.) / jnp.square(1 + jnp.sqrt(bias_asym_wanted)) # phi function from Robnik et. al., Blackbox Unadjusted Hamiltonian Monte Carlo eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1.0 / 6.0) eps_factor = jnp.clip(eps_factor, 0.3, 3.0) diff --git a/blackjax/util.py b/blackjax/util.py index 839f48d6e..b8a00f9d6 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -500,7 +500,7 @@ def step_while(a): new_step_size = step_size.at[i].set(info.get("step_size")) return (output, i + 1, - True, #info.get("r_max") > adaptation.r_end,x + (info.get("r_max") > adaptation.r_end) | (i < adaptation.save_num), # while is run while this is True new_EEVPD, new_EEVPD_wanted, new_L, new_entropy, new_equi_diag, new_equi_full, new_bias0, new_bias1, new_observables, new_r_avg, new_r_max, new_R_avg, new_R_max, new_step_size) if early_stop: From 5e28ec53cdc1aefd3319da8d400aedeebbcb56f7 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Fri, 3 Oct 2025 13:36:00 -0700 Subject: [PATCH 62/63] pseudofermion --- blackjax/__init__.py | 4 +- .../{adjusted_abla.py => adjusted_alba.py} | 4 +- blackjax/adaptation/las.py | 67 ++++-- blackjax/adaptation/unadjusted_alba.py | 44 ++-- blackjax/mcmc/adjusted_mclmc_dynamic.py | 33 ++- blackjax/mcmc/integrators.py | 2 +- blackjax/mcmc/malt.py | 24 ++- blackjax/mcmc/mclmc.py | 11 + blackjax/mcmc/pseudofermion.py | 197 ++++++------------ 9 files changed, 198 insertions(+), 188 deletions(-) rename blackjax/adaptation/{adjusted_abla.py => adjusted_alba.py} (99%) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 8c8aa0ad9..1a078cc91 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -9,7 +9,7 @@ from .adaptation.window_adaptation import window_adaptation from .adaptation.unadjusted_alba import unadjusted_alba from .adaptation.unadjusted_step_size import robnik_step_size_tuning -from .adaptation.adjusted_abla import alba_adjusted +from .adaptation.adjusted_alba import adjusted_alba from .adaptation.las import las from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess @@ -181,6 +181,6 @@ def generate_top_level_api_from(module): "rhat", "unadjusted_alba", "robnik_step_size_tuning", - "alba_adjusted", + "adjusted_alba", "las", ] diff --git a/blackjax/adaptation/adjusted_abla.py b/blackjax/adaptation/adjusted_alba.py similarity index 99% rename from blackjax/adaptation/adjusted_abla.py rename to blackjax/adaptation/adjusted_alba.py index 0294eb053..a04da52a6 100644 --- a/blackjax/adaptation/adjusted_abla.py +++ b/blackjax/adaptation/adjusted_alba.py @@ -62,7 +62,7 @@ def step(state, key): return ( ((new_adaptation_state, new_kernel_state), L), - info, + None, ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): @@ -94,7 +94,7 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): return AdaptationAlgorithm(run) -def alba_adjusted( +def adjusted_alba( unadjusted_algorithm, logdensity_fn: Callable, target_eevpd, diff --git a/blackjax/adaptation/las.py b/blackjax/adaptation/las.py index 65a178408..613e80761 100644 --- a/blackjax/adaptation/las.py +++ b/blackjax/adaptation/las.py @@ -10,24 +10,26 @@ import math from blackjax.mcmc.adjusted_mclmc_dynamic import make_random_trajectory_length_fn from functools import partial +from blackjax.adaptation.step_size import bisection_monotonic_fn # unbelievable that this is not in the standard library def compose(f, g): return lambda x: f(g(x)) -def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_preconditioning=True): +def las(logdensity_fn, num_chains, key, ndims, num_adjusted_steps, diagonal_preconditioning=True, target_acceptance_rate=0.8): - init_key, tune_key, run_key = jax.random.split(key, 3) + init_key, tune_key, unadjusted_key, adjusted_key = jax.random.split(key, 4) initial_position = jax.random.normal(init_key, (ndims,)) integrator = blackjax.mcmc.integrators.isokinetic_mclachlan - num_alba_steps = 10000 + num_alba_steps = 1000 warmup = unadjusted_alba( algorithm=blackjax.mclmc, logdensity_fn=logdensity_fn, integrator=integrator, target_eevpd=5e-4, + # target_acceptance_rate=target_acceptance_rate, v=jnp.sqrt(ndims), num_alba_steps=num_alba_steps, preconditioning=diagonal_preconditioning, @@ -35,11 +37,11 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ ) # run warmup - (blackjax_state_after_tuning, blackjax_mclmc_sampler_params), adaptation_info = warmup.run(tune_key, initial_position, 20000) + (blackjax_state_after_tuning, blackjax_mclmc_sampler_params), adaptation_info = warmup.run(tune_key, initial_position, 2000) ess_per_sample = blackjax_mclmc_sampler_params['ESS'] - num_steps = math.ceil(200 // ess_per_sample) + num_steps = math.ceil(num_chains // ess_per_sample) alg = blackjax.mclmc( logdensity_fn=logdensity_fn, @@ -50,7 +52,7 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ ) final_output, history = run_inference_algorithm( - rng_key=key, + rng_key=unadjusted_key, initial_state=blackjax_state_after_tuning, inference_algorithm=alg, num_steps=num_steps, @@ -64,13 +66,16 @@ def las(logdensity_fn, key, ndims, num_steps1, num_steps2, num_chains, diagonal_ integration_steps_fn = make_random_trajectory_length_fn(True) - initial_states = jax.lax.map(lambda x: blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), xs=subsamples) + # initial_states = jax.lax.map(lambda x: blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), xs=subsamples) def make_mams_step(key): - def mams_step(step_size_positions_info): + def mams_step(inp): + # init_key, run_key = jax.random.split(key, 2) - step_size, positions, info = step_size_positions_info - num_steps_per_traj = blackjax_mclmc_sampler_params['L'] / step_size + step_size, positions, info, step_size_adaptation_state = inp + keys = jax.random.split(key, positions.shape[0]) + # num_steps_per_traj = blackjax_mclmc_sampler_params['L'] / step_size + num_steps_per_traj = 1 alg = blackjax.adjusted_mclmc_dynamic( logdensity_fn=logdensity_fn, @@ -81,34 +86,52 @@ def mams_step(step_size_positions_info): L_proposal_factor=jnp.inf, ) - new_states, infos = jax.lax.map(lambda x: alg.step( - rng_key=key, - state=blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, key), - ), xs=positions) - return (step_size, new_states.position, infos) + # run_keys = jax.random.split(run_key, positions.shape[0]) + + def step_fn(pos_key): + pos, key = pos_key + init_key, run_key = jax.random.split(key, 2) + return alg.step( + rng_key=run_key, + state=blackjax.adjusted_mclmc_dynamic.init(pos, logdensity_fn, init_key), + ) + + new_states, infos = jax.lax.map(step_fn, xs=(positions,keys)) + # jax.debug.print("infos adaptation step: {infos}", infos=jnp.sum(infos.is_accepted)) + return (step_size, new_states.position, infos, step_size_adaptation_state) return mams_step + epsadap_update = bisection_monotonic_fn(target_acceptance_rate) + step_size_adaptation_state_initial = (jnp.array([-jnp.inf, jnp.inf]), False) - def tuning_step(old_step_size_positions_info): + def tuning_step(inp): - old_step_size, old_positions, old_infos = old_step_size_positions_info + old_step_size, old_positions, old_infos, step_size_adaptation_state = inp acc_rate = old_infos.acceptance_rate.mean() + - new_step_size = jax.lax.cond(acc_rate < 0.8, lambda: old_step_size * 0.5, lambda: old_step_size * 2.0) + step_size_adaptation_state, new_step_size = epsadap_update( + step_size_adaptation_state, + old_step_size, + acc_rate, + ) - return (new_step_size, old_positions, old_infos) + return (new_step_size, old_positions, old_infos, step_size_adaptation_state) + # return (10.0, old_positions, old_infos, step_size_adaptation_state) step = lambda key: compose(tuning_step, make_mams_step(key)) - _, _, infos = make_mams_step(jax.random.key(0))((blackjax_mclmc_sampler_params['step_size'], subsamples, None)) + initial_adjusted_key, adjusted_key = jax.random.split(adjusted_key, 2) + _, _, infos, _ = make_mams_step(initial_adjusted_key)((blackjax_mclmc_sampler_params['step_size'], subsamples, None, step_size_adaptation_state_initial)) + positions = subsamples step_size = blackjax_mclmc_sampler_params['step_size'] - (step_size, position, infos), (step_sizes, positions, infos) = jax.lax.scan(lambda state, key: (step(key)(state), step(key)(state)), (step_size, subsamples, infos), jax.random.split(jax.random.key(0), 10)) + (step_size, position, infos, step_size_adaptation_state), (step_sizes, positions, infos, step_size_adaptation_state) = jax.lax.scan(lambda state, key: (step(key)(state), step(key)(state)), (step_size, subsamples, infos, step_size_adaptation_state_initial), jax.random.split(adjusted_key, num_adjusted_steps)) - return samples, positions, infos, num_steps + return samples, positions, infos, num_steps, step_size_adaptation_state # type: forall a, b: (a -> b) -> (b -> a) -> Int -> (a -> b) # e.g.: a ~ (stepsize, position), b ~ (state) diff --git a/blackjax/adaptation/unadjusted_alba.py b/blackjax/adaptation/unadjusted_alba.py index af4781f84..9164ee562 100644 --- a/blackjax/adaptation/unadjusted_alba.py +++ b/blackjax/adaptation/unadjusted_alba.py @@ -17,6 +17,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size from blackjax.adaptation.unadjusted_step_size import robnik_step_size_tuning, RobnikStepSizeTuningState +import math class AlbaAdaptationState(NamedTuple): ss_state: RobnikStepSizeTuningState # step size @@ -205,7 +206,7 @@ def unadjusted_alba( preconditioning: bool = True, is_mass_matrix_diagonal: bool = True, progress_bar: bool = False, - adaptation_info_fn: Callable = return_all_adapt_info, + adaptation_info_fn: Callable = lambda x, y, z : None, integrator=mcmc.integrators.velocity_verlet, num_alba_steps: int = 500, alba_factor: float = 0.4, @@ -274,33 +275,50 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): ### ### ALBA TUNING ### - keys = jax.random.split(alba_key, num_alba_steps) + + jax.debug.print("num_alba_steps: {x}", x=num_alba_steps) + + max_num_steps = 200 + thinning_rate = math.ceil(num_alba_steps / max_num_steps) + jax.debug.print("thinning_rate: {x}", x=thinning_rate) + new_num_alba_steps = math.ceil(num_alba_steps / thinning_rate) + jax.debug.print("new_num_alba_steps: {x}", x=new_num_alba_steps) + + keys = jax.random.split(alba_key, new_num_alba_steps) mcmc_kernel = algorithm.build_kernel(integrator) + def step(state, key): - next_state, _ = mcmc_kernel( - rng_key=key, - state=state, - logdensity_fn=logdensity_fn, - L=L, - step_size=step_size, - inverse_mass_matrix=inverse_mass_matrix, - ) + next_state=state + for i in range(thinning_rate): + key = jax.random.fold_in(key, i) + next_state, _ = mcmc_kernel( + rng_key=key, + state=state, + logdensity_fn=logdensity_fn, + L=L, + step_size=step_size, + inverse_mass_matrix=inverse_mass_matrix, + ) return next_state, next_state.position - if num_alba_steps > 0: - print("params before alba tuning", L, step_size) + if new_num_alba_steps > 0: + jax.debug.print("params before alba tuning {x}", x=(L, step_size)) _, samples = jax.lax.scan(step, last_chain_state, keys) flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) ess = effective_sample_size(flat_samples[None, ...]) + jax.debug.print("ess after alba {x}", x=(L, step_size, ess)) # print(num_alba_steps/jnp.mean(ess), "ess (blackjax internal)\n") # print( effective_sample_size(flat_samples[None, ...]), "ess (blackjax internal)\n") # print("L etc", L, step_size, jnp.mean(ess), num_alba_steps, jnp.mean(num_alba_steps / ess)) - L=alba_factor * step_size * jnp.mean(num_alba_steps / ess) + L=alba_factor * step_size * jnp.mean(new_num_alba_steps / ess) # print("new L", L) # raise Exception("stop") + # else: + # ess = 0.0 + max_num_steps = 500 # jax.debug.print("L: {x}", x=step_size*50.) diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index dff277738..b51a4d551 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" -from typing import Callable, Union +from typing import Callable, NamedTuple, Union import jax import jax.numpy as jnp @@ -20,13 +20,22 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence -from blackjax.mcmc.hmc import HMCInfo +# from blackjax.mcmc.hmc import HMCInfo from blackjax.mcmc.proposal import static_binomial_sampling from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector __all__ = ["init", "build_kernel", "as_top_level_api"] +class MAMSInfo(NamedTuple): + + # momentum: ArrayTree + acceptance_rate: float + is_accepted: bool + is_divergent: bool + energy: float + # proposal: integrators.IntegratorState + num_integration_steps: int def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) @@ -67,12 +76,22 @@ def kernel( step_size: float, integration_steps_fn, inverse_mass_matrix=1.0, - ) -> tuple[DynamicHMCState, HMCInfo]: + ) -> tuple[DynamicHMCState, MAMSInfo]: """Generate a new sample with the MHMCHMC kernel.""" + # return state, HMCInfo( + # state.position, + # 0.0, + # False, + # False, + # state.logdensity, + # None, + # 0, + # ) num_integration_steps = integration_steps_fn(state.random_generator_arg) + key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( @@ -217,7 +236,7 @@ def build_trajectory(state, num_integration_steps, rng_key): def generate( rng_key, state: integrators.IntegratorState - ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + ) -> tuple[integrators.IntegratorState, MAMSInfo, ArrayTree]: """Generate a new chain state.""" end_state, kinetic_energy, rng_key = build_trajectory( state, num_integration_steps, rng_key @@ -230,13 +249,13 @@ def generate( sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info - info = HMCInfo( - state.momentum, + info = MAMSInfo( + # state.momentum, p_accept, do_accept, is_diverging, new_energy, - end_state, + # end_state, num_integration_steps, ) diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index fdbcaa0f8..13987b98f 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -346,7 +346,7 @@ def update( """ del is_last_call - logdensity_grad = logdensity_grad + # logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_grads = flatten_grads * sqrt_inverse_mass_matrix flatten_momentum, _ = ravel_pytree(momentum) diff --git a/blackjax/mcmc/malt.py b/blackjax/mcmc/malt.py index 57854ee26..395655bce 100644 --- a/blackjax/mcmc/malt.py +++ b/blackjax/mcmc/malt.py @@ -23,16 +23,26 @@ from blackjax.mcmc.proposal import safe_energy_diff, static_binomial_sampling from blackjax.mcmc.trajectory import hmc_energy from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey -from blackjax.mcmc.hmc import HMCState, HMCInfo +from blackjax.mcmc.hmc import HMCState __all__ = [ "HMCState", - "HMCInfo", "init", "build_kernel", "as_top_level_api", ] +class MALTInfo(NamedTuple): + + # momentum: ArrayTree + acceptance_rate: float + is_accepted: bool + is_divergent: bool + energy: float + # proposal: integrators.IntegratorState + num_integration_steps: int + + def init(position: ArrayLikeTree, logdensity_fn: Callable): @@ -71,7 +81,7 @@ def kernel( inverse_mass_matrix: metrics.MetricTypes, num_integration_steps: int, - ) -> tuple[HMCState, HMCInfo]: + ) -> tuple[HMCState, MALTInfo]: """Generate a new sample with the HMC kernel.""" L = num_integration_steps * step_size @@ -253,7 +263,7 @@ def hmc_proposal( def generate( rng_key, state: integrators.IntegratorState - ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + ) -> tuple[integrators.IntegratorState, MALTInfo, ArrayTree]: """Generate a new chain state.""" end_state, delta_energy = build_trajectory(state, step_size, num_integration_steps) end_state = flip_momentum(end_state) @@ -270,13 +280,13 @@ def generate( # jax.debug.print("delta_energy {x}",x=delta_energy) # jax.debug.print("p_accept {p_accept}", p_accept=(p_accept, delta_energy)) - info = HMCInfo( - state.momentum, + info = MALTInfo( + # state.momentum, p_accept, do_accept, is_diverging, new_energy, - end_state, + # end_state, num_integration_steps, ) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index 9ad52c001..0b4730786 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -99,6 +99,17 @@ def kernel( integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) ) + # return 0, None + + # state, MCLMCInfo( + + # logdensity=state.logdensity, + # energy_change=0.0, + # kinetic_change=0.0, + # nonans=True + # ) + # raise Exception + # jax.debug.print("L {x}", x=L) # jax.debug.print("step_size {x}", x=step_size) diff --git a/blackjax/mcmc/pseudofermion.py b/blackjax/mcmc/pseudofermion.py index 77f44ca4a..1880b6ec4 100644 --- a/blackjax/mcmc/pseudofermion.py +++ b/blackjax/mcmc/pseudofermion.py @@ -13,139 +13,68 @@ import blackjax.mcmc.trajectory as trajectory from blackjax.base import SamplingAlgorithm from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey - - -class GibbsState(NamedTuple): - - # pos : Any - # aux: ArrayTree - position: ArrayTree - logdensity: float - logdensity_grad: ArrayTree - momentum: ArrayTree - temporary_state : Any - fermion_matrix : Any - count : int - -def build_kernel(): - return () - -def init(position, logdensity_fn, fermion_matrix, temporary_state, init_main, rng_key ): - # state = hmc.state - # state.boson_state = position - # fermion_matrix = hmc.theory.get_fermion_matrix(hmc.state) - # temporary_state = hmc.theory.sample_temporary_state(position,hmc.state,fermion_matrix) - position, momentum, logdensity, logdensity_grad = init_main(position, logdensity_fn(fermion_matrix, temporary_state), rng_key ) - return GibbsState( - position=position, - logdensity=logdensity, - logdensity_grad=logdensity_grad, - momentum=momentum, - temporary_state=temporary_state, - fermion_matrix=fermion_matrix, - count=0, - ) - -def as_top_level_api( - kernel_main, - init_main, - logdensity_fn: Callable, - # step_size: float, - # inverse_mass_matrix: metrics.MetricTypes, - *, - max_num_doublings: int = 10, - divergence_threshold: int = 1000, - # integrator: Callable = integrators.velocity_verlet, - get_fermion_matrix_fn: Callable = None, - sample_temporary_state_fn: Callable = None, - # num_integration_steps: int = 1, - # alg1, -) -> SamplingAlgorithm: - """Implements the (basic) user interface for the nuts kernel. - - Examples - -------- - - A new NUTS kernel can be initialized and used with the following code: - - .. code:: - - nuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix) - state = nuts.init(position) - new_state, info = nuts.step(rng_key, state) - - We can JIT-compile the step function for more speed: - - .. code:: - - step = jax.jit(nuts.step) - new_state, info = step(rng_key, state) - - You can always use the base kernel should you need to: - - .. code:: - - import blackjax.mcmc.integrators as integrators - - kernel = blackjax.nuts.build_kernel(integrators.yoshida) - state = blackjax.nuts.init(position, logdensity_fn) - state, info = kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix) - - Parameters - ---------- - logdensity_fn - The log-density function we wish to draw samples from. - step_size - The value to use for the step size in the symplectic integrator. - inverse_mass_matrix - The value to use for the inverse mass matrix when drawing a value for - the momentum and computing the kinetic energy. - max_num_doublings - The maximum number of times we double the length of the trajectory before - returning if no U-turn has been obserbed or no divergence has occured. - divergence_threshold - The absolute value of the difference in energy between two states above - which we say that the transition is divergent. The default value is - commonly found in other libraries, and yet is arbitrary. - integrator - (algorithm parameter) The symplectic integrator to use to integrate the trajectory. - - Returns - ------- - A ``SamplingAlgorithm``. - - """ - # kernel = build_kernel(integrator, divergence_threshold) - - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - fermion_matrix = get_fermion_matrix_fn(position) - temporary_state = sample_temporary_state_fn(position,fermion_matrix) - return init(position, logdensity_fn, fermion_matrix, temporary_state, init_main, rng_key) - - def step_fn(rng_key: PRNGKey, state): - next_state, info = kernel_main( - rng_key, - state, - logdensity_fn(state.fermion_matrix, state.temporary_state), - # step_size, - # inverse_mass_matrix, - # max_num_doublings, - # num_integration_steps=num_integration_steps, - ) - new_fermion_matrix = get_fermion_matrix_fn(next_state.position) - new_temporary_state = sample_temporary_state_fn(next_state.position, new_fermion_matrix) - full_state = GibbsState( - position=next_state.position, - momentum=None, - # momentum=next_state.momentum, - logdensity=next_state.logdensity, - logdensity_grad=next_state.logdensity_grad, - temporary_state=new_temporary_state, - fermion_matrix=new_fermion_matrix, - count=state.count + info.num_integration_steps, +import blackjax + + +def build_kernel(kernel_1, kernel_2, logdensity_fn): + def kernel(state, rng_key): + next_x_state, info = kernel_1( + rng_key=rng_key, + state=state['x'], + step_size=1e-3, + inverse_mass_matrix=jnp.ones(state['x'].position.shape[0]), + num_integration_steps=1, + logdensity_fn=partial(logdensity_fn, pseudofermion=state['y']), ) - # jax.debug.print("count {x}", x=full_state.count) - return full_state, info + next_y_state = kernel_2(state['y']) + return {'x': next_x_state, 'y': next_y_state}, info + return kernel + +def init(position, logdensity_fn, pseudofermion, init_1, init_2, rng_key ): + key_1, key_2 = jax.random.split(rng_key, 2) + state_b = init_1(position, partial(logdensity_fn, pseudofermion=pseudofermion) ) + state_pf = init_2(pseudofermion) + + return {'x': state_b, 'y': state_pf} + + +# def as_top_level_api( +# kernel_1, +# kernel_2, +# init_1, +# init_2, +# logdensity_fn: Callable, +# ) -> SamplingAlgorithm: + +# # kernel = build_kernel(integrator, divergence_threshold) + +# def init_fn(position: ArrayLikeTree, pseudofermion, rng_key=None): +# del rng_key +# return init(position, logdensity_fn, pseudofermion, init_1, init_2, rng_key) + +# def step_fn(rng_key: PRNGKey, state): +# next_state, info = kernel_1( +# rng_key, +# state, +# logdensity_fn(state.pseudofermion), +# # step_size, +# # inverse_mass_matrix, +# # max_num_doublings, +# # num_integration_steps=num_integration_steps, +# ) +# new_pseudofermion = kernel_2(next_state.position, state.pseudofermion) +# # new_fermion_matrix = get_fermion_matrix_fn(next_state.position) +# full_state = GibbsState( +# position=next_state.position, +# momentum=None, +# # momentum=next_state.momentum, +# logdensity=next_state.logdensity, +# logdensity_grad=next_state.logdensity_grad, +# pseudofermion=new_pseudofermion, +# ) +# # jax.debug.print("count {x}", x=full_state.count) +# return full_state, info + +# return SamplingAlgorithm(init_fn, step_fn) + - return SamplingAlgorithm(init_fn, step_fn) From bffe235a244db88d296c1b60209b7a844dc31e17 Mon Sep 17 00:00:00 2001 From: Reuben Harry Cohn-Gordon Date: Sat, 4 Oct 2025 08:50:57 -0700 Subject: [PATCH 63/63] fix --- blackjax/mcmc/pseudofermion.py | 45 ++++++++-------------------------- 1 file changed, 10 insertions(+), 35 deletions(-) diff --git a/blackjax/mcmc/pseudofermion.py b/blackjax/mcmc/pseudofermion.py index 1880b6ec4..2bae3f6f4 100644 --- a/blackjax/mcmc/pseudofermion.py +++ b/blackjax/mcmc/pseudofermion.py @@ -38,43 +38,18 @@ def init(position, logdensity_fn, pseudofermion, init_1, init_2, rng_key ): return {'x': state_b, 'y': state_pf} -# def as_top_level_api( -# kernel_1, -# kernel_2, -# init_1, -# init_2, -# logdensity_fn: Callable, -# ) -> SamplingAlgorithm: +def as_top_level_api( + kernel_1, + kernel_2, + init_1, + init_2, + logdensity_fn: Callable, +) -> SamplingAlgorithm: -# # kernel = build_kernel(integrator, divergence_threshold) + # kernel = build_kernel(integrator, divergence_threshold) -# def init_fn(position: ArrayLikeTree, pseudofermion, rng_key=None): -# del rng_key -# return init(position, logdensity_fn, pseudofermion, init_1, init_2, rng_key) + return None -# def step_fn(rng_key: PRNGKey, state): -# next_state, info = kernel_1( -# rng_key, -# state, -# logdensity_fn(state.pseudofermion), -# # step_size, -# # inverse_mass_matrix, -# # max_num_doublings, -# # num_integration_steps=num_integration_steps, -# ) -# new_pseudofermion = kernel_2(next_state.position, state.pseudofermion) -# # new_fermion_matrix = get_fermion_matrix_fn(next_state.position) -# full_state = GibbsState( -# position=next_state.position, -# momentum=None, -# # momentum=next_state.momentum, -# logdensity=next_state.logdensity, -# logdensity_grad=next_state.logdensity_grad, -# pseudofermion=new_pseudofermion, -# ) -# # jax.debug.print("count {x}", x=full_state.count) -# return full_state, info - -# return SamplingAlgorithm(init_fn, step_fn) + return SamplingAlgorithm(init_fn, step_fn)