Skip to content

Improve make_transform_adapt docstring #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions python/nutpie/compiled_pyfunc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
import dataclasses
from dataclasses import dataclass
from functools import partial
from functools import partial, wraps
from typing import Any, Callable

import numpy as np

from nutpie import _lib # type: ignore
from nutpie.sample import CompiledModel

from importlib.util import find_spec

# importing from transform_adapter requires flowjax to be installed, which will not be the case for
# all users. If it's not present, the user can't access the with_transform_adapt method anyway, so we can
# use a dummy function so the docstring wrapper is always valid.
if find_spec("flowjax") is not None:
from nutpie.transform_adapter import make_transform_adapter
else:

def make_transform_adapter(*args, **kwargs):
"""Normalizing flow adaption not available. Install flowjax to use."""
pass


SeedType = int


Expand Down Expand Up @@ -44,6 +58,7 @@ def with_data(self, **updates):
updated.update(**updates)
return dataclasses.replace(self, _shared_data=updated)

@wraps(make_transform_adapter)
def with_transform_adapt(self, **kwargs):
return dataclasses.replace(self, _transform_adapt_args=kwargs)

Expand Down Expand Up @@ -71,8 +86,6 @@ def make_expand_func(seed1, seed2, chain):
outer_kwargs = {}

def make_adapter(*args, **kwargs):
from nutpie.transform_adapter import make_transform_adapter

return make_transform_adapter(**outer_kwargs)(
*args, **kwargs, logp_fn=self._raw_logp_fn
)
Expand Down
174 changes: 140 additions & 34 deletions python/nutpie/transform_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,43 +869,149 @@ def inv_transform(self, position, gradient):

def make_transform_adapter(
*,
verbose=False,
window_size=600,
show_progress=False,
nn_depth=None,
nn_width=None,
num_layers=8,
num_diag_windows=6,
learning_rate=5e-4,
verbose: bool = False,
window_size: int = 600,
show_progress: bool = False,
nn_depth: int | None = None,
nn_width: int | None = None,
num_layers: int = 8,
num_diag_windows: int = 6,
learning_rate: float = 5e-4,
untransformed_dim=None,
zero_init=True,
batch_size=128,
reuse_opt_state=False,
max_patience=20,
householder_layer=False,
dct_layer=False,
zero_init: bool = True,
batch_size: int = 128,
reuse_opt_state: bool = False,
max_patience: int = 20,
householder_layer: bool = False,
dct_layer: bool = False,
gamma=None,
log_inside_batch=False,
initial_skip=120,
extension_windows=None,
extend_dct=False,
extension_var_count=4,
extension_var_trafo_count=2,
debug_save_bijection=False,
make_optimizer=None,
coupling_type="masked",
mvscale_layer=False,
num_project=None,
num_embed=None,
num_householder=8,
twin_layers=False,
activation=None,
max_epochs=200,
affine_transformer=False,
contract_transformer=True,
asymmetric_transformer=False,
reuse_embed=True,
log_inside_batch: bool = False,
initial_skip: int = 120,
extension_windows: list[int] = None,
extend_dct: bool = False,
extension_var_count: int = 4,
extension_var_trafo_count: int = 2,
debug_save_bijection: bool = False,
make_optimizer: Callable | None = None,
coupling_type: str = "masked",
mvscale_layer: bool = False,
num_project: int | None = None,
num_embed: int | None = None,
num_householder: int = 8,
twin_layers: bool = False,
activation: str | None = None,
max_epochs: int = 200,
affine_transformer: bool = False,
contract_transformer: bool = True,
asymmetric_transformer: bool = False,
reuse_embed: bool = True,
):
"""
Create a TransformAdapter instance with the specified parameters.

A TransformAdapter is a utility for parameterizing a normalizing flow model used inside MCMC sampling. For more
details, see the documentation here

Parameters
----------
verbose: bool, default False
If True, print debug information, including random seed, available points, and loss value, to the terminal
during training.
window_size: int, default 600
???
show_progress: bool, default False
If True, show a TQDM progress bar during training of the flow network. Note that when using multiple chains,
this will quickly become extremely spammy!
nn_depth: int | None, default None
Number of layers in the neural network used for the flow. If None, defaults to 1.
nn_width: int | None, default None
Number of hidden units in each layer of the flow network. If None, defaults to 32.
num_layers: int, default 8
Number of flow layers to use in the flow network. Each layer will be parameterized according to nn_dept and
nn_width.
num_diag_windows: int, default 6
Number of diagonal mass matrix updates to perform before starting the flow training.
learning_rate: float, default 5e-4
Learning rate passed to the optimizer used to train the flow network. If a custom optimizer is provided via
the make_optimizer argument, this is ignored.
untransformed_dim: int | None, default None
???
zero_init: bool, default True
If True, all weights in the flow network are initialized to zero. Otherwise, initialization is done according to
the default flax initialization scheme (lecun_normal)
batch_size: int, default 128
Number of samples to use in each training batch.
reuse_opt_state: bool, default False
If True, the optimizer state (gradients and optimizer parameters) are stored and reused between updates.
Otherwise, training is restarted from scratch at each training update.
max_patience: int, default 20
Number of consecutive epochs with no validation loss improvement after which training is terminated.
householder_layer: bool, default False
If True, insert Householder transformation layers into the flow network. For more details, see the householder
layer documentation.
dct_layer: bool, default False
If True, insert discrete cosine transformation (DCT) layers into the flow network. For more details, see the
DCT layer documentation.
gamma: float | None, default None
???
log_inside_batch: bool, default False
???
initial_skip: int, default 120
Number of initial samples to completely ignore before flow training. Initial samples are often not sufficiently
representative of the target distribution, and ignoring them can help the flow network to converge.
extension_windows: list[int] | None, default None
???
extend_dct: bool, default False
???
extension_var_count: int, default 4
???
extension_var_trafo_count: int, default 2
???
debug_save_bijection: bool, default False
???
make_optimizer: Callable | None, default None
A function with no arguments that returns an optax optimizer. The default is optax.adamw(learning_rate),
wrapped by optax.apply_if_finite.
coupling_type: str, default "masked"
One of "subset", "masked", "flowjax_coupling", or "twin". This determines the type of coupling layer used
to construct the normalizing flow. For more details, see the coupling type documentation.
mvscale_layer: bool, default False
If True, re-scale parameters using a mean vector and covariance matrix. Ignored if coupling_type is not
"masked". Currently unused.
num_project: int | None, default None
??? Default is 2 * nn_width.
num_embed: int | None, default None
??? Default is 2 * nn_width.
num_householder: int, default 1
If greater than 0, the number of Householder layers to use in the flow network. Layers added in this way are
distinct from the (single) household layer inserted when householder_layer is True. Ignored if couping_type is
not "subset" or "twin".
twin_layers: bool, default False
If True, use twin layers in the flow network. This doubles the number of flow layers by masking each layer
twice, where each mask is the inverse of the previous. This should allow more expression flows to be learned.
activation: str, default "leaky_relu"
Nonlinearity to insert between flow layer. One of "relu", "leaky_relu", "gelu", "tanh", or "sigmoid".
max_epochs: int, default 200
Maximum number of training epochs to perform when training the flow network.
affine_transformer: bool, default False
If True, parameters are added to the flow network to learn the location and scale of each sample. This can be
seen as a latent non-centered parameterization, or a type of batch norm. Ignored if coupling_type is not
"masked" or "twin".
contract_transformer: bool, default False
??? Ignored if coupling_type is not "masked" or "twin".
asymmetric_transformer: bool, default False
If True, parameters are added to the flow network to the location and scale of each sample. Unlike
affine_transformer, the asymmetric transformer learns two scales, one for positive inputs, and one for negative
inputs. Ignored if coupling_type is not "masked" or "twin".
reuse_embed: bool, default False
???

Returns
-------
configured_adapter: TransformAdapter
A partially initialized TransformAdapter with the specified parameters.
"""

if extension_windows is None:
extension_windows = []

Expand Down
Loading