diff --git a/blackjax/__init__.py b/blackjax/__init__.py index ef5eabd79..1a078cc91 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -3,25 +3,32 @@ from blackjax._version import __version__ -from .adaptation.adjusted_mclmc_adaptation import adjusted_mclmc_find_L_and_step_size from .adaptation.chees_adaptation import chees_adaptation -from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size from .adaptation.meads_adaptation import meads_adaptation from .adaptation.pathfinder_adaptation import pathfinder_adaptation from .adaptation.window_adaptation import window_adaptation +from .adaptation.unadjusted_alba import unadjusted_alba +from .adaptation.unadjusted_step_size import robnik_step_size_tuning +from .adaptation.adjusted_alba import adjusted_alba +from .adaptation.las import las from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat -from .mcmc import adjusted_mclmc as _adjusted_mclmc from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc +from .mcmc import dynamic_malt as _dynamic_malt from .mcmc import elliptical_slice as _elliptical_slice from .mcmc import ghmc as _ghmc from .mcmc import hmc as _hmc +from .mcmc import uhmc as _uhmc +from .mcmc import malt as _malt from .mcmc import mala as _mala +from .mcmc import pseudofermion as _pseudofermion from .mcmc import marginal_latent_gaussian from .mcmc import mclmc as _mclmc +from .mcmc import mchmc as _mchmc +from .mcmc import underdamped_langevin as _langevin from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk from .mcmc import rmhmc as _rmhmc @@ -96,12 +103,15 @@ def generate_top_level_api_from(module): # MCMC hmc = generate_top_level_api_from(_hmc) +uhmc = generate_top_level_api_from(_uhmc) +malt = generate_top_level_api_from(_malt) nuts = generate_top_level_api_from(_nuts) rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh) irmh = GenerateSamplingAPI( irmh_as_top_level_api, random_walk.init, random_walk.build_irmh ) dynamic_hmc = generate_top_level_api_from(_dynamic_hmc) +dynamic_malt = generate_top_level_api_from(_dynamic_malt) rmhmc = generate_top_level_api_from(_rmhmc) mala = generate_top_level_api_from(_mala) mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian) @@ -114,12 +124,14 @@ def generate_top_level_api_from(module): additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk) mclmc = generate_top_level_api_from(_mclmc) +mchmc = generate_top_level_api_from(_mchmc) +langevin = generate_top_level_api_from(_langevin) adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic) -adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) +# adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc) elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) - +pseudofermion = generate_top_level_api_from(_pseudofermion) hmc_family = [hmc, nuts] # SMC @@ -165,8 +177,10 @@ def generate_top_level_api_from(module): "meads_adaptation", "chees_adaptation", "pathfinder_adaptation", - "mclmc_find_L_and_step_size", # mclmc adaptation - "adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation "ess", # diagnostics "rhat", + "unadjusted_alba", + "robnik_step_size_tuning", + "adjusted_alba", + "las", ] diff --git a/blackjax/adaptation/__init__.py b/blackjax/adaptation/__init__.py index 53d5fe2b6..bdade7fff 100644 --- a/blackjax/adaptation/__init__.py +++ b/blackjax/adaptation/__init__.py @@ -4,6 +4,8 @@ meads_adaptation, pathfinder_adaptation, window_adaptation, + unadjusted_alba, + unadjusted_step_size, ) __all__ = [ @@ -12,4 +14,6 @@ "window_adaptation", "pathfinder_adaptation", "mclmc_adaptation", + "unadjusted_alba", + "robnik_step_size_tuning", ] diff --git a/blackjax/adaptation/adjusted_alba.py b/blackjax/adaptation/adjusted_alba.py new file mode 100644 index 000000000..a04da52a6 --- /dev/null +++ b/blackjax/adaptation/adjusted_alba.py @@ -0,0 +1,156 @@ +from blackjax.adaptation.step_size import ( + dual_averaging_adaptation, +) +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale +from blackjax.base import AdaptationAlgorithm +from blackjax.types import ArrayLikeTree, PRNGKey +import jax +import jax.numpy as jnp +from typing import Callable +import blackjax +from blackjax.adaptation.unadjusted_alba import unadjusted_alba + + + +def make_random_trajectory_length_fn(random_trajectory_length : bool): + if random_trajectory_length: + integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.where(jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + )==0, 1, jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps))).astype('int32') + else: + integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil( + avg_num_integration_steps + ).astype('int32') + return integration_steps_fn + +def da_adaptation( + algorithm, + logdensity_fn: Callable, + integration_steps_fn: Callable, + inverse_mass_matrix, + initial_step_size: float = 1.0, + target_acceptance_rate: float = 0.80, + initial_L: float = 1.0, + integrator=blackjax.mcmc.integrators.velocity_verlet, + L_proposal_factor=jnp.inf, +): + + da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate) + kernel = algorithm.build_kernel(integrator=integrator, L_proposal_factor=L_proposal_factor) + + # initial_L = jnp.clip(initial_L, min=initial_step_size+0.01) + + + def step(state, key): + + (adaptation_state, kernel_state), L = state + new_kernel_state, info = kernel( + rng_key=key, + state=kernel_state, + logdensity_fn=logdensity_fn, + step_size=jnp.exp(adaptation_state.log_step_size), + inverse_mass_matrix=inverse_mass_matrix, + integration_steps_fn=integration_steps_fn(L/jnp.exp(adaptation_state.log_step_size)), + ) + + new_adaptation_state = da_update( + adaptation_state, + info.acceptance_rate, + ) + + + return ( + ((new_adaptation_state, new_kernel_state), L), + None, + ) + + def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): + + + init_key, rng_key = jax.random.split(rng_key) + + init_kernel_state = algorithm.init(position=position, logdensity_fn=logdensity_fn, random_generator_arg=init_key) + + keys = jax.random.split(rng_key, num_steps) + init_state = da_init(initial_step_size), init_kernel_state + ((adaptation_state, kernel_state), L), info = jax.lax.scan( + step, + (init_state, initial_L), + keys, + + ) + step_size = da_final(adaptation_state) + return ( + kernel_state, + { + "step_size": step_size, + "inverse_mass_matrix": inverse_mass_matrix, + "L": L, + }, + info, + ) + + return AdaptationAlgorithm(run) + + +def adjusted_alba( + unadjusted_algorithm, + logdensity_fn: Callable, + target_eevpd, + v, + adjusted_algorithm, + integrator, + target_acceptance_rate: float = 0.80, + num_alba_steps: int = 500, + alba_factor: float = 0.4, + preconditioning: bool = True, + L_proposal_factor=jnp.inf, + **extra_parameters, + ): + + unadjusted_warmup = unadjusted_alba( + algorithm= unadjusted_algorithm, + logdensity_fn=logdensity_fn, + target_eevpd=target_eevpd, + v=v, + integrator=integrator, + num_alba_steps=num_alba_steps, + alba_factor=alba_factor, + preconditioning=preconditioning, + **extra_parameters) + + def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): + + unadjusted_warmup_key, adjusted_warmup_key = jax.random.split(rng_key) + + num_unadjusted_steps = 20000 + + (state, params), adaptation_info = unadjusted_warmup.run(unadjusted_warmup_key, position, num_unadjusted_steps) + + # jax.debug.print("unadjusted params: {params}", params=(params["L"], params["step_size"])) + # jax.debug.print("unadjusted params: {params}", params=params) + + integration_steps_fn = make_random_trajectory_length_fn(random_trajectory_length=True) + + adjusted_warmup = da_adaptation( + algorithm=adjusted_algorithm, + logdensity_fn=logdensity_fn, + integration_steps_fn=integration_steps_fn, + initial_L=params["L"], + initial_step_size=params["step_size"], + target_acceptance_rate=target_acceptance_rate, + inverse_mass_matrix=params["inverse_mass_matrix"], + integrator=integrator, L_proposal_factor=L_proposal_factor, **extra_parameters) + + + + state, params, adaptation_info = adjusted_warmup.run(adjusted_warmup_key, state.position, num_steps) + # jax.debug.print("adjusted params: {params}", params=(params["L"], params["step_size"])) + # raise Exception("stop") + # return None + return state, params, adaptation_info + + return AdaptationAlgorithm(run) + + \ No newline at end of file diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 408c31383..474bfc5cc 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -28,6 +28,7 @@ def adjusted_mclmc_find_L_and_step_size( max="avg", num_windows=1, tuning_factor=1.3, + euclidean=False, ): """ Finds the optimal value of the parameters for the MH-MCHMC algorithm. @@ -73,7 +74,14 @@ def adjusted_mclmc_find_L_and_step_size( dim = pytree_size(state.position) if params is None: - params = MCLMCAdaptationState( + if euclidean: + + params = MCLMCAdaptationState( + 1.0, 0.2, inverse_mass_matrix=jnp.ones((dim,)) + ) + + else: + params = MCLMCAdaptationState( jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,)) ) @@ -96,6 +104,7 @@ def adjusted_mclmc_find_L_and_step_size( diagonal_preconditioning=diagonal_preconditioning, max=max, tuning_factor=tuning_factor, + euclidean=euclidean, )( state, params, num_steps, window_key ) @@ -113,7 +122,7 @@ def adjusted_mclmc_find_L_and_step_size( ) = adjusted_mclmc_make_adaptation_L( mclmc_kernel, frac=frac_tune3, - Lfactor=0.5, + Lfactor=0.3, max=max, eigenvector=eigenvector, )( @@ -156,6 +165,7 @@ def adjusted_mclmc_make_L_step_size_adaptation( fix_L_first_da=False, max="avg", tuning_factor=1.0, + euclidean=False, ): """Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC""" @@ -207,6 +217,7 @@ def step(iteration_state, weight_and_key): step_size = jax.lax.clamp( 1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1 ) + # jax.debug.print("step size in adaptation {x}",x=step_size) adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size)) x = ravel_pytree(state.position)[0] @@ -256,7 +267,10 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps * frac_tune2 ) - check_key, rng_key = jax.random.split(rng_key, 2) + # jax.debug.print("num steps1 {x}",x=num_steps1) + # jax.debug.print("num steps 2 {x}",x=num_steps2) + + # check_key, rng_key = jax.random.split(rng_key, 2) rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2) L_step_size_adaptation_keys_pass1 = jax.random.split( @@ -293,24 +307,49 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): variances = x_squared_average - jnp.square(x_average) if max == "max": - contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor + if euclidean: + contract = lambda x: (jnp.sqrt(jnp.max(x) * dim) * tuning_factor) / jnp.sqrt(dim) + + else: + contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor elif max == "avg": - contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor + print("avg") + if euclidean: + + contract = lambda x: (jnp.sqrt(jnp.sum(x)) * tuning_factor) / jnp.sqrt(dim) + else: + contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor else: raise ValueError("max should be either 'max' or 'avg'") + new_L = params.L + change = jax.lax.clamp( Lratio_lowerbound, - contract(variances) / params.L, + contract(variances) / new_L, Lratio_upperbound, ) + # if euclidean: + # # new_L /= jnp.sqrt(dim) + # change /= jnp.sqrt(dim) + + params = params._replace( L=params.L * change, step_size=params.step_size * change ) if diagonal_preconditioning: - params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) + if euclidean: + params = params._replace(inverse_mass_matrix=variances, L=1.) + else: + params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim)) + + # else: + # if euclidean: + # params = params._replace(L = params.L / jnp.sqrt(dim)) + + # jax.debug.print("params L {x}", x=(params.L, contract(variances), jnp.sum(variances), tuning_factor)) initial_da, update_da, final_da = dual_averaging_adaptation(target=target) ( @@ -330,6 +369,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): params = params._replace(step_size=final_da(dual_avg_state)) + return state, params, eigenvector, num_tuning_integrator_steps return L_step_size_adaptation diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py new file mode 100644 index 000000000..eb15cf908 --- /dev/null +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -0,0 +1,319 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# """Public API for the MCLMC Kernel""" + +# import jax +# import jax.numpy as jnp +# from blackjax.util import run_eca +# import blackjax.adaptation.ensemble_umclmc as umclmc + + +from typing import Any, NamedTuple + +import jax +import jax.numpy as jnp + +import blackjax.adaptation.ensemble_umclmc as umclmc +from blackjax.adaptation.ensemble_umclmc import ( + equipartition_diagonal, + equipartition_diagonal_loss, +) +from blackjax.adaptation.step_size import bisection_monotonic_fn +from blackjax.mcmc.adjusted_mclmc import build_kernel as build_kernel_malt +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.integrators import ( + generate_isokinetic_integrator, + mclachlan_coefficients, + omelyan_coefficients, +) +from blackjax.util import run_eca + + +class AdaptationState(NamedTuple): + steps_per_sample: float + step_size: float + stepsize_adaptation_state: ( + Any # the state of the bisection algorithm to find a stepsize + ) + iteration: int + + +build_kernel = lambda logdensity_fn, integrator, inverse_mass_matrix: lambda key, state, adap: build_kernel_malt( + # logdensity_fn=logdensity_fn, + integrator=integrator, + L_proposal_factor=1.25, +)( + rng_key=key, + state=state, + logdensity_fn=logdensity_fn, + step_size=adap.step_size, + integration_steps_fn=lambda k:adap.steps_per_sample, + inverse_mass_matrix=inverse_mass_matrix, +) + + +class Adaptation: + def __init__( + self, + adaptation_state, + num_adaptation_samples, # amount of tuning in the adjusted phase before fixing params + steps_per_sample=15, # L/eps + acc_prob_target=0.8, + observables=lambda x: 0.0, # just for diagnostics: some function of a given chain at given timestep + observables_for_bias=lambda x: 0.0, # just for diagnostics: the above, but averaged over all chains + contract=lambda x: 0.0, # just for diagnostics: observables for bias, contracted over dimensions + ): + self.num_adaptation_samples = num_adaptation_samples + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + + # Determine the initial hyperparameters # + + # stepsize # + # if we switched to the more accurate integrator we can use longer step size + # integrator_factor = jnp.sqrt(10.) if mclachlan else 1. + # Let's use the stepsize which will be optimal for the adjusted method. The energy variance after N steps scales as sigma^2 ~ N^2 eps^6 = eps^4 L^2 + # In the adjusted method we want sigma^2 = 2 mu = 2 * 0.41 = 0.82 + # With the current eps, we had sigma^2 = EEVPD * d for N = 1. + # Combining the two we have EEVPD * d / 0.82 = eps^6 / eps_new^4 L^2 + # adjustment_factor = jnp.power(0.82 / (ndims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) + step_size = adaptation_state.step_size + + # Initialize the bisection for finding the step size + self.epsadap_update = bisection_monotonic_fn(acc_prob_target) + stepsize_adaptation_state = (jnp.array([-jnp.inf, jnp.inf]), False) + + self.initial_state = AdaptationState( + steps_per_sample, step_size, stepsize_adaptation_state, 0 + ) + + def summary_statistics_fn(self, state, info, rng_key): + return { + "acceptance_probability": info.acceptance_rate, + "equipartition_diagonal": equipartition_diagonal(state), + "observables": self.observables(state.position), + "observables_for_bias": self.observables_for_bias(state.position), + } + + def update(self, adaptation_state, Etheta): + acc_prob = Etheta["acceptance_probability"] + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + true_bias = self.contract(Etheta["observables_for_bias"]) + + info_to_be_stored = { + "L": adaptation_state.step_size * adaptation_state.steps_per_sample, + "steps_per_sample": adaptation_state.steps_per_sample, + "step_size": adaptation_state.step_size, + "acc_prob": acc_prob, + "equi_diag": equi_diag, + "bias": true_bias, + "observables": Etheta["observables"], + } + + # Bisection to find step size + stepsize_adaptation_state, step_size = self.epsadap_update( + adaptation_state.stepsize_adaptation_state, + adaptation_state.step_size, + acc_prob, + ) + + return ( + AdaptationState( + adaptation_state.steps_per_sample, + step_size, + stepsize_adaptation_state, + adaptation_state.iteration + 1, + ), + info_to_be_stored, + ) + + +def bias(model): + """should be transfered to benchmarks/""" + + def observables(position): + return jnp.square(model.transform(position)) + + def contract(sampler_E_x2): + bsq = jnp.square(sampler_E_x2 - model.E_x2) / model.Var_x2 + return jnp.array([jnp.max(bsq), jnp.average(bsq)]) + + return observables, contract + + +def while_steps_num(cond): + if jnp.all(cond): + return len(cond) + else: + return jnp.argmin(cond) + 1 + + + +def laps( + logdensity_fn, + sample_init, + ndims, + num_steps1, + num_steps2, + num_chains, + mesh, + rng_key, + alpha=1.9, + save_frac=0.2, + C=0.1, + early_stop=True, + r_end= 0.01, + bias_type= 3, + diagonal_preconditioning=True, + integrator_coefficients=None, + steps_per_sample=15, + acc_prob=None, + observables_for_bias=lambda x: x, + ensemble_observables=None, + diagnostics=True, + contract=lambda x: 0.0, + superchain_size= 1, +): + """ + model: the target density object + num_steps1: number of steps in the first phase + num_steps2: number of steps in the second phase + num_chains: number of chains + mesh: the mesh object, used for distributing the computation across cpus and nodes + rng_key: the random key + alpha: L = sqrt{d}*alpha*variances + save_frac: the fraction of samples used to estimate the fluctuation in the first phase + C: constant in stage 1 that determines step size (eq (9) of EMAUS paper) + early_stop: whether to stop the first phase early + r_end + diagonal_preconditioning: whether to use diagonal preconditioning + integrator_coefficients: the coefficients of the integrator + steps_per_sample: the number of steps per sample + acc_prob: the acceptance probability + observables: the observables (for diagnostic use) + ensemble_observables: observable calculated over the ensemble (for diagnostic use) + diagnostics: whether to return diagnostics + """ + + key_init, key_umclmc, key_mclmc = jax.random.split(rng_key, 3) + + # initialize the chains + initial_state = umclmc.initialize( + key_init, logdensity_fn, sample_init, num_chains, mesh, superchain_size + ) + + # burn-in with the unadjusted method # + kernel = umclmc.build_kernel(logdensity_fn) + save_num = (jnp.rint(save_frac * num_steps1)).astype(int) + adap = umclmc.Adaptation( + ndims, + alpha=alpha, + bias_type=bias_type, + save_num=save_num, + C=C, + r_end=r_end, + observables_for_bias=observables_for_bias, + contract=contract, + ) + + final_state, final_adaptation_state, info1, steps_done_phase_1 = run_eca( + key_umclmc, + initial_state, + kernel, + adap, + num_steps1, + num_chains, + mesh, + superchain_size, + ensemble_observables, + early_stop=early_stop, + ) + + # refine the results with the adjusted method + _acc_prob = acc_prob + if integrator_coefficients is None: + high_dims = ndims > 200 + _integrator_coefficients = ( + omelyan_coefficients if high_dims else mclachlan_coefficients + ) + if acc_prob is None: + _acc_prob = 0.9 if high_dims else 0.7 + + else: + _integrator_coefficients = integrator_coefficients + if acc_prob is None: + _acc_prob = 0.9 + + integrator = generate_isokinetic_integrator(_integrator_coefficients) + gradient_calls_per_step = ( + len(_integrator_coefficients) // 2 + ) # scheme = BABAB..AB scheme has len(scheme)//2 + 1 Bs. The last doesn't count because that gradient can be reused in the next step. + + if diagonal_preconditioning: + inverse_mass_matrix = final_adaptation_state.inverse_mass_matrix + + # scale the stepsize so that it reflects averag scale change of the preconditioning + average_scale_change = jnp.sqrt(jnp.average(inverse_mass_matrix)) + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size / average_scale_change + ) + + else: + inverse_mass_matrix = 1.0 + + kernel = build_kernel( + logdensity_fn, integrator, inverse_mass_matrix=inverse_mass_matrix + ) + + initial_state = HMCState( + final_state.position, final_state.logdensity, final_state.logdensity_grad + # jax.random.key(0) + ) + num_samples = num_steps2 // (gradient_calls_per_step * steps_per_sample) + num_adaptation_samples = ( + num_samples // 2 + ) # number of samples after which the stepsize is fixed. + + final_adaptation_state = final_adaptation_state._replace( + step_size=final_adaptation_state.step_size.item() + ) + + adap = Adaptation( + final_adaptation_state, + num_adaptation_samples, + steps_per_sample, + _acc_prob, + contract=contract, + observables_for_bias=observables_for_bias, + ) + + final_state, final_adaptation_state, info2, _ = run_eca( + key_mclmc, + initial_state, + kernel, + adap, + num_samples, + num_chains, + mesh, + superchain_size, + ensemble_observables + ) + + if diagnostics: + info = {"phase_1": info1, "phase_2": info2} + else: + info = None + + return info, gradient_calls_per_step, _acc_prob, final_state diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py new file mode 100644 index 000000000..46b3c8346 --- /dev/null +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -0,0 +1,328 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# """Public API for the MCLMC Kernel""" + +from typing import Any, NamedTuple + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from blackjax.mcmc import mclmc +from blackjax.mcmc.integrators import ( + IntegratorState, + _normalized_flatten_array, + isokinetic_velocity_verlet, +) +from blackjax.types import Array +from blackjax.util import ensemble_execute_fn + + +def no_nans(a): + flat_a, unravel_fn = ravel_pytree(a) + return jnp.all(jnp.isfinite(flat_a)) + + +def nan_reject(nonans, old, new): + """Equivalent to + return new if nonans else old""" + + return jax.lax.cond(nonans, lambda _: new, lambda _: old, operand=None) + + +def build_kernel(logdensity_fn): + """MCLMC kernel (with nan rejection)""" + + # kernel = mclmc.build_kernel( + # logdensity_fn=logdensity_fn, integrator=isokinetic_velocity_verlet + # ) + + def sequential_kernel(key, state, adap): + new_state, info = mclmc.build_kernel( + integrator=isokinetic_velocity_verlet, + )(key, state,logdensity_fn, adap.L, adap.step_size,jnp.ones(adap.inverse_mass_matrix.shape)) + + # reject the new state if there were nans + nonans = no_nans(new_state) + new_state = nan_reject(nonans, state, new_state) + + return new_state, { + "nans": 1 - nonans, + "energy_change": info.energy_change * nonans, + "logdensity": info.logdensity * nonans, + } + + return sequential_kernel + + +def initialize(rng_key, logdensity_fn, sample_init, num_chains, mesh, superchain_size): + """initialize the chains based on the equipartition of the initial condition. + We initialize the velocity along grad log p if E_ii > 1 and along -grad log p if E_ii < 1. + """ + + def sequential_init(key, x, args): + """initialize the position using sample_init and the velocity along the gradient""" + position = sample_init(key) + + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + flat_g, unravel_fn = ravel_pytree(logdensity_grad) + velocity = unravel_fn( + _normalized_flatten_array(flat_g)[0] + ) # = grad logp/ |grad logp| + + return IntegratorState(position, velocity, logdensity, logdensity_grad), None + + def summary_statistics_fn(state): + """compute the diagonal elements of the equipartition matrix""" + flat_pos, unflatten = jax.flatten_util.ravel_pytree(state.position) + flat_g, unravel_fn = ravel_pytree(state.logdensity_grad) + return unravel_fn(-flat_pos * flat_g) + # return 0 + + # -state.position # * state.logdensity_grad + + def ensemble_init(key, state, signs): + """flip the velocity, depending on the equipartition condition""" + + momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum) + + velocity_flat = jax.tree_util.tree_map( + lambda sign, u: sign * u, signs, momentum + ) + + velocity = unflatten(velocity_flat) + + return ( + IntegratorState( + state.position, velocity, state.logdensity, state.logdensity_grad + ), + None, + ) + + key1, key2 = jax.random.split(rng_key) + initial_state, equipartition = ensemble_execute_fn( + sequential_init, + key1, + num_chains, + mesh, + summary_statistics_fn=summary_statistics_fn, + superchain_size= superchain_size + ) + + flat_equi, _ = ravel_pytree(equipartition) + + signs = -2.0 * (flat_equi < 1.0) + 1.0 + initial_state, _ = ensemble_execute_fn(ensemble_init, key2, num_chains, mesh, x=initial_state, args=signs, superchain_size= superchain_size) + + return initial_state + + +def update_history(new_vals, history): + new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) + return jnp.concatenate((new_vals[None, :], history[:-1, :])) + + +def update_history_scalar(new_val, history): + return jnp.concatenate((new_val * jnp.ones(1), history[:-1])) + + +def contract_history(theta, weights): + square_average = jnp.square(jnp.average(theta, weights=weights, axis=0)) + average_square = jnp.average(jnp.square(theta), weights=weights, axis=0) + + r = (average_square - square_average) / square_average + + return jnp.array([jnp.max(r), jnp.average(r)]) + + +class History(NamedTuple): + observables: Array + stopping: Array + weights: Array + + +class AdaptationState(NamedTuple): + L: float + inverse_mass_matrix: Any + step_size: float + + step_count: int + EEVPD: float + EEVPD_wanted: float + history: Any + + +def equipartition_diagonal(state): + """Ei = E_ensemble (- grad log p_i x_i ). Ei is 1 if we have converged. + equipartition_loss = average over parameters (Ei)""" + return jax.tree_util.tree_map( + lambda x, g: -x * g, state.position, state.logdensity_grad + ) + + +def equipartition_fullrank(state, rng_key): + """loss = Tr[(1 - E)^T (1 - E)] / d^2 + where Eij = is the equipartition patrix. + Loss is computed with the Hutchinson's trick.""" + + x, unravel_fn = ravel_pytree(state.position) + g, unravel_fn = ravel_pytree(state.logdensity_grad) + d = len(x) + + def func(z): + """z here has the same shape as position""" + return z + jnp.dot(z, g) * x + + z = jax.random.rademacher(rng_key, (100, d)) # = delta_ij + return jax.vmap(func)(z) + + +def equipartition_diagonal_loss(Eii): + Eii_flat, unravel_fn = ravel_pytree(Eii) + return jnp.average(jnp.square(1.0 - Eii_flat)) + + +def equipartition_fullrank_loss(delta_z): + d = delta_z.shape[-1] + return jnp.average(jnp.square(delta_z)) / d + + +class Adaptation: + def __init__( + self, + ndims, + alpha=1.0, + C=0.1, + r_end=0.01, + bias_type=0, + save_num=10, + observables=lambda x: 0.0, + observables_for_bias=lambda x: x, + contract=lambda x: 0.0, + ): + self.ndims = ndims + self.alpha = alpha + self.C = C + self.r_end = r_end + self.observables = observables + self.observables_for_bias = observables_for_bias + self.contract = contract + self.bias_type = bias_type + self.save_num = save_num + r_save_num = save_num + + history = History( + observables=jnp.zeros((r_save_num, ndims)), + stopping=jnp.full((save_num,), jnp.nan), + weights=jnp.zeros(r_save_num), + ) + + self.initial_state = AdaptationState( + L=jnp.inf, # do not add noise for the first step + inverse_mass_matrix=jnp.ones(ndims), + step_size=0.01 * jnp.sqrt(ndims), + step_count=0, + EEVPD=1e-3, + EEVPD_wanted=1e-3, + history=history, + ) + + # info 1 + def summary_statistics_fn(self, state, info, rng_key): + position_flat, unravel_fn = ravel_pytree(state.position) + + return { + "equipartition_diagonal": equipartition_diagonal(state), + "equipartition_fullrank": equipartition_fullrank(state, rng_key), + "x": position_flat, + "xsq": jnp.square(position_flat), + "E": info["energy_change"], + "Esq": jnp.square(info["energy_change"]), + "rejection_rate_nans": info["nans"], + "observables_for_bias": self.observables_for_bias(state.position), + "observables": self.observables(state.position), + "entropy": -info["logdensity"], + "uturn": jnp.sqrt(jnp.sum(jnp.square(state.logdensity_grad - jnp.dot(state.logdensity_grad, state.momentum) * state.momentum))) / (self.ndims - 1) + } + + def update(self, adaptation_state, Etheta): + # combine the expectation values to get useful scalars + equi_diag = equipartition_diagonal_loss(Etheta["equipartition_diagonal"]) + equi_full = equipartition_fullrank_loss(Etheta["equipartition_fullrank"]) + + history_observables = update_history( + Etheta["observables_for_bias"], adaptation_state.history.observables + ) + + history_weights = update_history_scalar(1.0, adaptation_state.history.weights) + fluctuations = contract_history(history_observables, history_weights) + history_stopping = update_history_scalar( + jax.lax.cond( + adaptation_state.step_count > len(history_weights), + lambda _: fluctuations[0], + lambda _: jnp.nan, + operand=None, + ), + adaptation_state.history.stopping, + ) + history = History(history_observables, history_stopping, history_weights) + + L = self.alpha * jnp.sqrt(jnp.sum(Etheta["xsq"] - jnp.square(Etheta["x"]))) # average over the ensemble, sum over parameters (to get sqrt(d)) + #L = self.alpha / Etheta["uturn"] + inverse_mass_matrix = Etheta["xsq"] - jnp.square(Etheta["x"]) + EEVPD = (Etheta["Esq"] - jnp.square(Etheta["E"])) / self.ndims + true_bias = self.contract(Etheta["observables_for_bias"]) + nans = Etheta["rejection_rate_nans"] > 0.0 # | (~jnp.isfinite(eps_factor)) + + # hyperparameter adaptation + # estimate bias + bias = jnp.array([fluctuations[0], fluctuations[1], equi_full, equi_diag])[self.bias_type] # r_max, r_avg, equi_full, equi_diag + EEVPD_wanted = self.C * jnp.power(bias, 3./8.) + # bias_asym_wanted = self.C * bias + # EEVPD_wanted = 4 * jnp.power(bias_asym_wanted, 3./2.) / jnp.square(1 + jnp.sqrt(bias_asym_wanted)) # phi function from Robnik et. al., Blackbox Unadjusted Hamiltonian Monte Carlo + + eps_factor = jnp.power(EEVPD_wanted / EEVPD, 1.0 / 6.0) + eps_factor = jnp.clip(eps_factor, 0.3, 3.0) + + eps_factor = nan_reject(1 - nans, 0.5, eps_factor) # reduce the stepsize if there were nans + + info_to_be_stored = { + "L": adaptation_state.L, + "step_size": adaptation_state.step_size, + "EEVPD_wanted": EEVPD_wanted, + "EEVPD": EEVPD, + "equi_diag": equi_diag, + "equi_full": equi_full, + "bias": true_bias, + "r_max": fluctuations[0], + "r_avg": fluctuations[1], + "entropy": Etheta["entropy"], + "observables": Etheta["observables"], + } + + adaptation_state_new = AdaptationState( + L, + inverse_mass_matrix, + adaptation_state.step_size * eps_factor, + adaptation_state.step_count + 1, + EEVPD, + EEVPD_wanted, + history, + ) + + return adaptation_state_new, info_to_be_stored + + def while_cond(self, info): + """determine if we want to switch to adjustment""" + return info['r_max'] > self.r_end \ No newline at end of file diff --git a/blackjax/adaptation/las.py b/blackjax/adaptation/las.py new file mode 100644 index 000000000..f4c975411 --- /dev/null +++ b/blackjax/adaptation/las.py @@ -0,0 +1,152 @@ +import jax +import jax.numpy as jnp +import blackjax +from blackjax.util import run_inference_algorithm +import blackjax + +from blackjax.adaptation.unadjusted_alba import unadjusted_alba +from blackjax.adaptation.unadjusted_step_size import robnik_step_size_tuning +from blackjax.adaptation.unadjusted_alba import unadjusted_alba +import math +from blackjax.mcmc.adjusted_mclmc_dynamic import make_random_trajectory_length_fn +from functools import partial +from blackjax.adaptation.step_size import bisection_monotonic_fn + +# unbelievable that this is not in the standard library +def compose(f, g): + return lambda x: f(g(x)) + + +def las(logdensity_fn, num_chains, key, ndims, num_adjusted_steps, diagonal_preconditioning=True, target_acceptance_rate=0.8): + + init_key, tune_key, unadjusted_key, adjusted_key = jax.random.split(key, 4) + initial_position = jax.random.normal(init_key, (ndims,)) + + ### Phase 1: unadjusted ### + + integrator = blackjax.mcmc.integrators.isokinetic_mclachlan + + # burn-in and adaptation + num_alba_steps = 10000 + warmup = unadjusted_alba( + algorithm=blackjax.mclmc, + logdensity_fn=logdensity_fn, integrator=integrator, + target_eevpd=5e-4, + # target_acceptance_rate=target_acceptance_rate, + v=jnp.sqrt(ndims), + num_alba_steps=num_alba_steps, + preconditioning=diagonal_preconditioning, + alba_factor=0.4, + ) + + (blackjax_state_after_tuning, blackjax_mclmc_sampler_params), adaptation_info = warmup.run(tune_key, initial_position, 20000) + + # sampling + ess_per_sample = blackjax_mclmc_sampler_params['ESS'] + + num_steps = math.ceil(num_chains // ess_per_sample) + + alg = blackjax.mclmc( + logdensity_fn=logdensity_fn, + L=blackjax_mclmc_sampler_params['L'], + step_size=blackjax_mclmc_sampler_params['step_size'], + inverse_mass_matrix=blackjax_mclmc_sampler_params['inverse_mass_matrix'], + integrator=integrator, + ) + + final_output, history = run_inference_algorithm( + rng_key=unadjusted_key, + initial_state=blackjax_state_after_tuning, + inference_algorithm=alg, + num_steps=num_steps, + transform=(lambda a, b: a), + progress_bar=False, + ) + samples = history.position + + + ### Phase 2: adjusted ### + + subsamples = samples[::math.ceil(1/ess_per_sample)] + + integration_steps_fn = make_random_trajectory_length_fn(True) + + + # initial_states = jax.lax.map(lambda x: blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), xs=subsamples) + # initial_states = jax.lax.map(lambda x: blackjax.adjusted_mclmc_dynamic.init(x, logdensity_fn, jax.random.key(0)), xs=subsamples) + + def make_mams_step(key): + def mams_step(inp): + # init_key, run_key = jax.random.split(key, 2) + + step_size, positions, info, step_size_adaptation_state = inp + keys = jax.random.split(key, positions.shape[0]) + # num_steps_per_traj = blackjax_mclmc_sampler_params['L'] / step_size + num_steps_per_traj = 1 + + alg = blackjax.adjusted_mclmc_dynamic( + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=integration_steps_fn(num_steps_per_traj), + integrator=blackjax.mcmc.integrators.isokinetic_velocity_verlet, + inverse_mass_matrix=blackjax_mclmc_sampler_params['inverse_mass_matrix'], + L_proposal_factor=jnp.inf, + ) + + # run_keys = jax.random.split(run_key, positions.shape[0]) + + def step_fn(pos_key): + pos, key = pos_key + init_key, run_key = jax.random.split(key, 2) + return alg.step( + rng_key=run_key, + state=blackjax.adjusted_mclmc_dynamic.init(pos, logdensity_fn, init_key), + ) + + new_states, infos = jax.lax.map(step_fn, xs=(positions,keys)) + # jax.debug.print("infos adaptation step: {infos}", infos=jnp.sum(infos.is_accepted)) + return (step_size, new_states.position, infos, step_size_adaptation_state) + + return mams_step + + epsadap_update = bisection_monotonic_fn(target_acceptance_rate) + step_size_adaptation_state_initial = (jnp.array([-jnp.inf, jnp.inf]), False) + + def tuning_step(inp): + + old_step_size, old_positions, old_infos, step_size_adaptation_state = inp + acc_rate = old_infos.acceptance_rate.mean() + + + step_size_adaptation_state, new_step_size = epsadap_update( + step_size_adaptation_state, + old_step_size, + acc_rate, + ) + + return (new_step_size, old_positions, old_infos, step_size_adaptation_state) + # return (10.0, old_positions, old_infos, step_size_adaptation_state) + + step = lambda key: compose(tuning_step, make_mams_step(key)) + + initial_adjusted_key, adjusted_key = jax.random.split(adjusted_key, 2) + _, _, infos, _ = make_mams_step(initial_adjusted_key)((blackjax_mclmc_sampler_params['step_size'], subsamples, None, step_size_adaptation_state_initial)) + + + positions = subsamples + step_size = blackjax_mclmc_sampler_params['step_size'] + + (step_size, position, infos, step_size_adaptation_state), (step_sizes, positions, infos, step_size_adaptation_state) = jax.lax.scan(lambda state, key: (step(key)(state), step(key)(state)), (step_size, subsamples, infos, step_size_adaptation_state_initial), jax.random.split(adjusted_key, num_adjusted_steps)) + + return samples, positions, infos, num_steps, step_size_adaptation_state + +# type: forall a, b: (a -> b) -> (b -> a) -> Int -> (a -> b) +# e.g.: a ~ (stepsize, position), b ~ (state) +def feedback(f,g, n, state_a): + for i in range(n): + state_b = f(state_a) + # print(state_b, "state_b") + state_a = g(state_b) + # print(state_a, "state_a") + return state_a + diff --git a/blackjax/adaptation/mass_matrix.py b/blackjax/adaptation/mass_matrix.py index dc0730161..bb74fa2ee 100644 --- a/blackjax/adaptation/mass_matrix.py +++ b/blackjax/adaptation/mass_matrix.py @@ -132,7 +132,7 @@ def final(mm_state: MassMatrixAdaptationState) -> MassMatrixAdaptationState: """Final iteration of the mass matrix adaptation. In this step we compute the mass matrix from the covariance matrix computed - by the Welford algorithm, and re-initialize the later. + by the Welford algorithm, and re-initialize the latter. """ _, wc_state = mm_state diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index fa644898a..e5c1c900f 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -21,6 +21,7 @@ from blackjax.diagnostics import effective_sample_size from blackjax.util import generate_unit_vector, incremental_value_update, pytree_size +import blackjax.mcmc.metrics as metrics class MCLMCAdaptationState(NamedTuple): @@ -50,8 +51,10 @@ def mclmc_find_L_and_step_size( desired_energy_var=5e-4, trust_in_estimate=1.5, num_effective_samples=150, - diagonal_preconditioning=True, params=None, + diagonal_preconditioning=True, + num_windows=1, + euclidean=False ): """ Finds the optimal value of the parameters for the MCLMC algorithm. @@ -72,7 +75,7 @@ def mclmc_find_L_and_step_size( The fraction of tuning for the second step of the adaptation. frac_tune3 The fraction of tuning for the third step of the adaptation. - desired_energy_va + desired_energy_var The desired energy variance for the MCMC algorithm. trust_in_estimate The trust in the estimate of optimal stepsize. @@ -82,6 +85,7 @@ def mclmc_find_L_and_step_size( Whether to do diagonal preconditioning (i.e. a mass matrix) params Initial params to start tuning from (optional) + euclidean: if this tuning is used for HMC or underdamped LMC, there are sqrt{d} factors that need to be taken into account (because L is parametrized differently) Returns ------- @@ -109,9 +113,18 @@ def mclmc_find_L_and_step_size( """ dim = pytree_size(state.position) if params is None: - params = MCLMCAdaptationState( - jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) - ) + if euclidean: + params = MCLMCAdaptationState( + 1, 0.25, inverse_mass_matrix=jnp.ones((dim,)) + ) + else: + params = MCLMCAdaptationState( + jnp.sqrt(dim), jnp.sqrt(dim) * 0.25, inverse_mass_matrix=jnp.ones((dim,)) + ) + + + # jax.debug.print("params {x}", x=(params, euclidean)) + part1_key, part2_key = jax.random.split(rng_key, 2) total_num_tuning_integrator_steps = 0 @@ -122,16 +135,21 @@ def mclmc_find_L_and_step_size( num_steps2 += diagonal_preconditioning * (num_steps2 // 3) num_steps3 = round(num_steps * frac_tune3) - state, params = make_L_step_size_adaptation( - kernel=mclmc_kernel, - dim=dim, - frac_tune1=frac_tune1, - frac_tune2=frac_tune2, - desired_energy_var=desired_energy_var, - trust_in_estimate=trust_in_estimate, - num_effective_samples=num_effective_samples, - diagonal_preconditioning=diagonal_preconditioning, - )(state, params, num_steps, part1_key) + for i in range(num_windows): + window_key = jax.random.fold_in(part1_key, i) + + state, params = make_L_step_size_adaptation( + kernel=mclmc_kernel, + dim=dim, + frac_tune1=frac_tune1/num_windows, + frac_tune2=frac_tune2/num_windows, + desired_energy_var=desired_energy_var, + trust_in_estimate=trust_in_estimate, + num_effective_samples=num_effective_samples, + diagonal_preconditioning=diagonal_preconditioning, + euclidean=euclidean + )(state, params, num_steps, window_key) + total_num_tuning_integrator_steps += num_steps1 + num_steps2 if num_steps3 >= 2: # at least 2 samples for ESS estimation @@ -152,6 +170,7 @@ def make_L_step_size_adaptation( desired_energy_var=1e-3, trust_in_estimate=1.5, num_effective_samples=150, + euclidean=False ): """Adapts the stepsize and L of the MCLMC kernel. Designed for unadjusted MCLMC""" @@ -163,7 +182,7 @@ def predictor(previous_state, params, adaptive_state, rng_key): time, x_average, step_size_max = adaptive_state - rng_key, nan_key = jax.random.split(rng_key) + rng_key, nan_key, energy_key = jax.random.split(rng_key, 3) # dynamics next_state, info = kernel(params.inverse_mass_matrix)( @@ -173,6 +192,26 @@ def predictor(previous_state, params, adaptive_state, rng_key): step_size=params.step_size, ) + ndims = pytree_size(previous_state.position) + cutoff = jnp.sqrt(ndims * desired_energy_var * 10**4) + energy_change_, next_state = handle_high_energy( + next_state=next_state, + previous_state=previous_state, + energy_change=info.energy_change, + key=energy_key, + inverse_mass_matrix=params.inverse_mass_matrix, + euclidean=euclidean, + cutoff=cutoff + ) + + info = info._replace( + energy_change=energy_change_, + ) + + # jax.debug.print("info.energy_change {x}", x=info.energy_change) + + + # step updating success, state, step_size_max, energy_change = handle_nans( previous_state, @@ -181,8 +220,15 @@ def predictor(previous_state, params, adaptive_state, rng_key): step_size_max, info.energy_change, nan_key, + euclidean=euclidean, ) + # energy_change = jnp.clip( + # energy_change, + # -25, + # 25 + # ) + # Warning: var = 0 if there were nans, but we will give it a very small weight xi = ( jnp.square(energy_change) / (dim * desired_energy_var) @@ -190,10 +236,15 @@ def predictor(previous_state, params, adaptive_state, rng_key): weight = jnp.exp( -0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate)) ) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. + # jax.debug.print("energy change {x}", x=energy_change) + # jax.debug.print("step size {x}", x=params.step_size) + x_average = decay_rate * x_average + weight * ( xi / jnp.power(params.step_size, 6.0) ) + # jax.debug.print("xi {x}", x=(xi, energy_change, x_average)) + time = decay_rate * time + weight step_size = jnp.power( x_average / time, -1.0 / 6.0 @@ -201,6 +252,12 @@ def predictor(previous_state, params, adaptive_state, rng_key): step_size = (step_size < step_size_max) * step_size + ( step_size > step_size_max ) * step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences + + # jax.debug.print("step size {x}", x=(step_size, step_size_max)) + + + + # params_new = params._replace(step_size=(energy_change==0)*params.step_size/2 + (energy_change != 0)*step_size) params_new = params._replace(step_size=step_size) adaptive_state = (time, x_average, step_size_max) @@ -255,22 +312,30 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2))) # run the steps - state, params, _, (_, average) = run_steps( + state, params, (_,x_average, _), (_, average) = run_steps( xs=(mask, L_step_size_adaptation_keys), state=state, params=params ) + # jax.debug.print("step size {x}", x=(params.step_size, x_average)) + L = params.L # determine L inverse_mass_matrix = params.inverse_mass_matrix if num_steps2 > 1: x_average, x_squared_average = average[0], average[1] variances = x_squared_average - jnp.square(x_average) - L = jnp.sqrt(jnp.sum(variances)) + L = jnp.sqrt(jnp.sum(variances)) # lmc: should be jnp.mean + if euclidean: + L /= jnp.sqrt(dim) + # jax.debug.print("foo {x}", x=(L, euclidean, variances)) if diagonal_preconditioning: + # jax.debug.print("bar") inverse_mass_matrix = variances params = params._replace(inverse_mass_matrix=inverse_mass_matrix) - L = jnp.sqrt(dim) + L = jnp.sqrt(dim) # lmc: 1 + if euclidean: + L /= jnp.sqrt(dim) # readjust the stepsize steps = round(num_steps2 / 3) # we do some small number of steps @@ -279,6 +344,8 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): xs=(jnp.ones(steps), keys), state=state, params=params ) + # jax.debug.print("params {x}", x=(MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix), euclidean)) + return state, MCLMCAdaptationState(L, params.step_size, inverse_mass_matrix) return L_step_size_adaptation @@ -318,7 +385,7 @@ def step(state, key): def handle_nans( - previous_state, next_state, step_size, step_size_max, kinetic_change, key + previous_state, next_state, step_size, step_size_max, kinetic_change, key, euclidean=False ): """if there are nans, let's reduce the stepsize, and not update the state. The function returns the old state in this case.""" @@ -333,12 +400,42 @@ def handle_nans( (previous_state, step_size * reduced_step_size, 0.0), ) + # new_momentum = euclidean*0.0 + (1-euclidean)*generate_unit_vector(key, previous_state.position) + new_momentum = generate_unit_vector(key, previous_state.position) + + # new_momentum = euclidean*metric.sample_momentum(key, previous_state.position) + (1-euclidean)*generate_unit_vector(key, previous_state.position) + state = jax.lax.cond( jnp.isnan(next_state.logdensity), lambda: state._replace( - momentum=generate_unit_vector(key, previous_state.position) + # momentum=generate_unit_vector(key, previous_state.position) + momentum=new_momentum ), lambda: state, ) return nonans, state, step_size, kinetic_change + + +def handle_high_energy( + previous_state, next_state, energy_change, key, inverse_mass_matrix, cutoff, euclidean=False +): + + + metric = metrics.default_metric(inverse_mass_matrix) + + + new_momentum = euclidean*metric.sample_momentum(key, previous_state.position) + (1-euclidean)*generate_unit_vector(key, previous_state.position) + # new_momentum = generate_unit_vector(key, previous_state.position) + + state = jax.lax.cond( + jnp.abs(energy_change) > cutoff, + lambda: previous_state._replace( + # momentum=generate_unit_vector(key, next_state.position) + momentum=new_momentum + ), + lambda: next_state, + ) + energy_change = jnp.clip(energy_change, -cutoff, cutoff) + + return energy_change, state diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 2b06172c0..da036c266 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -257,3 +257,47 @@ def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState: rss_state = jax.lax.while_loop(do_continue, update, rss_state) return rss_state.step_size + + +def bisection_monotonic_fn(acc_prob_wanted, reduce_shift=jnp.log(2.0), tolerance=0.03): + """Bisection of a monotonically decreassing function, that doesn't require an initially bracketing interval.""" + + def update(state, exp_x, acc_rate_new): + bounds, terminated = state + + # update the bounds + acc_high = acc_rate_new > acc_prob_wanted + x = jnp.log(exp_x) + + def on_true(bounds): + lower, upper = bounds + lower = jnp.max(jnp.array([lower, x])) + return jnp.array([lower, upper]), lower + reduce_shift + + def on_false(bounds): + lower, upper = bounds + upper = jnp.min(jnp.array([upper, x])) + return jnp.array([lower, upper]), upper - reduce_shift + + bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds) + + # if we have already found a bracketing interval, do bisection, otherwise further reduce or increase the bounds + bracketing = jnp.all(jnp.isfinite(bounds_new)) + + def reduce(bounds): + return x_new + + def bisect(bounds): + return jnp.average(bounds) + + x_new = jax.lax.cond(bracketing, bisect, reduce, bounds_new) + + stepsize = terminated * exp_x + (1 - terminated) * jnp.exp(x_new) + + terminated_new = ( + jnp.abs(acc_rate_new - acc_prob_wanted) < tolerance + ) | terminated + + return (bounds_new, terminated_new), stepsize + + return update diff --git a/blackjax/adaptation/unadjusted_alba.py b/blackjax/adaptation/unadjusted_alba.py new file mode 100644 index 000000000..9164ee562 --- /dev/null +++ b/blackjax/adaptation/unadjusted_alba.py @@ -0,0 +1,346 @@ +import jax +import jax.numpy as jnp + +from typing import Callable, NamedTuple + +import blackjax.mcmc as mcmc +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info +from blackjax.adaptation.mass_matrix import ( + MassMatrixAdaptationState, + mass_matrix_adaptation, +) +from blackjax.base import AdaptationAlgorithm +from blackjax.progress_bar import gen_scan_fn +from blackjax.types import Array, ArrayLikeTree, PRNGKey +from blackjax.util import pytree_size +from blackjax.adaptation.window_adaptation import build_schedule +from jax.flatten_util import ravel_pytree +from blackjax.diagnostics import effective_sample_size +from blackjax.adaptation.unadjusted_step_size import robnik_step_size_tuning, RobnikStepSizeTuningState +import math + +class AlbaAdaptationState(NamedTuple): + ss_state: RobnikStepSizeTuningState # step size + imm_state: MassMatrixAdaptationState # inverse mass matrix + step_size: float + inverse_mass_matrix: Array + L : float + +def base( + is_mass_matrix_diagonal: bool, + v, + target_eevpd, + preconditioning: bool = True, +) -> tuple[Callable, Callable, Callable]: + + mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal) + if not preconditioning: + + mm_update = lambda x, y: x + + mm_final = lambda x: x + + # step_size_init, step_size_update, step_size_final = dual_averaging_adaptation(target_eevpd) + step_size_init, step_size_update, step_size_final = robnik_step_size_tuning(desired_energy_var=target_eevpd) + + def init( + position: ArrayLikeTree, + ) -> AlbaAdaptationState: + + num_dimensions = pytree_size(position) + imm_state = mm_init(num_dimensions) + + ss_state = step_size_init(initial_step_size=jnp.sqrt(num_dimensions)/5, num_dimensions=num_dimensions) + + return AlbaAdaptationState( + ss_state, + imm_state, + ss_state.step_size, + imm_state.inverse_mass_matrix, + L = jnp.sqrt(num_dimensions)/v + ) + + def fast_update( + position: ArrayLikeTree, + info, + warmup_state: AlbaAdaptationState, + ) -> AlbaAdaptationState: + """Update the adaptation state when in a "fast" window. + + Only the step size is adapted in fast windows. "Fast" refers to the fact + that the optimization algorithms are relatively fast to converge + compared to the covariance estimation with Welford's algorithm + + """ + + del position + + + new_ss_state = step_size_update(warmup_state.ss_state, info) + new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + + new_inverse_mass_matrix = jax.lax.cond( + preconditioning, + lambda: warmup_state.inverse_mass_matrix, + lambda: jnp.ones_like(warmup_state.inverse_mass_matrix), + ) + + return AlbaAdaptationState( + new_ss_state, + warmup_state.imm_state, + new_step_size, + new_inverse_mass_matrix, + L = warmup_state.L + ) + + def slow_update( + position: ArrayLikeTree, + info, + warmup_state: AlbaAdaptationState, + ) -> AlbaAdaptationState: + + # raise Exception + + new_imm_state = mm_update(warmup_state.imm_state, position) + # jax.debug.print("imm state {x}", x=new_imm_state.inverse_mass_matrix[:3]) + # jax.debug.print("warmup_state.ss_state: {x}", x=(warmup_state.ss_state.step_size)) + new_ss_state = step_size_update(warmup_state.ss_state, info) + # new_ss_state = warmup_state.ss_state + # jax.debug.print("old then new: {new_ss_state}", new_ss_state=(warmup_state.ss_state.step_size, new_ss_state.step_size)) + new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + + new_inverse_mass_matrix = jax.lax.cond( + preconditioning, + lambda: warmup_state.inverse_mass_matrix, + lambda: jnp.ones_like(warmup_state.inverse_mass_matrix), + ) + + return AlbaAdaptationState( + new_ss_state, new_imm_state, new_step_size, new_inverse_mass_matrix, L = warmup_state.L + ) + + def slow_final(warmup_state: AlbaAdaptationState) -> AlbaAdaptationState: + + new_imm_state = mm_final(warmup_state.imm_state) + new_ss_state = warmup_state.ss_state + # ._replace(step_size=step_size_final(warmup_state.ss_state)) + # step_size_init(step_size_final(warmup_state.ss_state), warmup_state.ss_state.num_dimensions) + new_step_size = new_ss_state.step_size # jnp.exp(new_ss_state.log_step_size) + # jax.debug.print("new_ss_state: {new_ss_state}", new_ss_state=(new_ss_state.step_size)) + + new_L = jax.lax.cond( + preconditioning, + lambda: jnp.sqrt(warmup_state.ss_state.num_dimensions)/v, + lambda: jnp.sqrt(jnp.sum(new_imm_state.inverse_mass_matrix)), + ) + + # jax.debug.print("new_L: {x}", x=(warmup_state.L, new_L)) + + return AlbaAdaptationState( + new_ss_state, + new_imm_state, + new_step_size, + new_imm_state.inverse_mass_matrix, + L = new_L + ) + + def update( + adaptation_state: AlbaAdaptationState, + adaptation_stage: tuple, + position: ArrayLikeTree, + info, + ) -> AlbaAdaptationState: + """Update the adaptation state and parameter values. + + Parameters + ---------- + adaptation_state + Current adptation state. + adaptation_stage + The current stage of the warmup: whether this is a slow window, + a fast window and if we are at the last step of a slow window. + position + Current value of the model parameters. + value + Value of the acceptance rate for the last mcmc step. + + Returns + ------- + The updated adaptation state. + + """ + stage, is_middle_window_end = adaptation_stage + + warmup_state = jax.lax.switch( + stage, + (fast_update, slow_update), + position, + info, + adaptation_state, + ) + + warmup_state = jax.lax.cond( + is_middle_window_end, + slow_final, + lambda x: x, + warmup_state, + ) + + return warmup_state + + def final(warmup_state: AlbaAdaptationState) -> tuple[float, Array]: + """Return the final values for the step size and mass matrix.""" + step_size = step_size_final(warmup_state.ss_state) + # step_size = jnp.exp(warmup_state.ss_state.log_step_size_avg) + inverse_mass_matrix = warmup_state.imm_state.inverse_mass_matrix + L = warmup_state.L + return step_size, L, inverse_mass_matrix + + return init, update, final + +def unadjusted_alba( + algorithm, + logdensity_fn: Callable, + target_eevpd, + v, + preconditioning: bool = True, + is_mass_matrix_diagonal: bool = True, + progress_bar: bool = False, + adaptation_info_fn: Callable = lambda x, y, z : None, + integrator=mcmc.integrators.velocity_verlet, + num_alba_steps: int = 500, + alba_factor: float = 0.4, + **extra_parameters, +) -> AdaptationAlgorithm: + + + mcmc_kernel = algorithm.build_kernel(integrator) + + adapt_init, adapt_step, adapt_final = base( + is_mass_matrix_diagonal=is_mass_matrix_diagonal, + target_eevpd=target_eevpd, + v=v, + preconditioning=preconditioning + ) + + def one_step(carry, xs): + _, rng_key, adaptation_stage = xs + state, adaptation_state = carry + + new_state, info = mcmc_kernel( + rng_key=rng_key, + state=state, + logdensity_fn=logdensity_fn, + step_size=adaptation_state.step_size, + inverse_mass_matrix=adaptation_state.inverse_mass_matrix, + L=adaptation_state.L, + **extra_parameters, + ) + new_adaptation_state = adapt_step( + adaptation_state, + adaptation_stage, + new_state.position, + info, + ) + # jax.debug.print("info: {x}", x=(new_adaptation_state.step_size, info.energy_change)) + # jax.debug.print("step sizes: {x}", x=(adaptation_state.step_size, new_adaptation_state.step_size)) + + return ( + (new_state, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), + ) + + def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): + init_key, rng_key, alba_key = jax.random.split(rng_key, 3) + init_state = algorithm.init(position=position, logdensity_fn=logdensity_fn, random_generator_arg=init_key) + init_adaptation_state = adapt_init(position) + + if progress_bar: + print("Running window adaptation") + scan_fn = gen_scan_fn(num_steps-num_alba_steps, progress_bar=progress_bar) + start_state = (init_state, init_adaptation_state) + keys = jax.random.split(rng_key, num_steps-num_alba_steps) + schedule = build_schedule(num_steps-num_alba_steps) + last_state, info = scan_fn( + one_step, + start_state, + (jnp.arange(num_steps-num_alba_steps), keys, schedule), + ) + + last_chain_state, last_warmup_state, *_ = last_state + step_size, L, inverse_mass_matrix = adapt_final(last_warmup_state) + + # jax.debug.print("unadjusted L before alba: {params}", params=(L, step_size)) + + ### + ### ALBA TUNING + ### + + jax.debug.print("num_alba_steps: {x}", x=num_alba_steps) + + max_num_steps = 200 + thinning_rate = math.ceil(num_alba_steps / max_num_steps) + jax.debug.print("thinning_rate: {x}", x=thinning_rate) + new_num_alba_steps = math.ceil(num_alba_steps / thinning_rate) + jax.debug.print("new_num_alba_steps: {x}", x=new_num_alba_steps) + + keys = jax.random.split(alba_key, new_num_alba_steps) + mcmc_kernel = algorithm.build_kernel(integrator) + + def step(state, key): + next_state=state + for i in range(thinning_rate): + key = jax.random.fold_in(key, i) + next_state, _ = mcmc_kernel( + rng_key=key, + state=state, + logdensity_fn=logdensity_fn, + L=L, + step_size=step_size, + inverse_mass_matrix=inverse_mass_matrix, + ) + + return next_state, next_state.position + + if new_num_alba_steps > 0: + jax.debug.print("params before alba tuning {x}", x=(L, step_size)) + _, samples = jax.lax.scan(step, last_chain_state, keys) + flat_samples = jax.vmap(lambda x: ravel_pytree(x)[0])(samples) + ess = effective_sample_size(flat_samples[None, ...]) + jax.debug.print("ess after alba {x}", x=(L, step_size, ess)) + # print(num_alba_steps/jnp.mean(ess), "ess (blackjax internal)\n") + # print( effective_sample_size(flat_samples[None, ...]), "ess (blackjax internal)\n") + + # print("L etc", L, step_size, jnp.mean(ess), num_alba_steps, jnp.mean(num_alba_steps / ess)) + L=alba_factor * step_size * jnp.mean(new_num_alba_steps / ess) + # print("new L", L) + # raise Exception("stop") + + # else: + # ess = 0.0 + + max_num_steps = 500 + + # jax.debug.print("L: {x}", x=step_size*50.) + parameters = { + "step_size": step_size, + "inverse_mass_matrix": inverse_mass_matrix, + "L": jnp.clip(L, max=step_size*max_num_steps), + "ESS": jnp.mean(ess)/num_alba_steps, + **extra_parameters, + } + + # jax.debug.print("parameters {x}", x=parameters) + + return ( + AdaptationResults( + last_chain_state, + parameters, + ), + info, + ) + + return AdaptationAlgorithm(run) + + + diff --git a/blackjax/adaptation/unadjusted_step_size.py b/blackjax/adaptation/unadjusted_step_size.py new file mode 100644 index 000000000..c0753345b --- /dev/null +++ b/blackjax/adaptation/unadjusted_step_size.py @@ -0,0 +1,74 @@ +import jax.numpy as jnp +from typing import NamedTuple +import jax + +class RobnikStepSizeTuningState(NamedTuple): + time : jnp.ndarray + step_size: float + x_average: float + step_size_max: float + num_dimensions: int + +def robnik_step_size_tuning(desired_energy_var, trust_in_estimate=1.5, num_effective_samples=150, step_size_max=jnp.inf, step_size_reduction_factor=0.8): + + decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) + + def init(initial_step_size, num_dimensions): + return RobnikStepSizeTuningState(time=0.0, x_average=0.0, step_size=initial_step_size, step_size_max=step_size_max, num_dimensions=num_dimensions) + + def update(robnik_state, info): + + # jax.debug.print("robnik state: {x}", x=robnik_state) + # jax.debug.print("info: {x}", x=(info.energy_change, info.nonans)) + # raise Exception("Stop here") + + + + energy_change = info.energy_change + + + xi = ( + jnp.square(energy_change) / (robnik_state.num_dimensions * desired_energy_var) + ) + 1e-8 # 1e-8 is added to avoid divergences in log xi + weight = jnp.exp( + -0.5 * jnp.square(jnp.log(xi) / (6.0 * trust_in_estimate)) + ) # the weight reduces the impact of stepsizes which are much larger on much smaller than the desired one. + + x_average = decay_rate * robnik_state.x_average + weight * ( + xi / jnp.power(robnik_state.step_size, 6.0) + ) + + time = decay_rate * robnik_state.time + weight + step_size = jnp.power( + x_average / time, -1.0 / 6.0 + ) # We use the Var[E] = O(eps^6) relation here. + step_size = (step_size < robnik_state.step_size_max) * step_size + ( + step_size > robnik_state.step_size_max + ) * robnik_state.step_size_max # if the proposed stepsize is above the stepsize where we have seen divergences + # jax.debug.print("new step_size: {x}", x=(step_size)) + + # old_robnik_state = robnik_state + + # old_robnik_step_size = robnik_state.step_size + # jax.debug.print("step_size: {x}", x=(old_robnik_step_size, old_robnik_step_size * step_size_reduction_factor, step_size)) + # jax.debug.print("stuff {x}", x=(x_average, time, robnik_state.time, time)) + old_robnik_state = robnik_state + + + + robnik_state = jax.lax.cond( + info.nonans, + lambda: RobnikStepSizeTuningState(time=time, x_average=x_average, step_size=step_size, step_size_max=step_size_max, num_dimensions=robnik_state.num_dimensions), + lambda: robnik_state._replace(step_size=robnik_state.step_size * step_size_reduction_factor), + ) + # jax.debug.print("robnik_state: {robnik_state}", robnik_state=(robnik_state.step_size, info.nonans, robnik_state.step_size * step_size_reduction_factor)) + # jax.debug.print("robnik_state: {x}", x=(old_robnik_state.step_size, robnik_state.step_size)) + return robnik_state + + # return RobnikStepSizeTuningState(time=time, x_average=x_average, step_size=step_size, step_size_max=step_size_max, num_dimensions=robnik_state.num_dimensions) + + + def final(robnik_state): + return robnik_state.step_size + + return init, update, final diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 69a098325..18dd76cbb 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -45,6 +45,7 @@ class WindowAdaptationState(NamedTuple): def base( is_mass_matrix_diagonal: bool, target_acceptance_rate: float = 0.80, + preconditioning = True, ) -> tuple[Callable, Callable, Callable]: """Warmup scheme for sampling procedures based on euclidean manifold HMC. The schedule and algorithms used match Stan's :cite:p:`stan_hmc_param` as closely as possible. @@ -100,6 +101,9 @@ def base( """ mm_init, mm_update, mm_final = mass_matrix_adaptation(is_mass_matrix_diagonal) + if not preconditioning: + mm_update = lambda x, y: x + mm_final = lambda x: x da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate) def init( @@ -164,6 +168,7 @@ def slow_update( """ new_imm_state = mm_update(warmup_state.imm_state, position) + # new_imm_state = warmup_state.imm_state new_ss_state = da_update(warmup_state.ss_state, acceptance_rate) new_step_size = jnp.exp(new_ss_state.log_step_size) @@ -251,6 +256,7 @@ def window_adaptation( progress_bar: bool = False, adaptation_info_fn: Callable = return_all_adapt_info, integrator=mcmc.integrators.velocity_verlet, + preconditioning = True, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -301,6 +307,7 @@ def window_adaptation( adapt_init, adapt_step, adapt_final = base( is_mass_matrix_diagonal, target_acceptance_rate=target_acceptance_rate, + preconditioning=preconditioning, ) def one_step(carry, xs): diff --git a/blackjax/diagnostics.py b/blackjax/diagnostics.py index 257ce759c..fd184d048 100644 --- a/blackjax/diagnostics.py +++ b/blackjax/diagnostics.py @@ -207,3 +207,19 @@ def monotone_sequence_body_fn(rho_hat_sum_tm1, rho_hat_sum_t): ess = ess_raw / tau_hat return ess.squeeze() + + +def splitR(position, num_chains, superchain_size, func_for_splitR = jnp.square): + + # combine the chains in super-chains to compute expectation values + func_mk = jax.vmap(func_for_splitR)(position) # shape = (# chains, # func) + func_mk = func_mk.reshape(num_chains // superchain_size, superchain_size, func_mk.shape[-1]) #shape = (# superchains, # chains in superchain, # func) + func_k = jnp.average(func_mk, axis = 1) #shape = (# superchains, # func) + func_sq_k = jnp.average(jnp.square(func_mk), axis = 1) #shape = (# superchains, # func) + W_k = (func_sq_k - jnp.square(func_k)) * superchain_size / (superchain_size - 1) # variance withing k-th superchain + W = jnp.average(W_k, axis = 0) # average within superchain variance + B = jnp.var(func_k, axis = 0, ddof= 1) # between superchain variance + + R = jnp.sqrt(1. + (B/W)) # splitR, shape = (# func) + + return R \ No newline at end of file diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 8acb28274..47bc0837f 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -1,17 +1,22 @@ from . import ( adjusted_mclmc, adjusted_mclmc_dynamic, + malt, + dynamic_malt, barker, elliptical_slice, ghmc, hmc, + uhmc, mala, marginal_latent_gaussian, mclmc, + mchmc, nuts, periodic_orbital, random_walk, rmhmc, + underdamped_langevin ) __all__ = [ @@ -26,6 +31,11 @@ "marginal_latent_gaussian", "random_walk", "mclmc", + "mchmc", + "underdamped_langevin", + "uhmc", "adjusted_mclmc_dynamic", "adjusted_mclmc", + "dynamic_malt", + "malt", ] diff --git a/blackjax/mcmc/adjusted_mclmc.py b/blackjax/mcmc/adjusted_mclmc.py index f390402f2..747bedc4c 100644 --- a/blackjax/mcmc/adjusted_mclmc.py +++ b/blackjax/mcmc/adjusted_mclmc.py @@ -11,11 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin". - -NOTE: For best performance, we recommend using adjusted_mclmc_dynamic instead of this module, which is primarily intended for use in parallelized versions of the algorithm. - -""" +"""Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" from typing import Callable, Union import jax @@ -23,9 +19,10 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.hmc import HMCInfo, HMCState +from blackjax.mcmc.hmc import HMCState +from blackjax.mcmc.hmc import HMCInfo from blackjax.mcmc.proposal import static_binomial_sampling -from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector __all__ = ["init", "build_kernel", "as_top_level_api"] @@ -37,12 +34,12 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable): def build_kernel( - logdensity_fn: Callable, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, - inverse_mass_matrix=1.0, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + L_proposal_factor: float = jnp.inf, ): - """Build an MHMCHMC kernel where the number of integration steps is chosen randomly. + """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. Parameters ---------- @@ -66,12 +63,17 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: HMCState, + logdensity_fn: Callable, step_size: float, - num_integration_steps: int, - L_proposal_factor: float = jnp.inf, + integration_steps_fn, + inverse_mass_matrix=1.0, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the MHMCHMC kernel.""" + + # num_integration_steps = integration_steps_fn(state.random_generator_arg) + num_integration_steps = integration_steps_fn(None) + key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( @@ -96,6 +98,7 @@ def kernel( proposal.position, proposal.logdensity, proposal.logdensity_grad, + # next_random_arg_fn(state.random_generator_arg), ), info, ) @@ -111,9 +114,10 @@ def as_top_level_api( *, divergence_threshold: int = 1000, integrator: Callable = integrators.isokinetic_mclachlan, - num_integration_steps, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), ) -> SamplingAlgorithm: - """Implements the (basic) user interface for the MHMCHMC kernel. + """Implements the (basic) user interface for the dynamic MHMCHMC kernel. Parameters ---------- @@ -140,23 +144,23 @@ def as_top_level_api( """ kernel = build_kernel( - logdensity_fn=logdensity_fn, integrator=integrator, - inverse_mass_matrix=inverse_mass_matrix, + next_random_arg_fn=next_random_arg_fn, divergence_threshold=divergence_threshold, + L_proposal_factor=L_proposal_factor, ) - def init_fn(position: ArrayLikeTree, rng_key=None): - del rng_key - return init(position, logdensity_fn) + def init_fn(position: ArrayLikeTree, rng_key: Array): + return init(position, logdensity_fn, rng_key) def update_fn(rng_key: PRNGKey, state): return kernel( rng_key=rng_key, state=state, + logdensity_fn=logdensity_fn, step_size=step_size, - num_integration_steps=num_integration_steps, - L_proposal_factor=L_proposal_factor, + integration_steps_fn=integration_steps_fn, + inverse_mass_matrix=inverse_mass_matrix, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] @@ -240,3 +244,29 @@ def generate( return sampled_state, info, other_proposal_info return generate + + +def rescale(mu): + """returns s, such that + round(U(0, 1) * s + 0.5) + has expected value mu. + """ + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def trajectory_length(t, mu): + s = rescale(mu) + return jnp.rint(0.5 + halton_sequence(t) * s) + +def make_random_trajectory_length_fn(random_trajectory_length : bool): + if random_trajectory_length: + integration_steps_fn = lambda avg_num_integration_steps: lambda k: (jnp.clip(jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), min=1)).astype('int32') + else: + integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.clip(jnp.ceil( + avg_num_integration_steps + ), min=1).astype('int32') + return integration_steps_fn \ No newline at end of file diff --git a/blackjax/mcmc/adjusted_mclmc_dynamic.py b/blackjax/mcmc/adjusted_mclmc_dynamic.py index 1a69e1a28..b51a4d551 100644 --- a/blackjax/mcmc/adjusted_mclmc_dynamic.py +++ b/blackjax/mcmc/adjusted_mclmc_dynamic.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for the Metropolis Hastings Microcanonical Hamiltonian Monte Carlo (MHMCHMC) Kernel. This is closely related to the Microcanonical Langevin Monte Carlo (MCLMC) Kernel, which is an unadjusted method. This kernel adds a Metropolis-Hastings correction to the MCLMC kernel. It also only refreshes the momentum variable after each MH step, rather than during the integration of the trajectory. Hence "Hamiltonian" and not "Langevin".""" -from typing import Callable, Union +from typing import Callable, NamedTuple, Union import jax import jax.numpy as jnp @@ -20,13 +20,22 @@ import blackjax.mcmc.integrators as integrators from blackjax.base import SamplingAlgorithm from blackjax.mcmc.dynamic_hmc import DynamicHMCState, halton_sequence -from blackjax.mcmc.hmc import HMCInfo +# from blackjax.mcmc.hmc import HMCInfo from blackjax.mcmc.proposal import static_binomial_sampling from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey from blackjax.util import generate_unit_vector __all__ = ["init", "build_kernel", "as_top_level_api"] +class MAMSInfo(NamedTuple): + + # momentum: ArrayTree + acceptance_rate: float + is_accepted: bool + is_divergent: bool + energy: float + # proposal: integrators.IntegratorState + num_integration_steps: int def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) @@ -34,11 +43,10 @@ def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: def build_kernel( - integration_steps_fn, integrator: Callable = integrators.isokinetic_mclachlan, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - inverse_mass_matrix=1.0, + L_proposal_factor: float = jnp.inf, ): """Build a Dynamic MHMCHMC kernel where the number of integration steps is chosen randomly. @@ -66,12 +74,24 @@ def kernel( state: DynamicHMCState, logdensity_fn: Callable, step_size: float, - L_proposal_factor: float = jnp.inf, - ) -> tuple[DynamicHMCState, HMCInfo]: + integration_steps_fn, + inverse_mass_matrix=1.0, + ) -> tuple[DynamicHMCState, MAMSInfo]: """Generate a new sample with the MHMCHMC kernel.""" + # return state, HMCInfo( + # state.position, + # 0.0, + # False, + # False, + # state.logdensity, + # None, + # 0, + # ) + num_integration_steps = integration_steps_fn(state.random_generator_arg) + key_momentum, key_integrator = jax.random.split(rng_key, 2) momentum = generate_unit_vector(key_momentum, state.position) proposal, info, _ = adjusted_mclmc_proposal( @@ -142,11 +162,10 @@ def as_top_level_api( """ kernel = build_kernel( - integration_steps_fn=integration_steps_fn, integrator=integrator, next_random_arg_fn=next_random_arg_fn, - inverse_mass_matrix=inverse_mass_matrix, divergence_threshold=divergence_threshold, + L_proposal_factor=L_proposal_factor, ) def init_fn(position: ArrayLikeTree, rng_key: Array): @@ -154,11 +173,12 @@ def init_fn(position: ArrayLikeTree, rng_key: Array): def update_fn(rng_key: PRNGKey, state): return kernel( - rng_key, - state, - logdensity_fn, - step_size, - L_proposal_factor, + rng_key=rng_key, + state=state, + logdensity_fn=logdensity_fn, + step_size=step_size, + integration_steps_fn=integration_steps_fn, + inverse_mass_matrix=inverse_mass_matrix, ) return SamplingAlgorithm(init_fn, update_fn) # type: ignore[arg-type] @@ -216,7 +236,7 @@ def build_trajectory(state, num_integration_steps, rng_key): def generate( rng_key, state: integrators.IntegratorState - ) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]: + ) -> tuple[integrators.IntegratorState, MAMSInfo, ArrayTree]: """Generate a new chain state.""" end_state, kinetic_energy, rng_key = build_trajectory( state, num_integration_steps, rng_key @@ -229,13 +249,13 @@ def generate( sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info - info = HMCInfo( - state.momentum, + info = MAMSInfo( + # state.momentum, p_accept, do_accept, is_diverging, new_energy, - end_state, + # end_state, num_integration_steps, ) @@ -257,3 +277,14 @@ def rescale(mu): def trajectory_length(t, mu): s = rescale(mu) return jnp.rint(0.5 + halton_sequence(t) * s) + +def make_random_trajectory_length_fn(random_trajectory_length : bool): + if random_trajectory_length: + integration_steps_fn = lambda avg_num_integration_steps: lambda k: (jnp.clip(jnp.ceil( + jax.random.uniform(k) * rescale(avg_num_integration_steps) + ), min=1)).astype('int32') + else: + integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.clip(jnp.ceil( + avg_num_integration_steps + ), min=1).astype('int32') + return integration_steps_fn \ No newline at end of file diff --git a/blackjax/mcmc/dynamic_hmc.py b/blackjax/mcmc/dynamic_hmc.py index de77be825..2916d0262 100644 --- a/blackjax/mcmc/dynamic_hmc.py +++ b/blackjax/mcmc/dynamic_hmc.py @@ -1,3 +1,4 @@ + # Copyright 2020- The Blackjax Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -55,7 +56,6 @@ def build_kernel( integrator: Callable = integrators.velocity_verlet, divergence_threshold: float = 1000, next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], - integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), ): """Build a Dynamic HMC kernel where the number of integration steps is chosen randomly. @@ -86,6 +86,7 @@ def kernel( logdensity_fn: Callable, step_size: float, inverse_mass_matrix: Array, + integration_steps_fn: Callable, **integration_steps_kwargs, ) -> tuple[DynamicHMCState, HMCInfo]: """Generate a new sample with the HMC kernel.""" @@ -154,7 +155,7 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ kernel = build_kernel( - integrator, divergence_threshold, next_random_arg_fn, integration_steps_fn + integrator, divergence_threshold, next_random_arg_fn, ) def init_fn(position: ArrayLikeTree, rng_key: Array): @@ -170,6 +171,7 @@ def step_fn(rng_key: PRNGKey, state): logdensity_fn, step_size, inverse_mass_matrix, + integration_steps_fn ) return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/dynamic_malt.py b/blackjax/mcmc/dynamic_malt.py new file mode 100644 index 000000000..91c1e592f --- /dev/null +++ b/blackjax/mcmc/dynamic_malt.py @@ -0,0 +1,194 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Dynamic HMC Kernel""" +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +import blackjax.mcmc.integrators as integrators +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.hmc import HMCInfo, HMCState +from blackjax.mcmc.malt import build_kernel as build_static_hmc_kernel +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.mcmc.dynamic_hmc import DynamicHMCState + +__all__ = [ + "DynamicHMCState", + "init", + "build_kernel", + "halton_sequence", + "as_top_level_api", +] + + + + +def init(position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg) + + +def build_kernel( + integrator: Callable = integrators.velocity_verlet, + divergence_threshold: float = 1000, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + L_proposal_factor: float = jnp.inf, +): + """Build a Dynamic HMC kernel where the number of integration steps is chosen randomly. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is divergent. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. Needs to return an `int`. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + + hmc_base = build_static_hmc_kernel(integrator, divergence_threshold, L_proposal_factor) + + def kernel( + rng_key: PRNGKey, + state: DynamicHMCState, + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + **integration_steps_kwargs, + ) -> tuple[DynamicHMCState, HMCInfo]: + """Generate a new sample with the HMC kernel.""" + + + num_integration_steps = integration_steps_fn( + state.random_generator_arg, **integration_steps_kwargs + ).astype(int) + + # jax.debug.print("num_integration_steps {x}", x=(num_integration_steps, step_size, inverse_mass_matrix)) + + hmc_state = HMCState(state.position, state.logdensity, state.logdensity_grad) + hmc_proposal, info = hmc_base( + rng_key, + hmc_state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + ) + + next_random_arg = next_random_arg_fn(state.random_generator_arg) + return ( + DynamicHMCState( + hmc_proposal.position, + hmc_proposal.logdensity, + hmc_proposal.logdensity_grad, + next_random_arg, + ), + info, + ) + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: Array, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, + next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1], + integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10), + L_proposal_factor: float = jnp.inf, +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the dynamic HMC kernel. + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + inverse_mass_matrix + The value to use for the inverse mass matrix when drawing a value for + the momentum and computing the kinetic energy. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the trajectory. + next_random_arg_fn + Function that generates the next `random_generator_arg` from its previous value. + integration_steps_fn + Function that generates the next pseudo or quasi-random number of integration steps in the + sequence, given the current `random_generator_arg`. + + + Returns + ------- + A ``SamplingAlgorithm``. + """ + kernel = build_kernel( + integrator, divergence_threshold, next_random_arg_fn, L_proposal_factor + ) + + def init_fn(position: ArrayLikeTree, rng_key: Array): + # Note that rng_key here is not necessarily a PRNGKey, could be a Array that + # for generates a sequence of pseudo or quasi-random numbers (previously + # named as `random_generator_arg`) + return init(position, logdensity_fn, rng_key) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + integration_steps_fn, + ) + + return SamplingAlgorithm(init_fn, step_fn) + + +def halton_sequence(i: Array, max_bits: int = 10) -> float: + bit_masks = 2 ** jnp.arange(max_bits, dtype=i.dtype) + return jnp.einsum("i,i->", jnp.mod((i + 1) // bit_masks, 2), 0.5 / bit_masks) + + +def rescale(mu): + # Returns s, such that `round(U(0, 1) * s + 0.5)` has expected value mu. + k = jnp.floor(2 * mu - 1) + x = k * (mu - 0.5 * (k + 1)) / (k + 1 - mu) + return k + x + + +def halton_trajectory_length( + i: Array, trajectory_length_adjustment: float, max_bits: int = 10 +) -> int: + """Generate a quasi-random number of integration steps.""" + s = rescale(trajectory_length_adjustment) + return jnp.asarray(jnp.rint(0.5 + halton_sequence(i, max_bits) * s), dtype=int) diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 452b94e44..384c96b28 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -121,6 +121,7 @@ def kernel( num_integration_steps: int, ) -> tuple[HMCState, HMCInfo]: """Generate a new sample with the HMC kernel.""" + metric = metrics.default_metric(inverse_mass_matrix) symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy) @@ -134,7 +135,10 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) - position, logdensity, logdensity_grad = state + # position, logdensity, logdensity_grad = state + position = state.position + logdensity = state.logdensity + logdensity_grad = state.logdensity_grad momentum = metric.sample_momentum(key_momentum, position) integrator_state = integrators.IntegratorState( @@ -302,6 +306,7 @@ def generate( is_diverging = -delta_energy > divergence_threshold sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) do_accept, p_accept, other_proposal_info = info + info = HMCInfo( state.momentum, @@ -335,4 +340,4 @@ def flip_momentum( flipped_momentum, state.logdensity, state.logdensity_grad, - ) + ) \ No newline at end of file diff --git a/blackjax/mcmc/integrators.py b/blackjax/mcmc/integrators.py index 0effa204e..13987b98f 100644 --- a/blackjax/mcmc/integrators.py +++ b/blackjax/mcmc/integrators.py @@ -21,6 +21,7 @@ from blackjax.mcmc.metrics import KineticEnergy from blackjax.types import ArrayTree +from blackjax.mcmc.metrics import default_metric __all__ = [ "mclachlan", @@ -28,6 +29,7 @@ "velocity_verlet", "yoshida", "with_isokinetic_maruyama", + "with_maruyama", "isokinetic_velocity_verlet", "isokinetic_mclachlan", "isokinetic_omelyan", @@ -98,7 +100,12 @@ def generalized_two_stage_integrator( """ def one_step(state: IntegratorState, step_size: float): - position, momentum, _, logdensity_grad = state + # position, momentum, _, logdensity_grad = state + position = state.position + momentum = state.momentum + logdensity_grad = state.logdensity_grad + # logdensity = state.logdensity + # jax.debug.print("initial state {x}", x=jnp.any(jnp.isnan(momentum))) # auxiliary infomation generated during integration for diagnostics. It is # updated by the operator1 and operator2 at each call. momentum_update_info = None @@ -113,6 +120,7 @@ def one_step(state: IntegratorState, step_size: float): momentum_update_info, is_last_call=False, ) + # jax.debug.print("momentum inside {x}", x=momentum) else: ( position, @@ -163,12 +171,17 @@ def update( coef: float, auxiliary_info=None, ): + + + # jax.debug.print("nan? {x}", x=jnp.any(jnp.isnan(kinetic_grad))) + # jax.debug.print("position {x}", x=position) del auxiliary_info new_position = jax.tree_util.tree_map( lambda x, grad: x + step_size * coef * grad, position, kinetic_grad, ) + # jax.debug.print("new position {x}", x=new_position) logdensity, logdensity_grad = logdensity_and_grad_fn(new_position) return new_position, logdensity, logdensity_grad, None @@ -187,11 +200,14 @@ def update( is_last_call=False, ): del auxiliary_info + # jax.debug.print("momentum {x}", x=momentum) + # jax.debug.print("force {x}", x=logdensity_grad[0]) new_momentum = jax.tree_util.tree_map( lambda x, grad: x + step_size * coef * grad, momentum, logdensity_grad, ) + # jax.debug.print("new momentum {x}", x=new_momentum) if is_last_call: return new_momentum, None, None kinetic_grad = kinetic_energy_grad_fn(new_momentum) @@ -305,14 +321,14 @@ def euclidean_integrator( omelyan = generate_euclidean_integrator(omelyan_coefficients) -# Intergrators with non Euclidean updates +# Integrators with non Euclidean updates def _normalized_flatten_array(x, tol=1e-13): norm = jnp.linalg.norm(x) return jnp.where(norm > tol, x / norm, x), norm def esh_dynamics_momentum_update_one_step(inverse_mass_matrix=1.0): - sqrt_inverse_mass_matrix = jnp.sqrt(inverse_mass_matrix) + sqrt_inverse_mass_matrix = jax.tree_util.tree_map(jnp.sqrt, inverse_mass_matrix) def update( momentum: ArrayTree, @@ -330,7 +346,7 @@ def update( """ del is_last_call - logdensity_grad = logdensity_grad + # logdensity_grad = logdensity_grad flatten_grads, unravel_fn = ravel_pytree(logdensity_grad) flatten_grads = flatten_grads * sqrt_inverse_mass_matrix flatten_momentum, _ = ravel_pytree(momentum) @@ -353,6 +369,7 @@ def update( ) * (dims - 1) if previous_kinetic_energy_change is not None: kinetic_energy_change += previous_kinetic_energy_change + return next_momentum, gr, kinetic_energy_change return update @@ -398,7 +415,7 @@ def isokinetic_integrator( isokinetic_omelyan = generate_isokinetic_integrator(omelyan_coefficients) -def partially_refresh_momentum(momentum, rng_key, step_size, L): +def partially_refresh_momentum_isokinetic(momentum, rng_key, step_size, L): """Adds a small noise to momentum and normalizes. Parameters @@ -422,7 +439,46 @@ def partially_refresh_momentum(momentum, rng_key, step_size, L): nu = jnp.sqrt((jnp.exp(2 * step_size / L) - 1.0) / dim) z = nu * normal(rng_key, shape=m.shape, dtype=m.dtype) new_momentum = unravel_fn((m + z) / jnp.linalg.norm(m + z)) - # return new_momentum + + return jax.lax.cond( + jnp.isinf(L), + lambda _: momentum, + lambda _: new_momentum, + operand=None, + ) + + +def partially_refresh_momentum(momentum, rng_key, step_size, L, inverse_mass_matrix): + """Adds a small noise to momentum and normalizes. + + Parameters + ---------- + rng_key + The pseudo-random number generator key used to generate random numbers. + momentum + PyTree that the structure the output should to match. + step_size + Step size + L + controls rate of momentum change + + Returns + ------- + momentum with random change in angle + """ + + # TODO + m, unravel_fn = ravel_pytree(momentum) + # m = jax.tree_util.tree_map(lambda x: x * jnp.sqrt(inverse_mass_matrix), m) + # dim = m.shape[0] + c1 = jnp.exp(-step_size/L) + c2 = jnp.sqrt((1-c1**2)) + # z = normal(rng_key, shape=m.shape, dtype=m.dtype) + metric = default_metric(inverse_mass_matrix) + z = metric.sample_momentum(rng_key, m) + # normal(rng_key, shape=m.shape, dtype=m.dtype) + new_momentum = unravel_fn(c1*m + c2*z) + return jax.lax.cond( jnp.isinf(L), lambda _: momentum, @@ -436,7 +492,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): key1, key2 = jax.random.split(rng_key) # partial refreshment state = init_state._replace( - momentum=partially_refresh_momentum( + momentum=partially_refresh_momentum_isokinetic( momentum=init_state.momentum, rng_key=key1, L=L_proposal, @@ -448,7 +504,7 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): # partial refreshment state = state._replace( - momentum=partially_refresh_momentum( + momentum=partially_refresh_momentum_isokinetic( momentum=state.momentum, rng_key=key2, L=L_proposal, @@ -460,6 +516,49 @@ def stochastic_integrator(init_state, step_size, L_proposal, rng_key): return stochastic_integrator +def with_maruyama(integrator, kinetic_energy,inverse_mass_matrix): + def stochastic_integrator(init_state, step_size, L_proposal, rng_key): + key1, key2 = jax.random.split(rng_key) + # partial refreshment + # jax.debug.print("state 1 {x}",x=init_state) + state = init_state._replace( + momentum=partially_refresh_momentum( + momentum=init_state.momentum, + rng_key=key1, + L=L_proposal, + step_size=step_size * 0.5, + inverse_mass_matrix=inverse_mass_matrix, + ) + ) + # jax.debug.print("state 1.5 {x}",x=state) + # one step of the deterministic dynamics + new_state = integrator(state, step_size) + # jax.debug.print("state 2 {x}",x=state) + + kinetic_change = - kinetic_energy(state.momentum) + kinetic_energy( + new_state.momentum + ) + energy_change = kinetic_change - new_state.logdensity + state.logdensity + + # partial refreshment + state = new_state._replace( + momentum=partially_refresh_momentum( + momentum=new_state.momentum, + rng_key=key2, + L=L_proposal, + step_size=step_size * 0.5, + inverse_mass_matrix=inverse_mass_matrix, + ) + ) + + + + + return state, (kinetic_change, energy_change) + + return stochastic_integrator + + FixedPointSolver = Callable[ [Callable[[ArrayTree], Tuple[ArrayTree, ArrayTree]], ArrayTree], Tuple[ArrayTree, ArrayTree, Any], diff --git a/blackjax/mcmc/malt.py b/blackjax/mcmc/malt.py new file mode 100644 index 000000000..395655bce --- /dev/null +++ b/blackjax/mcmc/malt.py @@ -0,0 +1,315 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the HMC Kernel""" +from typing import Callable, NamedTuple, Union + +import jax +import jax.numpy as jnp +import blackjax.mcmc.integrators as integrators +import blackjax.mcmc.metrics as metrics +import blackjax.mcmc.trajectory as trajectory +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.proposal import safe_energy_diff, static_binomial_sampling +from blackjax.mcmc.trajectory import hmc_energy +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.mcmc.hmc import HMCState + +__all__ = [ + "HMCState", + "init", + "build_kernel", + "as_top_level_api", +] + +class MALTInfo(NamedTuple): + + # momentum: ArrayTree + acceptance_rate: float + is_accepted: bool + is_divergent: bool + energy: float + # proposal: integrators.IntegratorState + num_integration_steps: int + + + + +def init(position: ArrayLikeTree, logdensity_fn: Callable): + logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position) + return HMCState(position, logdensity, logdensity_grad) + + +def build_kernel( + integrator: Callable = integrators.velocity_verlet, + divergence_threshold: float = 1000, + L_proposal_factor = jnp.inf, +): + """Build a HMC kernel. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Hamiltonian dynamics. + divergence_threshold + Value of the difference in energy above which we consider that the transition is + divergent. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + + def kernel( + rng_key: PRNGKey, + state: HMCState, + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes, + num_integration_steps: int, + + ) -> tuple[HMCState, MALTInfo]: + """Generate a new sample with the HMC kernel.""" + + L = num_integration_steps * step_size + L_proposal = L_proposal_factor * L + + # jax.debug.print("L_proposal {L_proposal}",L_proposal=L_proposal) + + key_trajectory, key_momentum, key_integrator = jax.random.split(rng_key, 3) + metric = metrics.default_metric(inverse_mass_matrix) + symplectic_integrator = lambda state, step_size, rng_key: integrators.with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy, inverse_mass_matrix)(state, step_size, L_proposal=L_proposal, rng_key=rng_key) + proposal_generator = hmc_proposal( + symplectic_integrator, + metric.kinetic_energy, + step_size, + num_integration_steps, + divergence_threshold, + rng_key=key_trajectory, + ) + + + position, logdensity, logdensity_grad = state + momentum = metric.sample_momentum(key_momentum, position) + + integrator_state = integrators.IntegratorState( + position, momentum, logdensity, logdensity_grad + ) + proposal, info, _ = proposal_generator(key_integrator, integrator_state) + proposal = HMCState( + proposal.position, proposal.logdensity, proposal.logdensity_grad + ) + + return proposal, info + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: metrics.MetricTypes, + num_integration_steps: int, + *, + divergence_threshold: int = 1000, + integrator: Callable = integrators.velocity_verlet, + L_proposal_factor = jnp.inf, +) -> SamplingAlgorithm: + """Implements the (basic) user interface for the HMC kernel. + + The general hmc kernel builder (:meth:`blackjax.mcmc.hmc.build_kernel`, alias + `blackjax.hmc.build_kernel`) can be cumbersome to manipulate. Since most users only + need to specify the kernel parameters at initialization time, we provide a helper + function that specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.hmc` to SMC, adaptation, etc. algorithms. + + Examples + -------- + + A new HMC kernel can be initialized and used with the following code: + + .. code:: + + hmc = blackjax.hmc( + logdensity_fn, step_size, inverse_mass_matrix, num_integration_steps + ) + state = hmc.init(position) + new_state, info = hmc.step(rng_key, state) + + Kernels are not jit-compiled by default so you will need to do it manually: + + .. code:: + + step = jax.jit(hmc.step) + new_state, info = step(rng_key, state) + + Should you need to you can always use the base kernel directly: + + .. code:: + + import blackjax.mcmc.integrators as integrators + + kernel = blackjax.hmc.build_kernel(integrators.mclachlan) + state = blackjax.hmc.init(position, logdensity_fn) + state, info = kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + ) + + Parameters + ---------- + logdensity_fn + The log-density function we wish to draw samples from. + step_size + The value to use for the step size in the symplectic integrator. + inverse_mass_matrix + The value to use for the inverse mass matrix when drawing a value for + the momentum and computing the kinetic energy. This argument will be + passed to the ``metrics.default_metric`` function so it supports the + full interface presented there. + num_integration_steps + The number of steps we take with the symplectic integrator at each + sample step before returning a sample. + divergence_threshold + The absolute value of the difference in energy between two states above + which we say that the transition is divergent. The default value is + commonly found in other libraries, and yet is arbitrary. + integrator + (algorithm parameter) The symplectic integrator to use to integrate the + trajectory. + + Returns + ------- + A ``SamplingAlgorithm``. + """ + + kernel = build_kernel(integrator, divergence_threshold) + + def init_fn(position: ArrayLikeTree, rng_key=None): + del rng_key + return init(position, logdensity_fn) + + def step_fn(rng_key: PRNGKey, state): + return kernel( + rng_key, + state, + logdensity_fn, + step_size, + inverse_mass_matrix, + num_integration_steps, + L_proposal_factor, + ) + + return SamplingAlgorithm(init_fn, step_fn) + + +def hmc_proposal( + integrator: Callable, + kinetic_energy: metrics.KineticEnergy, + step_size: Union[float, ArrayLikeTree], + num_integration_steps: int = 1, + divergence_threshold: float = 1000, + *, + sample_proposal: Callable = static_binomial_sampling, + rng_key: PRNGKey, +) -> Callable: + """Vanilla HMC algorithm. + + The algorithm integrates the trajectory applying a symplectic integrator + `num_integration_steps` times in one direction to get a proposal and uses a + Metropolis-Hastings acceptance step to either reject or accept this + proposal. This is what people usually refer to when they talk about "the + HMC algorithm". + + Parameters + ---------- + integrator + Symplectic integrator used to build the trajectory step by step. + kinetic_energy + Function that computes the kinetic energy. + step_size + Size of the integration step. + num_integration_steps + Number of times we run the symplectic integrator to build the trajectory + divergence_threshold + Threshold above which we say that there is a divergence. + + Returns + ------- + A kernel that generates a new chain state and information about the transition. + + """ + build_trajectory = trajectory.langevin_integration(integrator, rng_key) + hmc_energy_fn = hmc_energy(kinetic_energy) + + def generate( + rng_key, state: integrators.IntegratorState + ) -> tuple[integrators.IntegratorState, MALTInfo, ArrayTree]: + """Generate a new chain state.""" + end_state, delta_energy = build_trajectory(state, step_size, num_integration_steps) + end_state = flip_momentum(end_state) + # proposal_energy = hmc_energy_fn(state) + new_energy = hmc_energy_fn(end_state) + # delta_energy = safe_energy_diff(proposal_energy, new_energy) + + delta_energy = jnp.where(jnp.isnan(delta_energy), -jnp.inf, delta_energy) + # jax.debug.print("delta_energy_trajectory {x}",x=(delta_energy_trajectory, delta_energy)) + is_diverging = -delta_energy > divergence_threshold + # jax.debug.print("is_diverging {x}",x=is_diverging) + sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state) + do_accept, p_accept, other_proposal_info = info + # jax.debug.print("delta_energy {x}",x=delta_energy) + # jax.debug.print("p_accept {p_accept}", p_accept=(p_accept, delta_energy)) + + info = MALTInfo( + # state.momentum, + p_accept, + do_accept, + is_diverging, + new_energy, + # end_state, + num_integration_steps, + ) + + return sampled_state, info, other_proposal_info + + return generate + + +def flip_momentum( + state: integrators.IntegratorState, +) -> integrators.IntegratorState: + """Flip the momentum at the end of the trajectory. + + To guarantee time-reversibility (hence detailed balance) we + need to flip the last state's momentum. If we run the hamiltonian + dynamics starting from the last state with flipped momentum we + should indeed retrieve the initial state (with flipped momentum). + + """ + flipped_momentum = jax.tree_util.tree_map(lambda m: -1.0 * m, state.momentum) + return integrators.IntegratorState( + state.position, + flipped_momentum, + state.logdensity, + state.logdensity_grad, + ) diff --git a/blackjax/mcmc/mchmc.py b/blackjax/mcmc/mchmc.py new file mode 100644 index 000000000..06bda7097 --- /dev/null +++ b/blackjax/mcmc/mchmc.py @@ -0,0 +1,134 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the MCLMC Kernel""" +from typing import Callable, NamedTuple + +import jax +import jax.numpy as jnp + +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import ( + IntegratorState, + isokinetic_mclachlan, +) +from blackjax.types import ArrayLike, PRNGKey +from blackjax.util import generate_unit_vector, pytree_size +from blackjax.mcmc.mclmc import handle_high_energy, handle_nans, MCLMCInfo +from blackjax.mcmc.adjusted_mclmc_dynamic import make_random_trajectory_length_fn +__all__ = ["MCLMCInfo", "init", "build_kernel", "as_top_level_api"] + +class MCHMCState(NamedTuple): + position: ArrayLike + momentum: ArrayLike + logdensity: float + logdensity_grad: ArrayLike + steps_until_refresh: int + + + +def init(position: ArrayLike, logdensity_fn, random_generator_arg): + if pytree_size(position) < 2: + raise ValueError( + "The target distribution must have more than 1 dimension for MCLMC." + ) + l, g = jax.value_and_grad(logdensity_fn)(position) + + return MCHMCState( + position=position, + momentum=generate_unit_vector(random_generator_arg, position), + logdensity=l, + logdensity_grad=g, + steps_until_refresh=0, + ) + +def integrator_state(state: MCHMCState) -> IntegratorState: + return IntegratorState( + position=state.position, + momentum=state.momentum, + logdensity=state.logdensity, + logdensity_grad=state.logdensity_grad, + ) + + +def build_kernel( + # integration_steps_fn, + integrator, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +): + """ + """ + + + + def kernel( + rng_key: PRNGKey, state: MCHMCState, logdensity_fn, L: float, step_size: float, inverse_mass_matrix: ArrayLike + ) -> tuple[MCHMCState, MCLMCInfo]: + step = integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) + (position, momentum, logdensity, logdensitygrad), kinetic_change = step( + integrator_state(state), step_size + ) + + randomization_key, refresh_key, energy_cutoff_key, nan_key = jax.random.split(rng_key, 4) + + num_steps_per_traj = make_random_trajectory_length_fn(True)(L/step_size)(randomization_key).astype(jnp.int64) + + + + energy_change = kinetic_change - logdensity + state.logdensity + + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) + + momentum=(state.steps_until_refresh==0) * generate_unit_vector(refresh_key, state.position) + (state.steps_until_refresh>0) * momentum + + steps_until_refresh = (state.steps_until_refresh==0) * num_steps_per_traj + (state.steps_until_refresh>0) * (state.steps_until_refresh - 1) + + new_state, info = handle_high_energy(state, MCHMCState(position, momentum, logdensity, logdensitygrad, steps_until_refresh), MCLMCInfo( + logdensity=logdensity, + energy_change=energy_change, + kinetic_change=kinetic_change, + nonans=True + ), energy_cutoff_key, cutoff = jnp.sqrt(ndims * eev_max_per_dim)) + + new_state, info = handle_nans(state, new_state, info, nan_key) + + return new_state, info + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + L, + step_size, + integrator=isokinetic_mclachlan, + inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, +) -> SamplingAlgorithm: + """ + """ + + kernel = build_kernel( + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + ) + + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key, state): + return kernel(rng_key, state, logdensity_fn, L, step_size, inverse_mass_matrix) + + return SamplingAlgorithm(init_fn, update_fn) diff --git a/blackjax/mcmc/mclmc.py b/blackjax/mcmc/mclmc.py index ff9638a1f..0b4730786 100644 --- a/blackjax/mcmc/mclmc.py +++ b/blackjax/mcmc/mclmc.py @@ -15,6 +15,8 @@ from typing import Callable, NamedTuple import jax +import jax.numpy as jnp +import time from blackjax.base import SamplingAlgorithm from blackjax.mcmc.integrators import ( @@ -24,6 +26,7 @@ ) from blackjax.types import ArrayLike, PRNGKey from blackjax.util import generate_unit_vector, pytree_size +from jax.flatten_util import ravel_pytree __all__ = ["MCLMCInfo", "init", "build_kernel", "as_top_level_api"] @@ -43,9 +46,10 @@ class MCLMCInfo(NamedTuple): logdensity: float kinetic_change: float energy_change: float + nonans: bool -def init(position: ArrayLike, logdensity_fn, rng_key): +def init(position: ArrayLike, logdensity_fn, random_generator_arg): if pytree_size(position) < 2: raise ValueError( "The target distribution must have more than 1 dimension for MCLMC." @@ -54,13 +58,18 @@ def init(position: ArrayLike, logdensity_fn, rng_key): return IntegratorState( position=position, - momentum=generate_unit_vector(rng_key, position), + momentum=generate_unit_vector(random_generator_arg, position), logdensity=l, logdensity_grad=g, ) -def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): + +def build_kernel( + integrator, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +): """Build a HMC kernel. Parameters @@ -80,24 +89,57 @@ def build_kernel(logdensity_fn, inverse_mass_matrix, integrator): """ - step = with_isokinetic_maruyama( - integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) - ) def kernel( - rng_key: PRNGKey, state: IntegratorState, L: float, step_size: float + rng_key: PRNGKey, state: IntegratorState, logdensity_fn, L: float, step_size: float, + inverse_mass_matrix, ) -> tuple[IntegratorState, MCLMCInfo]: + # tic = time.time() + step = with_isokinetic_maruyama( + integrator(logdensity_fn=logdensity_fn, inverse_mass_matrix=inverse_mass_matrix) + ) + + # return 0, None + + # state, MCLMCInfo( + + # logdensity=state.logdensity, + # energy_change=0.0, + # kinetic_change=0.0, + # nonans=True + # ) + # raise Exception + + + # jax.debug.print("L {x}", x=L) + # jax.debug.print("step_size {x}", x=step_size) + # jax.debug.print("inverse_mass_matrix {x}", x=inverse_mass_matrix) + + kernel_key, energy_cutoff_key, nan_key = jax.random.split(rng_key, 3) + + (position, momentum, logdensity, logdensitygrad), kinetic_change = step( - state, step_size, L, rng_key + state, step_size, L, kernel_key ) - return IntegratorState( - position, momentum, logdensity, logdensitygrad - ), MCLMCInfo( + energy_change = kinetic_change - logdensity + state.logdensity + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) + + # jax.debug.print("kinetic_change: {x}", x=kinetic_change) + # jax.debug.print("potential energy_change: {x}", x=state.logdensity - logdensity) + + new_state, info = handle_high_energy(state, IntegratorState(position, momentum, logdensity, logdensitygrad), MCLMCInfo( logdensity=logdensity, - energy_change=kinetic_change - logdensity + state.logdensity, + energy_change=energy_change, kinetic_change=kinetic_change, - ) + nonans=True + ), energy_cutoff_key, cutoff = jnp.sqrt(ndims * eev_max_per_dim)) + + new_state, info = handle_nans(state, new_state, info, nan_key) + + # jax.debug.print("Time taken in chain: {x}", x=time.time() - tic) + return new_state, info return kernel @@ -108,6 +150,7 @@ def as_top_level_api( step_size, integrator=isokinetic_mclachlan, inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, ) -> SamplingAlgorithm: """The general mclmc kernel builder (:meth:`blackjax.mcmc.mclmc.build_kernel`, alias `blackjax.mclmc.build_kernel`) can be cumbersome to manipulate. Since most users only need to specify the kernel @@ -155,12 +198,60 @@ def as_top_level_api( A ``SamplingAlgorithm``. """ - kernel = build_kernel(logdensity_fn, inverse_mass_matrix, integrator) + kernel = build_kernel( + + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + ) def init_fn(position: ArrayLike, rng_key: PRNGKey): return init(position, logdensity_fn, rng_key) def update_fn(rng_key, state): - return kernel(rng_key, state, L, step_size) + return kernel(rng_key, state, logdensity_fn, L, step_size, inverse_mass_matrix) return SamplingAlgorithm(init_fn, update_fn) + + +def handle_nans( + previous_state, next_state, info, key +): + + new_momentum = generate_unit_vector(key, previous_state.position) + + nonans = jnp.logical_and(jnp.all(jnp.isfinite(next_state.position)), jnp.all(jnp.isfinite(next_state.momentum))) + + state, info = jax.lax.cond( + nonans, + lambda: (next_state, info), + lambda: (previous_state._replace( + momentum=new_momentum, + ), MCLMCInfo( + logdensity=previous_state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + nonans=nonans + )), + ) + + return state, info +def handle_high_energy( + previous_state, next_state, info, key, cutoff +): + + new_momentum = generate_unit_vector(key, previous_state.position) + + state, info = jax.lax.cond( + jnp.abs(info.energy_change) > cutoff, + lambda: (previous_state._replace( + momentum=new_momentum, + ), MCLMCInfo( + logdensity=previous_state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + nonans=info.nonans + )), + lambda: (next_state, info), + ) + + return state, info \ No newline at end of file diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index c75ecdec6..75399a1e7 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -120,6 +120,7 @@ def kernel( ) -> tuple[hmc.HMCState, NUTSInfo]: """Generate a new sample with the NUTS kernel.""" + metric = metrics.default_metric(inverse_mass_matrix) symplectic_integrator = integrator(logdensity_fn, metric.kinetic_energy) proposal_generator = iterative_nuts_proposal( @@ -132,7 +133,10 @@ def kernel( key_momentum, key_integrator = jax.random.split(rng_key, 2) - position, logdensity, logdensity_grad = state + # position, logdensity, logdensity_grad = state + position = state.position + logdensity = state.logdensity + logdensity_grad = state.logdensity_grad momentum = metric.sample_momentum(key_momentum, position) integrator_state = integrators.IntegratorState( @@ -142,6 +146,9 @@ def kernel( proposal = hmc.HMCState( proposal.position, proposal.logdensity, proposal.logdensity_grad ) + # jax.debug.print("step {x}",x=info.acceptance_rate) + # jax.debug.print("num steps {x}",x=info.num_integration_steps) + # jax.debug.print("step size {x}",x=step_size) return proposal, info return kernel diff --git a/blackjax/mcmc/pseudofermion.py b/blackjax/mcmc/pseudofermion.py new file mode 100644 index 000000000..2bae3f6f4 --- /dev/null +++ b/blackjax/mcmc/pseudofermion.py @@ -0,0 +1,55 @@ +from functools import partial +from typing import Any, Callable, NamedTuple + +import jax +import jax.numpy as jnp +import numpy as np + +import blackjax.mcmc.hmc as hmc +import blackjax.mcmc.integrators as integrators +import blackjax.mcmc.metrics as metrics +import blackjax.mcmc.proposal as proposal +import blackjax.mcmc.termination as termination +import blackjax.mcmc.trajectory as trajectory +from blackjax.base import SamplingAlgorithm +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey +import blackjax + + +def build_kernel(kernel_1, kernel_2, logdensity_fn): + def kernel(state, rng_key): + next_x_state, info = kernel_1( + rng_key=rng_key, + state=state['x'], + step_size=1e-3, + inverse_mass_matrix=jnp.ones(state['x'].position.shape[0]), + num_integration_steps=1, + logdensity_fn=partial(logdensity_fn, pseudofermion=state['y']), + ) + next_y_state = kernel_2(state['y']) + return {'x': next_x_state, 'y': next_y_state}, info + return kernel + +def init(position, logdensity_fn, pseudofermion, init_1, init_2, rng_key ): + key_1, key_2 = jax.random.split(rng_key, 2) + state_b = init_1(position, partial(logdensity_fn, pseudofermion=pseudofermion) ) + state_pf = init_2(pseudofermion) + + return {'x': state_b, 'y': state_pf} + + +def as_top_level_api( + kernel_1, + kernel_2, + init_1, + init_2, + logdensity_fn: Callable, +) -> SamplingAlgorithm: + + # kernel = build_kernel(integrator, divergence_threshold) + + return None + + return SamplingAlgorithm(init_fn, step_fn) + + diff --git a/blackjax/mcmc/termination.py b/blackjax/mcmc/termination.py index eb1276da3..9fa7cee6c 100644 --- a/blackjax/mcmc/termination.py +++ b/blackjax/mcmc/termination.py @@ -33,10 +33,10 @@ def iterative_uturn_numpyro(is_turning: CheckTurning): def new_state(chain_state, max_num_doublings) -> IterativeUTurnState: flat, _ = jax.flatten_util.ravel_pytree(chain_state.position) - num_dims = jnp.shape(flat)[0] + ndims = jnp.shape(flat)[0] return IterativeUTurnState( - jnp.zeros((max_num_doublings, num_dims)), - jnp.zeros((max_num_doublings, num_dims)), + jnp.zeros((max_num_doublings, ndims)), + jnp.zeros((max_num_doublings, ndims)), 0, 0, ) diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 7bb1b35a5..25260e777 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -127,6 +127,34 @@ def one_step(_, state): return integrate +def langevin_integration( + integrator: Callable, + rng_key: PRNGKey, + direction: int = 1, +) -> Callable: + """Generate a trajectory by integrating several times in one direction.""" + + def integrate( + initial_state: IntegratorState, step_size, num_integration_steps + ) -> IntegratorState: + directed_step_size = jax.tree_util.tree_map( + lambda step_size: direction * step_size, step_size + ) + + # jax.debug.print("num_integration_steps {x}", x=num_integration_steps) + + def one_step(_, accum): + state, delta_energy, rng_key = accum + rng_key, key2 = jax.random.split(rng_key) + new_state, (_, delta_energy_new) = integrator(state, directed_step_size, rng_key) + # jax.debug.print("delta_energy 0 {x}", x=delta_energy_new) + return (new_state, delta_energy+delta_energy_new, key2) + + state, delta_energy, _ = jax.lax.fori_loop(0, num_integration_steps, one_step, (initial_state, 0.0, rng_key)) + return state, -delta_energy + + return integrate + class DynamicIntegrationState(NamedTuple): step: int diff --git a/blackjax/mcmc/uhmc.py b/blackjax/mcmc/uhmc.py new file mode 100644 index 000000000..4e6f6daea --- /dev/null +++ b/blackjax/mcmc/uhmc.py @@ -0,0 +1,161 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Underdamped Langevin Kernel""" +from typing import Callable, NamedTuple + +import jax +from blackjax.mcmc.adjusted_mclmc_dynamic import make_random_trajectory_length_fn +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import ( + IntegratorState, + velocity_verlet, +) +from blackjax.types import ArrayLike, PRNGKey +import blackjax.mcmc.metrics as metrics +import jax.numpy as jnp +from blackjax.util import pytree_size, generate_unit_vector +from blackjax.mcmc.underdamped_langevin import handle_high_energy, handle_nans, LangevinInfo +__all__ = ["LangevinInfo", "init", "build_kernel", "as_top_level_api"] + +class UHMCState(NamedTuple): + position: ArrayLike + momentum: ArrayLike + logdensity: float + logdensity_grad: ArrayLike + steps_until_refresh: int + +def integrator_state(state: UHMCState) -> IntegratorState: + return IntegratorState( + position=state.position, + momentum=state.momentum, + logdensity=state.logdensity, + logdensity_grad=state.logdensity_grad, + ) + + +def init(position: ArrayLike, logdensity_fn, random_generator_arg): + + l, g = jax.value_and_grad(logdensity_fn)(position) + + metric = metrics.default_metric(jnp.ones_like(position)) + + return UHMCState( + position=position, + momentum = metric.sample_momentum(random_generator_arg, position), + logdensity=l, + logdensity_grad=g, + steps_until_refresh=0, + ) + + +def build_kernel( + integrator, + desired_energy_var_max_ratio=1e3, + desired_energy_var=5e-4,): + """Build a HMC kernel. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Langevin dynamics. + L + the momentum decoherence rate. + step_size + step size of the integrator. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + + + def kernel( + rng_key: PRNGKey, state: UHMCState, logdensity_fn, L: float, step_size: float, inverse_mass_matrix, + ) -> tuple[UHMCState, LangevinInfo]: + metric = metrics.default_metric(inverse_mass_matrix) + step = integrator(logdensity_fn, metric.kinetic_energy) + + refresh_key, energy_cutoff_key, nan_key, randomization_key = jax.random.split(rng_key, 4) + + (position, momentum, logdensity, logdensitygrad) = step( + integrator_state(state), step_size + ) + + kinetic_change = - metric.kinetic_energy(state.momentum) + metric.kinetic_energy(momentum) + energy_change = kinetic_change - logdensity + state.logdensity + + num_steps_per_traj = make_random_trajectory_length_fn(True)(L/step_size)(randomization_key).astype(jnp.int64) + + momentum = (state.steps_until_refresh==0) * metric.sample_momentum(refresh_key, position) + (state.steps_until_refresh>0) * momentum + + steps_until_refresh = (state.steps_until_refresh==0) * num_steps_per_traj + (state.steps_until_refresh>0) * (state.steps_until_refresh - 1) + + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) + + + + new_state, info = handle_high_energy(state, UHMCState(position, momentum, logdensity, logdensitygrad, steps_until_refresh), LangevinInfo( + logdensity=logdensity, + energy_change=energy_change, + kinetic_change=kinetic_change, + nonans=True + ), energy_cutoff_key, cutoff = jnp.sqrt(ndims * eev_max_per_dim), inverse_mass_matrix=inverse_mass_matrix) + + new_state, info = handle_nans(state, new_state, info, nan_key, inverse_mass_matrix) + return new_state, info + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + L, + step_size, + integrator=velocity_verlet, + inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +) -> SamplingAlgorithm: + """The general Langevin kernel builder (:meth:`blackjax.mcmc.langevin.build_kernel`, alias `blackjax.langevin.build_kernel`) can be + cumbersome to manipulate. Since most users only need to specify the kernel + parameters at initialization time, we provide a helper function that + specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.langevin` to SMC, adaptation, etc. algorithms. + + + + """ + + kernel = build_kernel( + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + desired_energy_var=desired_energy_var, + ) + # metric = metrics.default_metric(inverse_mass_matrix) + + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key, state): + return kernel(rng_key, state, logdensity_fn, L, step_size, inverse_mass_matrix) + + return SamplingAlgorithm(init_fn, update_fn) + + diff --git a/blackjax/mcmc/underdamped_langevin.py b/blackjax/mcmc/underdamped_langevin.py new file mode 100644 index 000000000..6bf01591a --- /dev/null +++ b/blackjax/mcmc/underdamped_langevin.py @@ -0,0 +1,203 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Underdamped Langevin Kernel""" +from typing import Callable, NamedTuple + +import jax + +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.integrators import ( + IntegratorState, + with_maruyama, + velocity_verlet, +) +from blackjax.types import ArrayLike, PRNGKey +import blackjax.mcmc.metrics as metrics +import jax.numpy as jnp +from blackjax.util import pytree_size +from blackjax.adaptation.mclmc_adaptation import handle_high_energy +__all__ = ["LangevinInfo", "init", "build_kernel", "as_top_level_api"] + + +class LangevinInfo(NamedTuple): + """ + Additional information on the Langevin transition. + + logdensity + The log-density of the distribution at the current step of the Langevin chain. + kinetic_change + The difference in kinetic energy between the current and previous step. + energy_change + The difference in energy between the current and previous step. + """ + + logdensity: float + kinetic_change: float + energy_change: float + nonans : bool + +def init(position: ArrayLike, logdensity_fn, random_generator_arg): + + l, g = jax.value_and_grad(logdensity_fn)(position) + + metric = metrics.default_metric(jnp.ones_like(position)) + + return IntegratorState( + position=position, + momentum = metric.sample_momentum(random_generator_arg, position), + logdensity=l, + logdensity_grad=g, + ) + + +def build_kernel( + integrator, + desired_energy_var_max_ratio=1e3, + desired_energy_var=5e-4,): + """Build a HMC kernel. + + Parameters + ---------- + integrator + The symplectic integrator to use to integrate the Langevin dynamics. + L + the momentum decoherence rate. + step_size + step size of the integrator. + + Returns + ------- + A kernel that takes a rng_key and a Pytree that contains the current state + of the chain and that returns a new state of the chain along with + information about the transition. + + """ + + + def kernel( + rng_key: PRNGKey, state: IntegratorState, logdensity_fn, L: float, step_size: float, inverse_mass_matrix, + ) -> tuple[IntegratorState, LangevinInfo]: + metric = metrics.default_metric(inverse_mass_matrix) + step = with_maruyama(integrator(logdensity_fn, metric.kinetic_energy), metric.kinetic_energy,inverse_mass_matrix) + + (position, momentum, logdensity, logdensitygrad), (kinetic_change, energy_error) = step( + state, step_size, L, rng_key + ) + + + eev_max_per_dim = desired_energy_var_max_ratio * desired_energy_var + ndims = pytree_size(position) + + energy_key, nan_key = jax.random.split(rng_key) + + new_state, info = handle_high_energy( + previous_state=state, + next_state=IntegratorState(position, momentum, logdensity, logdensitygrad), + info=LangevinInfo( + logdensity=logdensity, + energy_change=energy_error, + kinetic_change=kinetic_change, + nonans=True + ), + key=energy_key, + inverse_mass_matrix=inverse_mass_matrix, + cutoff=jnp.sqrt(ndims * eev_max_per_dim), + ) + + new_state, info = handle_nans(state, new_state, info, nan_key, inverse_mass_matrix) + return new_state, info + + + return kernel + + +def as_top_level_api( + logdensity_fn: Callable, + L, + step_size, + integrator=velocity_verlet, + inverse_mass_matrix=1.0, + desired_energy_var_max_ratio=jnp.inf, + desired_energy_var=5e-4, +) -> SamplingAlgorithm: + """The general Langevin kernel builder (:meth:`blackjax.mcmc.langevin.build_kernel`, alias `blackjax.langevin.build_kernel`) can be + cumbersome to manipulate. Since most users only need to specify the kernel + parameters at initialization time, we provide a helper function that + specializes the general kernel. + + We also add the general kernel and state generator as an attribute to this class so + users only need to pass `blackjax.langevin` to SMC, adaptation, etc. algorithms. + """ + + kernel = build_kernel( + integrator, + desired_energy_var_max_ratio=desired_energy_var_max_ratio, + desired_energy_var=desired_energy_var, + ) + + def init_fn(position: ArrayLike, rng_key: PRNGKey): + return init(position, logdensity_fn, rng_key) + + def update_fn(rng_key, state): + return kernel( + rng_key=rng_key, state=state, logdensity_fn=logdensity_fn, L=L, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix) + + return SamplingAlgorithm(init_fn, update_fn) + + +def handle_nans( + previous_state, next_state, info, key, inverse_mass_matrix +): + + metric = metrics.default_metric(inverse_mass_matrix) + new_momentum = metric.sample_momentum(key, previous_state.position) + + nonans = jnp.logical_and(jnp.all(jnp.isfinite(next_state.position)), jnp.all(jnp.isfinite(next_state.momentum))) + + state, info = jax.lax.cond( + nonans, + lambda: (next_state, info), + lambda: (previous_state._replace( + momentum=new_momentum, + ), LangevinInfo( + logdensity=previous_state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + nonans=nonans + )), + ) + + return state, info + +def handle_high_energy( + previous_state, next_state, info, key, cutoff, inverse_mass_matrix +): + + metric = metrics.default_metric(inverse_mass_matrix) + new_momentum = metric.sample_momentum(key, previous_state.position) + + state, info = jax.lax.cond( + jnp.abs(info.energy_change) > cutoff, + lambda: (previous_state._replace( + momentum=new_momentum, + ), LangevinInfo( + logdensity=previous_state.logdensity, + energy_change=0.0, + kinetic_change=0.0, + nonans=False + )), + lambda: (next_state, info), + ) + + return state, info \ No newline at end of file diff --git a/blackjax/smc/tuning/from_kernel_info.py b/blackjax/smc/tuning/from_kernel_info.py index fa2c7054c..13db64a37 100644 --- a/blackjax/smc/tuning/from_kernel_info.py +++ b/blackjax/smc/tuning/from_kernel_info.py @@ -3,6 +3,7 @@ strategies to tune the parameters of mcmc kernels used within smc, based on MCMC states """ + import jax import jax.numpy as jnp diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 505e7f3a1..74c4e67ef 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -2,6 +2,7 @@ static (all particles get the same value) strategies to tune the parameters of mcmc kernels used within SMC, based on particles. """ + import jax import jax.numpy as jnp from jax._src.flatten_util import ravel_pytree diff --git a/blackjax/util.py b/blackjax/util.py index 8cdcd45ee..b8a00f9d6 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -4,15 +4,19 @@ from typing import Callable, Union import jax.numpy as jnp -from jax import jit, lax +from jax import device_put, jit, lax, vmap +from jax.experimental.shard_map import shard_map from jax.flatten_util import ravel_pytree from jax.random import normal, split +from jax.sharding import NamedSharding, PartitionSpec from jax.tree_util import tree_leaves, tree_map from blackjax.base import SamplingAlgorithm, VIAlgorithm from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey - +from blackjax.diagnostics import splitR +import time +import jax @partial(jit, static_argnames=("precision",), inline=True) def linear_map(diag_or_dense_a, b, *, precision="highest"): @@ -196,6 +200,7 @@ def run_inference_algorithm( keys = split(rng_key, num_steps) + def one_step(state, xs): _, rng_key = xs state, info = inference_algorithm.step(rng_key, state) @@ -204,7 +209,9 @@ def one_step(state, xs): scan_fn = gen_scan_fn(num_steps, progress_bar) xs = jnp.arange(num_steps), keys + # tic = time.time() final_state, history = scan_fn(one_step, initial_state, xs) + # toc = time.time() return final_state, history @@ -314,3 +321,312 @@ def incremental_value_update( ) total += weight return total, average + + +def eca_step( + kernel, summary_statistics_fn, adaptation_update, num_chains, superchain_size = None, ensemble_info=None +): + """ + Construct a single step of ensemble chain adaptation (eca) to be performed in parallel on multiple devices. + """ + + def step(state_all, xs): + """This function operates on a single device.""" + ( + state, + adaptation_state, + ) = state_all # state is an array of states, one for each chain on this device. adaptation_state is the same for all chains, so it is not an array. + ( + _, + keys_sampling, + key_adaptation, + ) = xs # keys_sampling.shape = (chains_per_device, ) + + # update the state of all chains on this device + state, info = vmap(kernel, (0, 0, None))(keys_sampling, state, adaptation_state) + + # combine all the chains to compute expectation values + theta = vmap(summary_statistics_fn, (0, 0, None))(state, info, key_adaptation) + Etheta = tree_map( + lambda theta: lax.psum(jnp.sum(theta, axis=0), axis_name="chains") + / num_chains, + theta, + ) + + # use these to adapt the hyperparameters of the dynamics + adaptation_state, info_to_be_stored = adaptation_update( + adaptation_state, Etheta + ) + + return (state, adaptation_state), info_to_be_stored + + + return add_ensemble_info(add_splitR(step, num_chains, superchain_size), ensemble_info) + + +def add_splitR(step, num_chains, superchain_size): + + + def _step_with_R(state_all, xs): + + state_all, info_to_be_stored = step(state_all, xs) + + state, adaptation_state = state_all + + R = splitR(state.position, num_chains, superchain_size) + split_bavg = jnp.average(jnp.square(R) - 1) + split_bmax = jnp.max(jnp.square(R) - 1) + + info_to_be_stored['R_avg'] = split_bavg + info_to_be_stored['R_max'] = split_bmax + + return (state, adaptation_state), info_to_be_stored + + def _step_with_R_1(state_all, xs): + + state_all, info_to_be_stored = step(state_all, xs) + + info_to_be_stored['R_avg'] = 0. + info_to_be_stored['R_max'] = 0. + + return state_all, info_to_be_stored + + if superchain_size == None: + return step + + if superchain_size == 1: + return _step_with_R_1 + + else: + return _step_with_R + + +def add_ensemble_info(step, ensemble_info): + + def _step(state_all, xs): + (state, adaptation_state), info_to_be_stored = step(state_all, xs) + return (state, adaptation_state), (info_to_be_stored, vmap(ensemble_info)(state.position)) + + return _step if ensemble_info is not None else step + + + +def run_eca( + rng_key, + initial_state, + kernel, + adaptation, + num_steps, + num_chains, + mesh, + superchain_size= None, + ensemble_info=None, + early_stop=False, +): + """ + Run ensemble chain adaptation (eca) in parallel on multiple devices. + ----------------------------------------------------- + Args: + rng_key: random key + initial_state: initial state of the system + kernel: kernel for the dynamics + adaptation: adaptation object + num_steps: number of steps to run + num_chains: number of chains + mesh: mesh for parallelization + ensemble_info: function that takes the state of the system and returns some information about the ensemble + early_stop: whether to stop early + Returns: + final_state: final state of the system + final_adaptation_state: final adaptation state + info_history: history of the information that was stored at each step (if early_stop is False, then this is None) + """ + + step = eca_step( + kernel, + adaptation.summary_statistics_fn, + adaptation.update, + num_chains, + superchain_size= superchain_size, + ensemble_info = ensemble_info, + ) + + def all_steps(initial_state, keys_sampling, keys_adaptation): + """This function operates on a single device. key is a random key for this device.""" + + initial_state_all = (initial_state, adaptation.initial_state) + + # run sampling + xs = ( + jnp.arange(num_steps), + keys_sampling.T, + keys_adaptation, + ) # keys for all steps that will be performed. keys_sampling.shape = (num_steps, chains_per_device), keys_adaptation.shape = (num_steps, ) + + EEVPD = jnp.zeros((num_steps,)) + EEVPD_wanted = jnp.zeros((num_steps,)) + L = jnp.zeros((num_steps,)) + entropy = jnp.zeros((num_steps,)) + equi_diag = jnp.zeros((num_steps,)) + equi_full = jnp.zeros((num_steps,)) + bias0 = jnp.zeros((num_steps,)) + bias1 = jnp.zeros((num_steps,)) + observables = jnp.zeros((num_steps,)) + r_avg = jnp.zeros((num_steps,)) + r_max = jnp.zeros((num_steps,)) + R_avg = jnp.zeros((num_steps,)) + R_max = jnp.zeros((num_steps,)) + step_size = jnp.zeros((num_steps,)) + + def step_while(a): + x, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, R_avg, R_max, step_size = a + + auxilliary_input = (xs[0][i], xs[1][i], xs[2][i]) + + output, (info, pos) = step(x, auxilliary_input) + new_EEVPD = EEVPD.at[i].set(info.get("EEVPD")) + new_EEVPD_wanted = EEVPD_wanted.at[i].set(info.get("EEVPD_wanted")) + new_L = L.at[i].set(info.get("L")) + new_entropy = entropy.at[i].set(info.get("entropy")) + new_equi_diag = equi_diag.at[i].set(info.get("equi_diag")) + new_equi_full = equi_full.at[i].set(info.get("equi_full")) + new_bias0 = bias0.at[i].set(info.get("bias")[0]) + new_bias1 = bias1.at[i].set(info.get("bias")[1]) + new_observables = observables.at[i].set(info.get("observables")) + new_r_avg = r_avg.at[i].set(info.get("r_avg")) + new_r_max = r_max.at[i].set(info.get("r_max")) + new_R_avg = R_avg.at[i].set(info.get("R_avg")) + new_R_max = R_max.at[i].set(info.get("R_max")) + new_step_size = step_size.at[i].set(info.get("step_size")) + + return (output, i + 1, + (info.get("r_max") > adaptation.r_end) | (i < adaptation.save_num), # while is run while this is True + new_EEVPD, new_EEVPD_wanted, new_L, new_entropy, new_equi_diag, new_equi_full, new_bias0, new_bias1, new_observables, new_r_avg, new_r_max, new_R_avg, new_R_max, new_step_size) + + if early_stop: + final_state_all, i, _, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, R_avg, R_max, step_size = lax.while_loop( + lambda a: ((a[1] < num_steps) & a[2]), + step_while, + (initial_state_all, 0, True, EEVPD, EEVPD_wanted, L, entropy, equi_diag, equi_full, bias0, bias1, observables, r_avg, r_max, R_avg, R_max, step_size), + ) + steps_done = i + info_history = { + "EEVPD": EEVPD, + "EEVPD_wanted": EEVPD_wanted, + "L": L, + "entropy": entropy, + "equi_diag": equi_diag, + "equi_full": equi_full, + "bias0": bias0, + "bias1": bias1, + "observables": observables, + "r_avg": r_avg, + "r_max": r_max, + "R_avg": R_avg, + "R_max": R_max, + "step_size": step_size, + "steps_done": steps_done, + } + + else: + final_state_all, info_history = lax.scan(step, initial_state_all, xs) + steps_done = num_steps + + final_state, final_adaptation_state = final_state_all + return ( + final_state, + final_adaptation_state, + info_history, + steps_done, + ) # info history is composed of averages over all chains, so it is a couple of scalars + + p, pscalar = PartitionSpec("chains"), PartitionSpec() + parallel_execute = shard_map( + all_steps, + mesh=mesh, + in_specs=(p, p, pscalar), + out_specs=(p, pscalar, pscalar, pscalar), + check_rep=False, + ) + + # produce all random keys that will be needed + + key_sampling, key_adaptation = split(rng_key) + num_steps = jnp.array(num_steps).item() + keys_adaptation = split(key_adaptation, num_steps) + distribute_keys = lambda key, shape: device_put( + split(key, shape), NamedSharding(mesh, p) + ) # random keys, distributed across devices + keys_sampling = distribute_keys(key_sampling, (num_chains, num_steps)) + + # run sampling in parallel + final_state, final_adaptation_state, info_history, steps_done = parallel_execute( + initial_state, keys_sampling, keys_adaptation + ) + + return final_state, final_adaptation_state, info_history, steps_done + + +def ensemble_execute_fn( + func, + rng_key, + num_chains, + mesh, + x=None, + args=None, + summary_statistics_fn=lambda y: 0.0, + superchain_size = None +): + """Given a sequential function + func(rng_key, x, args) = y, + evaluate it with an ensemble and also compute some summary statistics E[theta(y)], where expectation is taken over ensemble. + Args: + x: array distributed over all decvices + args: additional arguments for func, not distributed. + summary_statistics_fn: operates on a single member of ensemble and returns some summary statistics. + rng_key: a single random key, which will then be split, such that each member of an ensemble will get a different random key. + + Returns: + y: array distributed over all decvices. Need not be of the same shape as x. + Etheta: expected values of the summary statistics + """ + p, pscalar = PartitionSpec("chains"), PartitionSpec() + + if x is None: + X = device_put(jnp.zeros(num_chains), NamedSharding(mesh, p)) + else: + X = x + + adaptation_update = lambda _, Etheta: (Etheta, None) + + _F = eca_step( + func, + lambda y, info, key: summary_statistics_fn(y), + adaptation_update, + num_chains, + ) + + def F(x, keys): + """This function operates on a single device. key is a random key for this device.""" + y, summary_statistics = _F((x, args), (None, keys, None))[0] + return y, summary_statistics + + parallel_execute = shard_map( + F, mesh=mesh, in_specs=(p, p), out_specs=(p, pscalar), check_rep=False + ) + + if superchain_size == 1: + _keys = split(rng_key, num_chains) + + else: + _keys = jnp.repeat(split(rng_key, num_chains // superchain_size), superchain_size) + + + keys = device_put(_keys, NamedSharding(mesh, p)) # random keys, distributed across devices + + # apply F in parallel + return parallel_execute(X, keys) + + + + diff --git a/tests/adaptation/test_mass_matrix.py b/tests/adaptation/test_mass_matrix.py index 622b2111c..97d6ea882 100644 --- a/tests/adaptation/test_mass_matrix.py +++ b/tests/adaptation/test_mass_matrix.py @@ -1,4 +1,5 @@ """Test the welford adaptation algorithm.""" + import itertools import chex diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 4d8a9fa61..d2cbd1501 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -5,6 +5,8 @@ import chex import jax + +# jax.config.update("jax_traceback_filtering", "off") import jax.numpy as jnp import jax.scipy.stats as stats import numpy as np @@ -15,6 +17,7 @@ import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info +from blackjax.adaptation.ensemble_mclmc import emaus from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.mcmc.integrators import isokinetic_mclachlan from blackjax.util import run_inference_algorithm @@ -289,6 +292,43 @@ def run_adjusted_mclmc_static( return out + def run_emaus( + self, + sample_init, + logdensity_fn, + ndims, + key, + diagonal_preconditioning, + ): + mesh = jax.sharding.Mesh(devices=jax.devices(), axis_names="chains") + + from blackjax.mcmc.integrators import mclachlan_coefficients + + integrator_coefficients = mclachlan_coefficients + + info, grads_per_step, _acc_prob, final_state = emaus( + logdensity_fn=logdensity_fn, + sample_init=sample_init, + ndims=ndims, + num_steps1=100, + num_steps2=300, + num_chains=800, + mesh=mesh, + rng_key=key, + alpha=1.9, + C=0.1, + early_stop=1, + r_end=1e-2, + diagonal_preconditioning=diagonal_preconditioning, + integrator_coefficients=integrator_coefficients, + steps_per_sample=15, + acc_prob=None, + ensemble_observables=lambda x: x, + # ensemble_observables = lambda x: vec @ x + ) # run the algorithm + + return final_state.position + @parameterized.parameters( itertools.product( regression_test_cases, [True, False], window_adaptation_filters @@ -535,6 +575,43 @@ def get_inverse_mass_matrix(): < 0.1 ) + def test_emaus( + self, + ): + """Test the MCLMC kernel.""" + + init_key0, init_key1, inference_key = jax.random.split(self.key, 3) + + x_data = jax.random.normal(init_key0, shape=(1000, 1)) + y_data = 3 * x_data + jax.random.normal(init_key1, shape=x_data.shape) + + logposterior_fn_ = functools.partial( + self.regression_logprob, x=x_data, preds=y_data + ) + logdensity_fn = lambda x: logposterior_fn_( + coefs=x["coefs"][0], log_scale=x["log_scale"][0] + ) + + def sample_init(key): + key1, key2 = jax.random.split(key) + coefs = jax.random.uniform(key1, shape=(1,), minval=1, maxval=2) + log_scale = jax.random.uniform(key2, shape=(1,), minval=1, maxval=2) + return {"coefs": coefs, "log_scale": log_scale} + + samples = self.run_emaus( + sample_init=sample_init, + logdensity_fn=logdensity_fn, + ndims=2, + key=inference_key, + diagonal_preconditioning=True, + ) + + coefs_samples = samples["coefs"] + scale_samples = np.exp(samples["log_scale"]) + + np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) + @parameterized.parameters(regression_test_cases) def test_pathfinder_adaptation( self, diff --git a/tests/mcmc/test_trajectory.py b/tests/mcmc/test_trajectory.py index e93280400..bc8490b19 100644 --- a/tests/mcmc/test_trajectory.py +++ b/tests/mcmc/test_trajectory.py @@ -1,4 +1,5 @@ """Test the trajectory integration""" + import chex import jax import jax.numpy as jnp diff --git a/tests/mcmc/test_uturn.py b/tests/mcmc/test_uturn.py index 7f9f597d6..ff1f261f6 100644 --- a/tests/mcmc/test_uturn.py +++ b/tests/mcmc/test_uturn.py @@ -1,4 +1,5 @@ """Test the iterative u-turn criterion.""" + import chex import jax.numpy as jnp from absl.testing import absltest, parameterized diff --git a/tests/optimizers/test_optimizers.py b/tests/optimizers/test_optimizers.py index a7549842f..47f437af2 100644 --- a/tests/optimizers/test_optimizers.py +++ b/tests/optimizers/test_optimizers.py @@ -1,4 +1,5 @@ """Test optimizers.""" + import functools import chex diff --git a/tests/optimizers/test_pathfinder.py b/tests/optimizers/test_pathfinder.py index b9b9c69be..f40e79410 100644 --- a/tests/optimizers/test_pathfinder.py +++ b/tests/optimizers/test_pathfinder.py @@ -1,4 +1,5 @@ """Test the pathfinder algorithm.""" + import functools import chex diff --git a/tests/smc/test_resampling.py b/tests/smc/test_resampling.py index 20cb0d813..e6570f8f6 100644 --- a/tests/smc/test_resampling.py +++ b/tests/smc/test_resampling.py @@ -1,4 +1,5 @@ """Test the resampling functions for SMC.""" + import itertools import chex diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 769078c8d..7ab6350ec 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -1,4 +1,5 @@ """Test the generic SMC sampler""" + import functools import chex diff --git a/tests/smc/test_smc_ess.py b/tests/smc/test_smc_ess.py index 570d392d9..1f02b8d61 100644 --- a/tests/smc/test_smc_ess.py +++ b/tests/smc/test_smc_ess.py @@ -1,4 +1,5 @@ """Test the ess function""" + import functools import chex diff --git a/tests/smc/test_solver.py b/tests/smc/test_solver.py index 49db84129..8bcdd6a07 100644 --- a/tests/smc/test_solver.py +++ b/tests/smc/test_solver.py @@ -1,4 +1,5 @@ """Test the solving functions""" + import itertools import chex diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index 527457d62..ef8b8cb08 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -1,4 +1,5 @@ """Test the tempered SMC steps and routine""" + import functools import chex @@ -79,9 +80,11 @@ def logprior_fn(x): base_params, jax.tree.map(lambda x: jnp.repeat(x, num_particles, axis=0), base_params), jax.tree_util.tree_map_with_path( - lambda path, x: jnp.repeat(x, num_particles, axis=0) - if path[0].key == "step_size" - else x, + lambda path, x: ( + jnp.repeat(x, num_particles, axis=0) + if path[0].key == "step_size" + else x + ), base_params, ), ] diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 2d108a48d..ea9c6aa66 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -5,6 +5,7 @@ obviously more models. It should also be run in CI. """ + import functools import jax diff --git a/tests/test_compilation.py b/tests/test_compilation.py index 7179b71ba..2d3b03a44 100644 --- a/tests/test_compilation.py +++ b/tests/test_compilation.py @@ -5,6 +5,7 @@ internal changes do not trigger more compilations than is necessary. """ + import chex import jax import jax.numpy as jnp diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py index b583c8645..1d7c74846 100644 --- a/tests/test_diagnostics.py +++ b/tests/test_diagnostics.py @@ -1,4 +1,5 @@ """Test MCMC diagnostics.""" + import functools import itertools diff --git a/velocity_verlet_gaussian.ipynb b/velocity_verlet_gaussian.ipynb new file mode 100644 index 000000000..cff0442a8 --- /dev/null +++ b/velocity_verlet_gaussian.ipynb @@ -0,0 +1,429 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Velocity Verlet Integrator for a Gaussian Target\n", + "\n", + "This notebook demonstrates how to use the velocity verlet integrator from Blackjax to simulate Hamiltonian dynamics for a Gaussian target distribution." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:2025-04-22 11:29:00,948:jax._src.xla_bridge:909: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.scipy.stats as stats\n", + "\n", + "import blackjax\n", + "from blackjax.mcmc import integrators\n", + "from blackjax.mcmc import metrics\n", + "from blackjax.types import ArrayTree\n", + "\n", + "# Set random seed for reproducibility\n", + "rng_key = jax.random.key(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define a Gaussian Target Distribution\n", + "\n", + "We'll use a 2D Gaussian distribution as our target. The log-density function is defined as:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Define the parameters of the Gaussian\n", + "mean = jnp.array([0.0, 0.0])\n", + "cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])\n", + "\n", + "# Define the log-density function\n", + "def logdensity_fn(position):\n", + " return stats.multivariate_normal.logpdf(position, mean, cov)\n", + "\n", + "# Visualize the target distribution\n", + "def plot_gaussian():\n", + " x = np.linspace(-4, 4, 100)\n", + " y = np.linspace(-4, 4, 100)\n", + " X, Y = np.meshgrid(x, y)\n", + " Z = np.zeros((100, 100))\n", + " for i in range(100):\n", + " for j in range(100):\n", + " Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))\n", + " \n", + " plt.figure(figsize=(10, 8))\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.colorbar(label='Density')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('2D Gaussian Target Distribution')\n", + " plt.axis('equal')\n", + " plt.show()\n", + "\n", + "plot_gaussian()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set Up the Velocity Verlet Integrator\n", + "\n", + "Now, let's set up the velocity verlet integrator. We need to:\n", + "1. Define a kinetic energy function\n", + "2. Create an integrator state\n", + "3. Run the integrator for a few steps" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the inverse mass matrix (for the kinetic energy)\n", + "inverse_mass_matrix = jnp.eye(2) # Identity matrix for simplicity\n", + "\n", + "# Create a metric object\n", + "metric = metrics.default_metric(inverse_mass_matrix)\n", + "\n", + "# Get the kinetic energy function from the metric\n", + "kinetic_energy_fn = metric.kinetic_energy\n", + "\n", + "# Create the velocity verlet integrator\n", + "integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn)\n", + "\n", + "# Set the initial position and momentum\n", + "initial_position = jnp.array([2.0, 2.0])\n", + "initial_momentum = jnp.array([0.5, -0.3])\n", + "\n", + "# Create the initial integrator state\n", + "initial_state = integrators.new_integrator_state(logdensity_fn, initial_position, initial_momentum)\n", + "\n", + "# Set the step size and number of steps\n", + "step_size = 0.1\n", + "num_steps = 50" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "IntegratorState(position=Array([2.0414286, 1.9671428], dtype=float32), momentum=Array([ 0.3257347 , -0.35561225], dtype=float32), logdensity=Array(-4.457322, dtype=float32), logdensity_grad=Array([-1.7710204 , -0.54081637], dtype=float32))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "integrator(initial_state, step_size)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the Integrator and Visualize the Trajectory\n", + "\n", + "Now, let's run the integrator for a few steps and visualize the trajectory:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'jaxlib.xla_extension.ArrayImpl' object has no attribute 'position'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 14\u001b[0m\n\u001b[1;32m 11\u001b[0m states \u001b[38;5;241m=\u001b[39m run_integrator(initial_state, integrator, step_size, num_steps)\n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# Extract positions and momenta\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m positions \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray(\u001b[43m[\u001b[49m\u001b[43mstate\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mposition\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mstates\u001b[49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 15\u001b[0m momenta \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray([state\u001b[38;5;241m.\u001b[39mmomentum \u001b[38;5;28;01mfor\u001b[39;00m state \u001b[38;5;129;01min\u001b[39;00m states])\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# Visualize the trajectory\u001b[39;00m\n", + "Cell \u001b[0;32mIn[5], line 14\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 11\u001b[0m states \u001b[38;5;241m=\u001b[39m run_integrator(initial_state, integrator, step_size, num_steps)\n\u001b[1;32m 13\u001b[0m \u001b[38;5;66;03m# Extract positions and momenta\u001b[39;00m\n\u001b[0;32m---> 14\u001b[0m positions \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray([\u001b[43mstate\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mposition\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m state \u001b[38;5;129;01min\u001b[39;00m states])\n\u001b[1;32m 15\u001b[0m momenta \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray([state\u001b[38;5;241m.\u001b[39mmomentum \u001b[38;5;28;01mfor\u001b[39;00m state \u001b[38;5;129;01min\u001b[39;00m states])\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# Visualize the trajectory\u001b[39;00m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'position'" + ] + } + ], + "source": [ + "# Function to run the integrator for multiple steps\n", + "def run_integrator(initial_state, integrator, step_size, num_steps):\n", + " def one_step(state, _):\n", + " new_state = integrator(state, step_size)\n", + " return new_state, new_state\n", + " \n", + " _, states = jax.lax.scan(one_step, initial_state, None, length=num_steps)\n", + " return states\n", + "\n", + "# Run the integrator\n", + "states = run_integrator(initial_state, integrator, step_size, num_steps)\n", + "\n", + "# Extract positions and momenta\n", + "positions = jnp.array([state.position for state in states])\n", + "momenta = jnp.array([state.momentum for state in states])\n", + "\n", + "# Visualize the trajectory\n", + "def plot_trajectory(positions, momenta):\n", + " x = np.linspace(-4, 4, 100)\n", + " y = np.linspace(-4, 4, 100)\n", + " X, Y = np.meshgrid(x, y)\n", + " Z = np.zeros((100, 100))\n", + " for i in range(100):\n", + " for j in range(100):\n", + " Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))\n", + " \n", + " plt.figure(figsize=(12, 10))\n", + " \n", + " # Plot the target distribution\n", + " plt.subplot(2, 2, 1)\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.colorbar(label='Density')\n", + " plt.plot(positions[:, 0], positions[:, 1], 'r-', label='Trajectory')\n", + " plt.plot(positions[0, 0], positions[0, 1], 'go', label='Start')\n", + " plt.plot(positions[-1, 0], positions[-1, 1], 'bo', label='End')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('Position Trajectory')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " \n", + " # Plot the momentum trajectory\n", + " plt.subplot(2, 2, 2)\n", + " plt.plot(momenta[:, 0], momenta[:, 1], 'b-', label='Trajectory')\n", + " plt.plot(momenta[0, 0], momenta[0, 1], 'go', label='Start')\n", + " plt.plot(momenta[-1, 0], momenta[-1, 1], 'ro', label='End')\n", + " plt.xlabel('p_x')\n", + " plt.ylabel('p_y')\n", + " plt.title('Momentum Trajectory')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " \n", + " # Plot position vs time\n", + " plt.subplot(2, 2, 3)\n", + " time = np.arange(num_steps) * step_size\n", + " plt.plot(time, positions[:, 0], 'r-', label='x')\n", + " plt.plot(time, positions[:, 1], 'b-', label='y')\n", + " plt.xlabel('Time')\n", + " plt.ylabel('Position')\n", + " plt.title('Position vs Time')\n", + " plt.legend()\n", + " \n", + " # Plot momentum vs time\n", + " plt.subplot(2, 2, 4)\n", + " plt.plot(time, momenta[:, 0], 'r-', label='p_x')\n", + " plt.plot(time, momenta[:, 1], 'b-', label='p_y')\n", + " plt.xlabel('Time')\n", + " plt.ylabel('Momentum')\n", + " plt.title('Momentum vs Time')\n", + " plt.legend()\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "plot_trajectory(positions, momenta)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Energy Conservation\n", + "\n", + "One of the key properties of Hamiltonian dynamics is energy conservation. Let's check if our integrator preserves energy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate the total energy (potential + kinetic) at each step\n", + "def calculate_energy(states):\n", + " # Potential energy is the negative log-density\n", + " potential_energy = -jnp.array([state.logdensity for state in states])\n", + " \n", + " # Kinetic energy\n", + " kinetic_energy = jnp.array([kinetic_energy_fn(state.momentum) for state in states])\n", + " \n", + " # Total energy\n", + " total_energy = potential_energy + kinetic_energy\n", + " \n", + " return potential_energy, kinetic_energy, total_energy\n", + "\n", + "potential_energy, kinetic_energy, total_energy = calculate_energy(states)\n", + "\n", + "# Plot the energy components\n", + "plt.figure(figsize=(12, 6))\n", + "time = np.arange(num_steps) * step_size\n", + "plt.plot(time, potential_energy, 'r-', label='Potential Energy')\n", + "plt.plot(time, kinetic_energy, 'b-', label='Kinetic Energy')\n", + "plt.plot(time, total_energy, 'g-', label='Total Energy')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Energy')\n", + "plt.title('Energy Conservation')\n", + "plt.legend()\n", + "plt.grid(True)\n", + "plt.show()\n", + "\n", + "# Calculate the relative energy error\n", + "initial_total_energy = total_energy[0]\n", + "relative_energy_error = jnp.abs(total_energy - initial_total_energy) / jnp.abs(initial_total_energy)\n", + "max_relative_error = jnp.max(relative_energy_error)\n", + "print(f\"Maximum relative energy error: {max_relative_error:.6f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Time Reversibility\n", + "\n", + "Another important property of Hamiltonian dynamics is time reversibility. Let's check if our integrator is time-reversible by running it forward and then backward:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Function to run the integrator backward (by negating the step size)\n", + "def run_integrator_backward(final_state, integrator, step_size, num_steps):\n", + " def one_step(state, _):\n", + " new_state = integrator(state, -step_size) # Negative step size for backward integration\n", + " return new_state, new_state\n", + " \n", + " _, states = jax.lax.scan(one_step, final_state, None, length=num_steps)\n", + " return states\n", + "\n", + "# Run the integrator forward\n", + "forward_states = run_integrator(initial_state, integrator, step_size, num_steps)\n", + "final_state = forward_states[-1]\n", + "\n", + "# Run the integrator backward\n", + "backward_states = run_integrator_backward(final_state, integrator, step_size, num_steps)\n", + "reversed_state = backward_states[-1]\n", + "\n", + "# Check if we've returned to the initial state\n", + "position_error = jnp.linalg.norm(initial_state.position - reversed_state.position)\n", + "momentum_error = jnp.linalg.norm(initial_state.momentum - reversed_state.momentum)\n", + "\n", + "print(f\"Position error: {position_error:.10f}\")\n", + "print(f\"Momentum error: {momentum_error:.10f}\")\n", + "\n", + "# Visualize the forward and backward trajectories\n", + "def plot_reversibility(forward_states, backward_states):\n", + " forward_positions = jnp.array([state.position for state in forward_states])\n", + " backward_positions = jnp.array([state.position for state in backward_states])\n", + " \n", + " x = np.linspace(-4, 4, 100)\n", + " y = np.linspace(-4, 4, 100)\n", + " X, Y = np.meshgrid(x, y)\n", + " Z = np.zeros((100, 100))\n", + " for i in range(100):\n", + " for j in range(100):\n", + " Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))\n", + " \n", + " plt.figure(figsize=(10, 8))\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.colorbar(label='Density')\n", + " plt.plot(forward_positions[:, 0], forward_positions[:, 1], 'r-', label='Forward')\n", + " plt.plot(backward_positions[:, 0], backward_positions[:, 1], 'b--', label='Backward')\n", + " plt.plot(forward_positions[0, 0], forward_positions[0, 1], 'go', label='Start')\n", + " plt.plot(forward_positions[-1, 0], forward_positions[-1, 1], 'bo', label='End')\n", + " plt.plot(backward_positions[-1, 0], backward_positions[-1, 1], 'ro', label='Reversed')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('Time Reversibility')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " plt.show()\n", + "\n", + "plot_reversibility(forward_states, backward_states)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "In this notebook, we've demonstrated how to use the velocity verlet integrator from Blackjax to simulate Hamiltonian dynamics for a Gaussian target distribution. We've shown that:\n", + "\n", + "1. The integrator can be used to simulate the trajectory of a particle in the potential energy landscape defined by the negative log-density of the target distribution.\n", + "2. The integrator approximately conserves energy, with small errors due to the numerical approximation.\n", + "3. The integrator is time-reversible, meaning that running it forward and then backward returns to the initial state (up to numerical errors).\n", + "\n", + "These properties make the velocity verlet integrator a good choice for Hamiltonian Monte Carlo, where we want to simulate Hamiltonian dynamics to propose new states in the Markov chain." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/velocity_verlet_gaussian_with_equivalence.ipynb b/velocity_verlet_gaussian_with_equivalence.ipynb new file mode 100644 index 000000000..d36c21f1c --- /dev/null +++ b/velocity_verlet_gaussian_with_equivalence.ipynb @@ -0,0 +1,558 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Velocity Verlet Integrator for a Gaussian Target\n", + "\n", + "This notebook demonstrates how to use the velocity verlet integrator from Blackjax to simulate Hamiltonian dynamics for a Gaussian target distribution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import jax.scipy.stats as stats\n", + "\n", + "import blackjax\n", + "from blackjax.mcmc import integrators\n", + "from blackjax.mcmc import metrics\n", + "from blackjax.types import ArrayTree\n", + "\n", + "# Set random seed for reproducibility\n", + "rng_key = jax.random.key(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define a Gaussian Target Distribution\n", + "\n", + "We'll use a 2D Gaussian distribution as our target. The log-density function is defined as:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the parameters of the Gaussian\n", + "mean = jnp.array([0.0, 0.0])\n", + "cov = jnp.array([[1.0, 0.5], [0.5, 2.0]])\n", + "\n", + "# Define the log-density function\n", + "def logdensity_fn(position):\n", + " return stats.multivariate_normal.logpdf(position, mean, cov)\n", + "\n", + "# Visualize the target distribution\n", + "def plot_gaussian():\n", + " x = np.linspace(-4, 4, 100)\n", + " y = np.linspace(-4, 4, 100)\n", + " X, Y = np.meshgrid(x, y)\n", + " Z = np.zeros((100, 100))\n", + " for i in range(100):\n", + " for j in range(100):\n", + " Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))\n", + " \n", + " plt.figure(figsize=(10, 8))\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.colorbar(label='Density')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('2D Gaussian Target Distribution')\n", + " plt.axis('equal')\n", + " plt.show()\n", + "\n", + "plot_gaussian()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set Up the Velocity Verlet Integrator\n", + "\n", + "Now, let's set up the velocity verlet integrator. We need to:\n", + "1. Define a kinetic energy function\n", + "2. Create an integrator state\n", + "3. Run the integrator for a few steps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define the inverse mass matrix (for the kinetic energy)\n", + "inverse_mass_matrix = jnp.eye(2) # Identity matrix for simplicity\n", + "\n", + "# Create a metric object\n", + "metric = metrics.default_metric(inverse_mass_matrix)\n", + "\n", + "# Get the kinetic energy function from the metric\n", + "kinetic_energy_fn = metric.kinetic_energy\n", + "\n", + "# Create the velocity verlet integrator\n", + "integrator = integrators.velocity_verlet(logdensity_fn, kinetic_energy_fn)\n", + "\n", + "# Set the initial position and momentum\n", + "initial_position = jnp.array([2.0, 2.0])\n", + "initial_momentum = jnp.array([0.5, -0.3])\n", + "\n", + "# Create the initial integrator state\n", + "initial_state = integrators.new_integrator_state(logdensity_fn, initial_position, initial_momentum)\n", + "\n", + "# Set the step size and number of steps\n", + "step_size = 0.1\n", + "num_steps = 50" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Run the Integrator and Visualize the Trajectory\n", + "\n", + "Now, let's run the integrator for a few steps and visualize the trajectory:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Function to run the integrator for multiple steps\n", + "def run_integrator(initial_state, integrator, step_size, num_steps):\n", + " def one_step(state, _):\n", + " new_state = integrator(state, step_size)\n", + " return new_state, new_state\n", + " \n", + " _, states = jax.lax.scan(one_step, initial_state, None, length=num_steps)\n", + " return states\n", + "\n", + "# Run the integrator\n", + "states = run_integrator(initial_state, integrator, step_size, num_steps)\n", + "\n", + "# Extract positions and momenta\n", + "positions = jnp.array([state.position for state in states])\n", + "momenta = jnp.array([state.momentum for state in states])\n", + "\n", + "# Visualize the trajectory\n", + "def plot_trajectory(positions, momenta):\n", + " x = np.linspace(-4, 4, 100)\n", + " y = np.linspace(-4, 4, 100)\n", + " X, Y = np.meshgrid(x, y)\n", + " Z = np.zeros((100, 100))\n", + " for i in range(100):\n", + " for j in range(100):\n", + " Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))\n", + " \n", + " plt.figure(figsize=(12, 10))\n", + " \n", + " # Plot the target distribution\n", + " plt.subplot(2, 2, 1)\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.colorbar(label='Density')\n", + " plt.plot(positions[:, 0], positions[:, 1], 'r-', label='Trajectory')\n", + " plt.plot(positions[0, 0], positions[0, 1], 'go', label='Start')\n", + " plt.plot(positions[-1, 0], positions[-1, 1], 'bo', label='End')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('Position Trajectory')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " \n", + " # Plot the momentum trajectory\n", + " plt.subplot(2, 2, 2)\n", + " plt.plot(momenta[:, 0], momenta[:, 1], 'b-', label='Trajectory')\n", + " plt.plot(momenta[0, 0], momenta[0, 1], 'go', label='Start')\n", + " plt.plot(momenta[-1, 0], momenta[-1, 1], 'ro', label='End')\n", + " plt.xlabel('p_x')\n", + " plt.ylabel('p_y')\n", + " plt.title('Momentum Trajectory')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " \n", + " # Plot position vs time\n", + " plt.subplot(2, 2, 3)\n", + " time = np.arange(num_steps) * step_size\n", + " plt.plot(time, positions[:, 0], 'r-', label='x')\n", + " plt.plot(time, positions[:, 1], 'b-', label='y')\n", + " plt.xlabel('Time')\n", + " plt.ylabel('Position')\n", + " plt.title('Position vs Time')\n", + " plt.legend()\n", + " \n", + " # Plot momentum vs time\n", + " plt.subplot(2, 2, 4)\n", + " plt.plot(time, momenta[:, 0], 'r-', label='p_x')\n", + " plt.plot(time, momenta[:, 1], 'b-', label='p_y')\n", + " plt.xlabel('Time')\n", + " plt.ylabel('Momentum')\n", + " plt.title('Momentum vs Time')\n", + " plt.legend()\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "plot_trajectory(positions, momenta)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Energy Conservation\n", + "\n", + "One of the key properties of Hamiltonian dynamics is energy conservation. Let's check if our integrator preserves energy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Calculate the total energy (potential + kinetic) at each step\n", + "def calculate_energy(states):\n", + " # Potential energy is the negative log-density\n", + " potential_energy = -jnp.array([state.logdensity for state in states])\n", + " \n", + " # Kinetic energy\n", + " kinetic_energy = jnp.array([kinetic_energy_fn(state.momentum) for state in states])\n", + " \n", + " # Total energy\n", + " total_energy = potential_energy + kinetic_energy\n", + " \n", + " return potential_energy, kinetic_energy, total_energy\n", + "\n", + "potential_energy, kinetic_energy, total_energy = calculate_energy(states)\n", + "\n", + "# Plot the energy components\n", + "plt.figure(figsize=(12, 6))\n", + "time = np.arange(num_steps) * step_size\n", + "plt.plot(time, potential_energy, 'r-', label='Potential Energy')\n", + "plt.plot(time, kinetic_energy, 'b-', label='Kinetic Energy')\n", + "plt.plot(time, total_energy, 'g-', label='Total Energy')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Energy')\n", + "plt.title('Energy Conservation')\n", + "plt.legend()\n", + "plt.grid(True)\n", + "plt.show()\n", + "\n", + "# Calculate the relative energy error\n", + "initial_total_energy = total_energy[0]\n", + "relative_energy_error = jnp.abs(total_energy - initial_total_energy) / jnp.abs(initial_total_energy)\n", + "max_relative_error = jnp.max(relative_energy_error)\n", + "print(f\"Maximum relative energy error: {max_relative_error:.6f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Time Reversibility\n", + "\n", + "Another important property of Hamiltonian dynamics is time reversibility. Let's check if our integrator is time-reversible by running it forward and then backward:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Function to run the integrator backward (by negating the step size)\n", + "def run_integrator_backward(final_state, integrator, step_size, num_steps):\n", + " def one_step(state, _):\n", + " new_state = integrator(state, -step_size) # Negative step size for backward integration\n", + " return new_state, new_state\n", + " \n", + " _, states = jax.lax.scan(one_step, final_state, None, length=num_steps)\n", + " return states\n", + "\n", + "# Run the integrator forward\n", + "forward_states = run_integrator(initial_state, integrator, step_size, num_steps)\n", + "final_state = forward_states[-1]\n", + "\n", + "# Run the integrator backward\n", + "backward_states = run_integrator_backward(final_state, integrator, step_size, num_steps)\n", + "reversed_state = backward_states[-1]\n", + "\n", + "# Check if we've returned to the initial state\n", + "position_error = jnp.linalg.norm(initial_state.position - reversed_state.position)\n", + "momentum_error = jnp.linalg.norm(initial_state.momentum - reversed_state.momentum)\n", + "\n", + "print(f\"Position error: {position_error:.10f}\")\n", + "print(f\"Momentum error: {momentum_error:.10f}\")\n", + "\n", + "# Visualize the forward and backward trajectories\n", + "def plot_reversibility(forward_states, backward_states):\n", + " forward_positions = jnp.array([state.position for state in forward_states])\n", + " backward_positions = jnp.array([state.position for state in backward_states])\n", + " \n", + " x = np.linspace(-4, 4, 100)\n", + " y = np.linspace(-4, 4, 100)\n", + " X, Y = np.meshgrid(x, y)\n", + " Z = np.zeros((100, 100))\n", + " for i in range(100):\n", + " for j in range(100):\n", + " Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))\n", + " \n", + " plt.figure(figsize=(10, 8))\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.colorbar(label='Density')\n", + " plt.plot(forward_positions[:, 0], forward_positions[:, 1], 'r-', label='Forward')\n", + " plt.plot(backward_positions[:, 0], backward_positions[:, 1], 'b--', label='Backward')\n", + " plt.plot(forward_positions[0, 0], forward_positions[0, 1], 'go', label='Start')\n", + " plt.plot(forward_positions[-1, 0], forward_positions[-1, 1], 'bo', label='End')\n", + " plt.plot(backward_positions[-1, 0], backward_positions[-1, 1], 'ro', label='Reversed')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('Time Reversibility')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " plt.show()\n", + "\n", + "plot_reversibility(forward_states, backward_states)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Equivalence of Mass Matrix and Covariance Matrix\n", + "\n", + "Now, let's demonstrate an interesting property of Hamiltonian dynamics: using a mass matrix that is the inverse of the covariance matrix is equivalent to using a unit mass matrix on a Gaussian with unit covariance.\n", + "\n", + "This is a fundamental insight in Hamiltonian Monte Carlo, as it allows us to transform the problem into a simpler one with isotropic dynamics.\n", + "\n", + "### Mathematical Explanation\n", + "\n", + "For a Gaussian target with covariance matrix $\\Sigma$, the log-density is:\n", + "\n", + "$$\\log p(x) = -\\frac{1}{2}(x-\\mu)^T \\Sigma^{-1} (x-\\mu) + \\text{const}$$\n", + "\n", + "If we use a mass matrix $M = \\Sigma$, the Hamiltonian becomes:\n", + "\n", + "$$H(x, p) = -\\log p(x) + \\frac{1}{2}p^T M^{-1} p = \\frac{1}{2}(x-\\mu)^T \\Sigma^{-1} (x-\\mu) + \\frac{1}{2}p^T \\Sigma^{-1} p$$\n", + "\n", + "Now, let's transform the variables:\n", + "\n", + "$$x' = \\Sigma^{-1/2}(x-\\mu)$$\n", + "$$p' = \\Sigma^{1/2}p$$\n", + "\n", + "The Hamiltonian in the transformed variables becomes:\n", + "\n", + "$$H(x', p') = \\frac{1}{2}x'^T x' + \\frac{1}{2}p'^T p'$$\n", + "\n", + "This is the Hamiltonian for a standard Gaussian with unit covariance and unit mass matrix.\n", + "\n", + "Let's demonstrate this equivalence numerically:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define a function to transform coordinates\n", + "def transform_coordinates(position, momentum, cov):\n", + " # Compute the Cholesky decomposition of the covariance matrix\n", + " L = jnp.linalg.cholesky(cov)\n", + " L_inv = jnp.linalg.inv(L)\n", + " \n", + " # Transform position and momentum\n", + " position_transformed = L_inv @ (position - mean)\n", + " momentum_transformed = L.T @ momentum\n", + " \n", + " return position_transformed, momentum_transformed\n", + "\n", + "# Define a function to inverse transform coordinates\n", + "def inverse_transform_coordinates(position_transformed, momentum_transformed, cov):\n", + " # Compute the Cholesky decomposition of the covariance matrix\n", + " L = jnp.linalg.cholesky(cov)\n", + " L_inv = jnp.linalg.inv(L)\n", + " \n", + " # Inverse transform position and momentum\n", + " position = L @ position_transformed + mean\n", + " momentum = L_inv.T @ momentum_transformed\n", + " \n", + " return position, momentum\n", + "\n", + "# Define a standard Gaussian log-density function (unit covariance)\n", + "def standard_logdensity_fn(position):\n", + " return stats.multivariate_normal.logpdf(position, jnp.zeros(2), jnp.eye(2))\n", + "\n", + "# Set up three integrators:\n", + "# 1. Original Gaussian with unit mass matrix\n", + "original_metric = metrics.default_metric(jnp.eye(2))\n", + "original_integrator = integrators.velocity_verlet(logdensity_fn, original_metric.kinetic_energy)\n", + "\n", + "# 2. Original Gaussian with mass matrix = covariance matrix\n", + "cov_mass_metric = metrics.default_metric(cov)\n", + "cov_mass_integrator = integrators.velocity_verlet(logdensity_fn, cov_mass_metric.kinetic_energy)\n", + "\n", + "# 3. Standard Gaussian (unit covariance) with unit mass matrix\n", + "standard_metric = metrics.default_metric(jnp.eye(2))\n", + "standard_integrator = integrators.velocity_verlet(standard_logdensity_fn, standard_metric.kinetic_energy)\n", + "\n", + "# Set initial conditions\n", + "initial_position = jnp.array([2.0, 2.0])\n", + "initial_momentum = jnp.array([0.5, -0.3])\n", + "\n", + "# Transform initial conditions for the standard Gaussian\n", + "transformed_position, transformed_momentum = transform_coordinates(initial_position, initial_momentum, cov)\n", + "\n", + "# Create initial states\n", + "original_state = integrators.new_integrator_state(logdensity_fn, initial_position, initial_momentum)\n", + "cov_mass_state = integrators.new_integrator_state(logdensity_fn, initial_position, initial_momentum)\n", + "standard_state = integrators.new_integrator_state(standard_logdensity_fn, transformed_position, transformed_momentum)\n", + "\n", + "# Run the integrators\n", + "step_size = 0.1\n", + "num_steps = 50\n", + "\n", + "original_states = run_integrator(original_state, original_integrator, step_size, num_steps)\n", + "cov_mass_states = run_integrator(cov_mass_state, cov_mass_integrator, step_size, num_steps)\n", + "standard_states = run_integrator(standard_state, standard_integrator, step_size, num_steps)\n", + "\n", + "# Extract positions\n", + "original_positions = jnp.array([state.position for state in original_states])\n", + "cov_mass_positions = jnp.array([state.position for state in cov_mass_states])\n", + "standard_positions = jnp.array([state.position for state in standard_states])\n", + "\n", + "# Transform standard positions back to original space\n", + "transformed_standard_positions = jnp.array([\n", + " inverse_transform_coordinates(pos, jnp.zeros(2), cov)[0] for pos in standard_positions\n", + "])\n", + "\n", + "# Compare the trajectories\n", + "def plot_equivalent_trajectories():\n", + " x = np.linspace(-4, 4, 100)\n", + " y = np.linspace(-4, 4, 100)\n", + " X, Y = np.meshgrid(x, y)\n", + " Z = np.zeros((100, 100))\n", + " for i in range(100):\n", + " for j in range(100):\n", + " Z[i, j] = np.exp(logdensity_fn(np.array([X[i, j], Y[i, j]])))\n", + " \n", + " plt.figure(figsize=(15, 5))\n", + " \n", + " # Plot original Gaussian with unit mass matrix\n", + " plt.subplot(1, 3, 1)\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.plot(original_positions[:, 0], original_positions[:, 1], 'r-', label='Trajectory')\n", + " plt.plot(original_positions[0, 0], original_positions[0, 1], 'go', label='Start')\n", + " plt.plot(original_positions[-1, 0], original_positions[-1, 1], 'bo', label='End')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('Original Gaussian\\nUnit Mass Matrix')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " \n", + " # Plot original Gaussian with mass matrix = covariance\n", + " plt.subplot(1, 3, 2)\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.plot(cov_mass_positions[:, 0], cov_mass_positions[:, 1], 'r-', label='Trajectory')\n", + " plt.plot(cov_mass_positions[0, 0], cov_mass_positions[0, 1], 'go', label='Start')\n", + " plt.plot(cov_mass_positions[-1, 0], cov_mass_positions[-1, 1], 'bo', label='End')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('Original Gaussian\\nMass Matrix = Covariance')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " \n", + " # Plot standard Gaussian with unit mass matrix (transformed back)\n", + " plt.subplot(1, 3, 3)\n", + " plt.contour(X, Y, Z, levels=20)\n", + " plt.plot(transformed_standard_positions[:, 0], transformed_standard_positions[:, 1], 'r-', label='Trajectory')\n", + " plt.plot(transformed_standard_positions[0, 0], transformed_standard_positions[0, 1], 'go', label='Start')\n", + " plt.plot(transformed_standard_positions[-1, 0], transformed_standard_positions[-1, 1], 'bo', label='End')\n", + " plt.xlabel('x')\n", + " plt.ylabel('y')\n", + " plt.title('Standard Gaussian (Transformed)\\nUnit Mass Matrix')\n", + " plt.legend()\n", + " plt.axis('equal')\n", + " \n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "plot_equivalent_trajectories()\n", + "\n", + "# Calculate the differences between trajectories\n", + "original_cov_mass_diff = jnp.mean(jnp.abs(original_positions - cov_mass_positions))\n", + "original_standard_diff = jnp.mean(jnp.abs(original_positions - transformed_standard_positions))\n", + "cov_mass_standard_diff = jnp.mean(jnp.abs(cov_mass_positions - transformed_standard_positions))\n", + "\n", + "print(f\"Average difference between original and cov_mass trajectories: {original_cov_mass_diff:.6f}\")\n", + "print(f\"Average difference between original and standard trajectories: {original_standard_diff:.6f}\")\n", + "print(f\"Average difference between cov_mass and standard trajectories: {cov_mass_standard_diff:.6f}\")\n", + "\n", + "# Plot the differences over time\n", + "plt.figure(figsize=(12, 6))\n", + "time = np.arange(num_steps) * step_size\n", + "plt.plot(time, jnp.abs(original_positions - cov_mass_positions).mean(axis=1), 'r-', label='Original vs Cov Mass')\n", + "plt.plot(time, jnp.abs(original_positions - transformed_standard_positions).mean(axis=1), 'b-', label='Original vs Standard')\n", + "plt.plot(time, jnp.abs(cov_mass_positions - transformed_standard_positions).mean(axis=1), 'g-', label='Cov Mass vs Standard')\n", + "plt.xlabel('Time')\n", + "plt.ylabel('Average Absolute Difference')\n", + "plt.title('Trajectory Differences Over Time')\n", + "plt.legend()\n", + "plt.grid(True)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "In this notebook, we've demonstrated how to use the velocity verlet integrator from Blackjax to simulate Hamiltonian dynamics for a Gaussian target distribution. We've shown that:\n", + "\n", + "1. The integrator can be used to simulate the trajectory of a particle in the potential energy landscape defined by the negative log-density of the target distribution.\n", + "2. The integrator approximately conserves energy, with small errors due to the numerical approximation.\n", + "3. The integrator is time-reversible, meaning that running it forward and then backward returns to the initial state (up to numerical errors).\n", + "4. Using a mass matrix that is the inverse of the covariance matrix is equivalent to using a unit mass matrix on a Gaussian with unit covariance.\n", + "\n", + "These properties make the velocity verlet integrator a good choice for Hamiltonian Monte Carlo, where we want to simulate Hamiltonian dynamics to propose new states in the Markov chain. The equivalence between mass matrices and covariance matrices is particularly useful for designing efficient samplers, as it allows us to transform complex target distributions into simpler ones with isotropic dynamics." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}