Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 40 additions & 59 deletions blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Callable

import jax
Expand All @@ -21,37 +22,14 @@
import blackjax.smc.solver as solver
import blackjax.smc.tempered as tempered
from blackjax.base import SamplingAlgorithm
from blackjax.smc import from_mcmc as smc_from_mcmc
from blackjax.types import ArrayLikeTree, PRNGKey

__all__ = ["build_kernel", "init", "as_top_level_api"]


def build_kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
target_ess: float,
root_solver: Callable = solver.dichotomy,
**extra_parameters,
) -> Callable:
r"""Build a Tempered SMC step using an adaptive schedule.

Parameters
----------
logprior_fn: Callable
A function that computes the log-prior density.
loglikelihood_fn: Callable
A function that returns the log-likelihood density.
mcmc_kernel_factory: Callable
A callable function that creates a mcmc kernel from a log-probability
density function.
make_mcmc_state: Callable
A function that creates a new mcmc state from a position and a
log-probability density function.
resampling_fn: Callable
A random function that resamples generated particles based of weights
def build_kernel(loglikelihood_fn, target_ess, root_solver, tempered_kernel):
"""
target_ess: float
The target ESS for the adaptive MCMC tempering
root_solver: Callable, optional
Expand All @@ -60,13 +38,6 @@ def build_kernel(
use_log_ess: bool, optional
Use ESS in log space to solve for delta, default is `True`.
This is usually more stable when using gradient based solvers.

Returns
-------
A callable that takes a rng_key and a TemperedSMCState that contains the current state
of the chain and that returns a new state of the chain along with
information about the transition.

"""

def compute_delta(state: tempered.TemperedSMCState) -> float:
Expand All @@ -83,24 +54,13 @@ def compute_delta(state: tempered.TemperedSMCState) -> float:

return delta

tempered_kernel = tempered.build_kernel(
logprior_fn,
loglikelihood_fn,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
**extra_parameters,
)

def kernel(
rng_key: PRNGKey,
state: tempered.TemperedSMCState,
num_mcmc_steps: int,
mcmc_parameters: dict,
) -> tuple[tempered.TemperedSMCState, base.SMCInfo]:
delta = compute_delta(state)
lmbda = delta + state.lmbda
return tempered_kernel(rng_key, state, num_mcmc_steps, lmbda, mcmc_parameters)
return tempered_kernel(rng_key, state, lmbda)

return kernel

Expand All @@ -118,6 +78,8 @@ def as_top_level_api(
target_ess: float,
root_solver: Callable = solver.dichotomy,
num_mcmc_steps: int = 10,
mcmc_run_strategy=None,
mutation_step=None,
**extra_parameters,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Adaptive Tempered SMC kernel.
Expand Down Expand Up @@ -148,29 +110,48 @@ def as_top_level_api(
Returns
-------
A ``SamplingAlgorithm``.

"""
kernel = build_kernel(
if num_mcmc_steps is not None:
mcmc_run_strategy = functools.partial(
base.update_and_take_last, num_mcmc_steps=num_mcmc_steps
)
mutation_step = smc_from_mcmc.build_kernel(
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
mcmc_parameters,
mcmc_run_strategy,
)

elif mcmc_run_strategy is not None:
mutation_step = smc_from_mcmc.build_kernel(
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
mcmc_parameters,
mcmc_run_strategy,
)

elif mutation_step is not None:
mutation_step = mutation_step
else:
raise ValueError(
"You must either supply num_mcmc_steps, or mcmc_run_strategy or mutation_step"
)

tempered_kernel = tempered.build_kernel(
logprior_fn,
loglikelihood_fn,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
target_ess,
root_solver,
**extra_parameters,
mutation_step,
)

kernel = build_kernel(loglikelihood_fn, target_ess, root_solver, tempered_kernel)

def init_fn(position: ArrayLikeTree, rng_key=None):
del rng_key
return init(position)

def step_fn(rng_key: PRNGKey, state):
return kernel(
rng_key,
state,
num_mcmc_steps,
mcmc_parameters,
)
return kernel(rng_key, state)

return SamplingAlgorithm(init_fn, step_fn)
Loading