Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
fa158ea
langevin structs
reubenharry Dec 11, 2024
8eed424
add static adjusted mclmc
reubenharry Jan 15, 2025
9dd6bdb
add static adjusted mclmc
reubenharry Jan 15, 2025
a49bb35
add static adjusted mclmc
reubenharry Jan 15, 2025
35d71ff
draft
reubenharry Jan 15, 2025
0f3df53
draft
reubenharry Jan 15, 2025
04522f5
change order of parameters
reubenharry Jan 15, 2025
6099c01
Merge branch 'adjusted_mclmc_static' into emaus
reubenharry Jan 15, 2025
a7c99b9
draft
reubenharry Jan 16, 2025
6972f23
mid cleanup
reubenharry Feb 3, 2025
c96a8e8
fix while loop
reubenharry Feb 4, 2025
805113a
test passes
reubenharry Feb 6, 2025
16841c6
precommit
reubenharry Feb 6, 2025
d203427
update
reubenharry Feb 6, 2025
52ce7ad
update
reubenharry Feb 6, 2025
cc7bfbd
docstrings
reubenharry Feb 10, 2025
58d9920
remove debug statements
reubenharry Feb 10, 2025
4274b07
precommit
reubenharry Feb 10, 2025
d951a44
modify test
reubenharry Feb 11, 2025
ba8f6eb
modify test
reubenharry Feb 11, 2025
29dcd54
modify test
reubenharry Feb 11, 2025
4b7d8b0
modify test
reubenharry Feb 11, 2025
67cfd71
clean up and bug fix
reubenharry Feb 22, 2025
1877cc5
Merge branch 'main' into emaus
reubenharry Feb 24, 2025
5ee0b8a
Update blackjax/adaptation/step_size.py
reubenharry Feb 27, 2025
2b33625
clean up and bug fix
reubenharry Feb 27, 2025
4aec6e9
Merge branch 'emaus' of github.com:reubenharry/blackjax into emaus
reubenharry Feb 27, 2025
cc5e09a
clean up and bug fix
reubenharry Feb 27, 2025
bd40cf9
bug present in minimal_repro_3.py
reubenharry Mar 5, 2025
f7f8d86
wip
Mar 6, 2025
8242ef0
merge
Mar 6, 2025
a5eb4f4
wip
Mar 6, 2025
f404898
bug fix
reubenharry Mar 10, 2025
e64b7f4
Merge branch 'emaus' of github.com:reubenharry/blackjax into emaus
reubenharry Mar 10, 2025
9b00e28
bug fix
reubenharry Mar 10, 2025
f35f98e
bug fix
reubenharry Mar 10, 2025
b55ab0d
bug fix
reubenharry Mar 10, 2025
e6da5c2
changes
Mar 10, 2025
ea92b1c
small changes
Mar 10, 2025
0bd1414
bug fix
reubenharry Mar 10, 2025
13a375c
bug fix
reubenharry Mar 10, 2025
8834188
langevin
Mar 24, 2025
04e5b61
first attempt at langevin
Mar 29, 2025
3906c0e
emaus diagnostics
reubenharry Apr 1, 2025
3d465f2
tuning for hmc
Apr 4, 2025
e17b1cc
energy error monitoring
reubenharry Apr 9, 2025
1015774
energy error monitoring
reubenharry Apr 9, 2025
fe28c3c
preconditioning
Apr 22, 2025
a164c63
windows for unadjusted
Apr 23, 2025
6448735
add preconditioning for ulmc
Apr 25, 2025
fb72f34
Merge branch 'emaus' into working_branch
Apr 25, 2025
a612d39
fix emaus code
Apr 26, 2025
0ced55f
FOR NEURIPS
May 15, 2025
89ccc8b
updates
Jun 3, 2025
6fe3963
update
JakobRobnik Jun 4, 2025
1ea9ac4
fixed unadjusted phase (removed diagonal precond), tried out nuts-sty…
JakobRobnik Jun 5, 2025
82ed6da
working branch
Jun 12, 2025
c44dce3
attempted to change the nuts angle
JakobRobnik Jun 19, 2025
abc0cb7
Revert "attempted to change the nuts angle"
JakobRobnik Jun 19, 2025
2d33a1f
working branch
Jun 23, 2025
4b074b1
Merge branch 'working_branch' of github.com:reubenharry/blackjax into…
Jun 23, 2025
26f4528
fixed laps
JakobRobnik Jun 23, 2025
b65a664
working old version of eca_step with splitR
JakobRobnik Jun 27, 2025
28cabf0
new adaptation
Jun 30, 2025
dc1bde7
Merge branch 'working_branch' of github.com:reubenharry/blackjax into…
Jun 30, 2025
9fdc500
checkpoint
Aug 23, 2025
3ee1c14
fix laps
Aug 27, 2025
ce2fe53
tryinmg to fix (overwrite later)
JakobRobnik Aug 28, 2025
9eaa489
Merge remote-tracking branch 'refs/remotes/origin/working_branch' int…
JakobRobnik Aug 28, 2025
649a316
las
Sep 16, 2025
18209ca
las
Sep 17, 2025
c6e99ce
las
Sep 17, 2025
029c78b
las comments
JakobRobnik Sep 18, 2025
35ba313
submission
JakobRobnik Oct 3, 2025
5e28ec5
pseudofermion
Oct 3, 2025
861f147
las
Oct 3, 2025
bffe235
fix
Oct 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,32 @@

from blackjax._version import __version__

from .adaptation.adjusted_mclmc_adaptation import adjusted_mclmc_find_L_and_step_size
from .adaptation.chees_adaptation import chees_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .adaptation.unadjusted_alba import unadjusted_alba
from .adaptation.unadjusted_step_size import robnik_step_size_tuning
from .adaptation.adjusted_alba import adjusted_alba
from .adaptation.las import las
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .mcmc import adjusted_mclmc as _adjusted_mclmc
from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic
from .mcmc import barker
from .mcmc import dynamic_hmc as _dynamic_hmc
from .mcmc import dynamic_malt as _dynamic_malt
from .mcmc import elliptical_slice as _elliptical_slice
from .mcmc import ghmc as _ghmc
from .mcmc import hmc as _hmc
from .mcmc import uhmc as _uhmc
from .mcmc import malt as _malt
from .mcmc import mala as _mala
from .mcmc import pseudofermion as _pseudofermion
from .mcmc import marginal_latent_gaussian
from .mcmc import mclmc as _mclmc
from .mcmc import mchmc as _mchmc
from .mcmc import underdamped_langevin as _langevin
from .mcmc import nuts as _nuts
from .mcmc import periodic_orbital, random_walk
from .mcmc import rmhmc as _rmhmc
Expand Down Expand Up @@ -96,12 +103,15 @@ def generate_top_level_api_from(module):

# MCMC
hmc = generate_top_level_api_from(_hmc)
uhmc = generate_top_level_api_from(_uhmc)
malt = generate_top_level_api_from(_malt)
nuts = generate_top_level_api_from(_nuts)
rmh = GenerateSamplingAPI(rmh_as_top_level_api, random_walk.init, random_walk.build_rmh)
irmh = GenerateSamplingAPI(
irmh_as_top_level_api, random_walk.init, random_walk.build_irmh
)
dynamic_hmc = generate_top_level_api_from(_dynamic_hmc)
dynamic_malt = generate_top_level_api_from(_dynamic_malt)
rmhmc = generate_top_level_api_from(_rmhmc)
mala = generate_top_level_api_from(_mala)
mgrad_gaussian = generate_top_level_api_from(marginal_latent_gaussian)
Expand All @@ -114,12 +124,14 @@ def generate_top_level_api_from(module):
additive_step_random_walk.register_factory("normal_random_walk", normal_random_walk)

mclmc = generate_top_level_api_from(_mclmc)
mchmc = generate_top_level_api_from(_mchmc)
langevin = generate_top_level_api_from(_langevin)
adjusted_mclmc_dynamic = generate_top_level_api_from(_adjusted_mclmc_dynamic)
adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc)
# adjusted_mclmc = generate_top_level_api_from(_adjusted_mclmc)
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
ghmc = generate_top_level_api_from(_ghmc)
barker_proposal = generate_top_level_api_from(barker)

pseudofermion = generate_top_level_api_from(_pseudofermion)
hmc_family = [hmc, nuts]

# SMC
Expand Down Expand Up @@ -165,8 +177,10 @@ def generate_top_level_api_from(module):
"meads_adaptation",
"chees_adaptation",
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"adjusted_mclmc_find_L_and_step_size", # adjusted mclmc adaptation
"ess", # diagnostics
"rhat",
"unadjusted_alba",
"robnik_step_size_tuning",
"adjusted_alba",
"las",
]
4 changes: 4 additions & 0 deletions blackjax/adaptation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
meads_adaptation,
pathfinder_adaptation,
window_adaptation,
unadjusted_alba,
unadjusted_step_size,
)

__all__ = [
Expand All @@ -12,4 +14,6 @@
"window_adaptation",
"pathfinder_adaptation",
"mclmc_adaptation",
"unadjusted_alba",
"robnik_step_size_tuning",
]
156 changes: 156 additions & 0 deletions blackjax/adaptation/adjusted_alba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from blackjax.adaptation.step_size import (
dual_averaging_adaptation,
)
from blackjax.mcmc.adjusted_mclmc_dynamic import rescale
from blackjax.base import AdaptationAlgorithm
from blackjax.types import ArrayLikeTree, PRNGKey
import jax
import jax.numpy as jnp
from typing import Callable
import blackjax
from blackjax.adaptation.unadjusted_alba import unadjusted_alba



def make_random_trajectory_length_fn(random_trajectory_length : bool):
if random_trajectory_length:
integration_steps_fn = lambda avg_num_integration_steps: lambda k: jnp.where(jnp.ceil(
jax.random.uniform(k) * rescale(avg_num_integration_steps)
)==0, 1, jnp.ceil(
jax.random.uniform(k) * rescale(avg_num_integration_steps))).astype('int32')
else:
integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(
avg_num_integration_steps
).astype('int32')
return integration_steps_fn

def da_adaptation(
algorithm,
logdensity_fn: Callable,
integration_steps_fn: Callable,
inverse_mass_matrix,
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
initial_L: float = 1.0,
integrator=blackjax.mcmc.integrators.velocity_verlet,
L_proposal_factor=jnp.inf,
):

da_init, da_update, da_final = dual_averaging_adaptation(target_acceptance_rate)
kernel = algorithm.build_kernel(integrator=integrator, L_proposal_factor=L_proposal_factor)

# initial_L = jnp.clip(initial_L, min=initial_step_size+0.01)


def step(state, key):

(adaptation_state, kernel_state), L = state
new_kernel_state, info = kernel(
rng_key=key,
state=kernel_state,
logdensity_fn=logdensity_fn,
step_size=jnp.exp(adaptation_state.log_step_size),
inverse_mass_matrix=inverse_mass_matrix,
integration_steps_fn=integration_steps_fn(L/jnp.exp(adaptation_state.log_step_size)),
)

new_adaptation_state = da_update(
adaptation_state,
info.acceptance_rate,
)


return (
((new_adaptation_state, new_kernel_state), L),
None,
)

def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):


init_key, rng_key = jax.random.split(rng_key)

init_kernel_state = algorithm.init(position=position, logdensity_fn=logdensity_fn, random_generator_arg=init_key)

keys = jax.random.split(rng_key, num_steps)
init_state = da_init(initial_step_size), init_kernel_state
((adaptation_state, kernel_state), L), info = jax.lax.scan(
step,
(init_state, initial_L),
keys,

)
step_size = da_final(adaptation_state)
return (
kernel_state,
{
"step_size": step_size,
"inverse_mass_matrix": inverse_mass_matrix,
"L": L,
},
info,
)

return AdaptationAlgorithm(run)


def adjusted_alba(
unadjusted_algorithm,
logdensity_fn: Callable,
target_eevpd,
v,
adjusted_algorithm,
integrator,
target_acceptance_rate: float = 0.80,
num_alba_steps: int = 500,
alba_factor: float = 0.4,
preconditioning: bool = True,
L_proposal_factor=jnp.inf,
**extra_parameters,
):

unadjusted_warmup = unadjusted_alba(
algorithm= unadjusted_algorithm,
logdensity_fn=logdensity_fn,
target_eevpd=target_eevpd,
v=v,
integrator=integrator,
num_alba_steps=num_alba_steps,
alba_factor=alba_factor,
preconditioning=preconditioning,
**extra_parameters)

def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000):

unadjusted_warmup_key, adjusted_warmup_key = jax.random.split(rng_key)

num_unadjusted_steps = 20000

(state, params), adaptation_info = unadjusted_warmup.run(unadjusted_warmup_key, position, num_unadjusted_steps)

# jax.debug.print("unadjusted params: {params}", params=(params["L"], params["step_size"]))
# jax.debug.print("unadjusted params: {params}", params=params)

integration_steps_fn = make_random_trajectory_length_fn(random_trajectory_length=True)

adjusted_warmup = da_adaptation(
algorithm=adjusted_algorithm,
logdensity_fn=logdensity_fn,
integration_steps_fn=integration_steps_fn,
initial_L=params["L"],
initial_step_size=params["step_size"],
target_acceptance_rate=target_acceptance_rate,
inverse_mass_matrix=params["inverse_mass_matrix"],
integrator=integrator, L_proposal_factor=L_proposal_factor, **extra_parameters)



state, params, adaptation_info = adjusted_warmup.run(adjusted_warmup_key, state.position, num_steps)
# jax.debug.print("adjusted params: {params}", params=(params["L"], params["step_size"]))
# raise Exception("stop")
# return None
return state, params, adaptation_info

return AdaptationAlgorithm(run)


54 changes: 47 additions & 7 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def adjusted_mclmc_find_L_and_step_size(
max="avg",
num_windows=1,
tuning_factor=1.3,
euclidean=False,
):
"""
Finds the optimal value of the parameters for the MH-MCHMC algorithm.
Expand Down Expand Up @@ -73,7 +74,14 @@ def adjusted_mclmc_find_L_and_step_size(

dim = pytree_size(state.position)
if params is None:
params = MCLMCAdaptationState(
if euclidean:

params = MCLMCAdaptationState(
1.0, 0.2, inverse_mass_matrix=jnp.ones((dim,))
)

else:
params = MCLMCAdaptationState(
jnp.sqrt(dim), jnp.sqrt(dim) * 0.2, inverse_mass_matrix=jnp.ones((dim,))
)

Expand All @@ -96,6 +104,7 @@ def adjusted_mclmc_find_L_and_step_size(
diagonal_preconditioning=diagonal_preconditioning,
max=max,
tuning_factor=tuning_factor,
euclidean=euclidean,
)(
state, params, num_steps, window_key
)
Expand All @@ -113,7 +122,7 @@ def adjusted_mclmc_find_L_and_step_size(
) = adjusted_mclmc_make_adaptation_L(
mclmc_kernel,
frac=frac_tune3,
Lfactor=0.5,
Lfactor=0.3,
max=max,
eigenvector=eigenvector,
)(
Expand Down Expand Up @@ -156,6 +165,7 @@ def adjusted_mclmc_make_L_step_size_adaptation(
fix_L_first_da=False,
max="avg",
tuning_factor=1.0,
euclidean=False,
):
"""Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC"""

Expand Down Expand Up @@ -207,6 +217,7 @@ def step(iteration_state, weight_and_key):
step_size = jax.lax.clamp(
1e-5, jnp.exp(adaptive_state.log_step_size), params.L / 1.1
)
# jax.debug.print("step size in adaptation {x}",x=step_size)
adaptive_state = adaptive_state._replace(log_step_size=jnp.log(step_size))

x = ravel_pytree(state.position)[0]
Expand Down Expand Up @@ -256,7 +267,10 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps * frac_tune2
)

check_key, rng_key = jax.random.split(rng_key, 2)
# jax.debug.print("num steps1 {x}",x=num_steps1)
# jax.debug.print("num steps 2 {x}",x=num_steps2)

# check_key, rng_key = jax.random.split(rng_key, 2)

rng_key_pass1, rng_key_pass2 = jax.random.split(rng_key, 2)
L_step_size_adaptation_keys_pass1 = jax.random.split(
Expand Down Expand Up @@ -293,24 +307,49 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
variances = x_squared_average - jnp.square(x_average)

if max == "max":
contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor
if euclidean:
contract = lambda x: (jnp.sqrt(jnp.max(x) * dim) * tuning_factor) / jnp.sqrt(dim)

else:
contract = lambda x: jnp.sqrt(jnp.max(x) * dim) * tuning_factor

elif max == "avg":
contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor
print("avg")
if euclidean:

contract = lambda x: (jnp.sqrt(jnp.sum(x)) * tuning_factor) / jnp.sqrt(dim)
else:
contract = lambda x: jnp.sqrt(jnp.sum(x)) * tuning_factor

else:
raise ValueError("max should be either 'max' or 'avg'")

new_L = params.L

change = jax.lax.clamp(
Lratio_lowerbound,
contract(variances) / params.L,
contract(variances) / new_L,
Lratio_upperbound,
)
# if euclidean:
# # new_L /= jnp.sqrt(dim)
# change /= jnp.sqrt(dim)


params = params._replace(
L=params.L * change, step_size=params.step_size * change
)
if diagonal_preconditioning:
params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim))
if euclidean:
params = params._replace(inverse_mass_matrix=variances, L=1.)
else:
params = params._replace(inverse_mass_matrix=variances, L=jnp.sqrt(dim))

# else:
# if euclidean:
# params = params._replace(L = params.L / jnp.sqrt(dim))

# jax.debug.print("params L {x}", x=(params.L, contract(variances), jnp.sum(variances), tuning_factor))

initial_da, update_da, final_da = dual_averaging_adaptation(target=target)
(
Expand All @@ -330,6 +369,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):

params = params._replace(step_size=final_da(dual_avg_state))


return state, params, eigenvector, num_tuning_integrator_steps

return L_step_size_adaptation
Expand Down
Loading