Skip to content

Draft external sampler API #7880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __set_compiler_flags():
from pymc.printing import *
from pymc.pytensorf import *
from pymc.sampling import *
from pymc.sampling import external
from pymc.smc import *
from pymc.stats import *
from pymc.step_methods import *
Expand Down
16 changes: 16 additions & 0 deletions pymc/sampling/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pymc.sampling.external.base import ExternalSampler
from pymc.sampling.external.jax import Blackjax, Numpyro
from pymc.sampling.external.nutpie import Nutpie
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from pymc.sampling.external.nutpie import Nutpie
from pymc.sampling.external.nutpie import Nutpie
__all__ = ["Blackjax", "Numpyro", "Nutpie"]

You probably don't need to export the ABC right?

48 changes: 48 additions & 0 deletions pymc/sampling/external/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod

from pymc.model.core import modelcontext
from pymc.util import get_value_vars_from_user_vars


class ExternalSampler(ABC):
def __init__(self, vars=None, model=None):
model = modelcontext(model)
if vars is None:
vars = model.free_RVs
else:
vars = get_value_vars_from_user_vars(vars, model=model)
if set(vars) != set(model.free_RVs):
raise ValueError(
"External samplers must sample all the model free_RVs, not just a subset"
)
self.vars = vars
self.model = model

@abstractmethod
def sample(
self,
tune,
draws,
chains,
initvals,
random_seed,
progressbar,
var_names,
idata_kwargs,
compute_convergence_checks,
**kwargs,
):
Comment on lines +35 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Golden opportunity to add type hints and return types early!

pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pass
raise NotImplementedError

make it fail fast

88 changes: 88 additions & 0 deletions pymc/sampling/external/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Literal

from arviz import InferenceData

from pymc.sampling.external.base import ExternalSampler
from pymc.util import RandomState


class JAXSampler(ExternalSampler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best not to assume that any JAXSampler is a NUTS sampler

nuts_sampler = None # Should be defined by subclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be enforced by making an abstract method property that inheritors must implement, I've found that pattern quite useful


def __init__(
self,
vars=None,
model=None,
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
chain_method: Literal["parallel", "vectorized"] = "parallel",
jitter: bool = True,
keep_untransformed: bool = False,
nuts_kwargs: dict | None = None,
):
super().__init__(vars, model)
self.postprocessing_backend = postprocessing_backend
self.chain_method = chain_method
self.jitter = jitter
self.keep_untransformed = keep_untransformed
self.nuts_kwargs = nuts_kwargs or {}

def sample(
self,
*,
tune: int = 1000,
draws: int = 1000,
chains: int = 4,
initvals=None,
random_seed: RandomState | None = None,
progressbar: bool = True,
var_names: Sequence[str] | None = None,
idata_kwargs: dict | None = None,
compute_convergence_checks: bool = True,
target_accept: float = 0.8,
nuts_sampler,
**kwargs,
) -> InferenceData:
from pymc.sampling.jax import sample_jax_nuts

return sample_jax_nuts(
tune=tune,
draws=draws,
chains=chains,
target_accept=target_accept,
random_seed=random_seed,
var_names=var_names,
progressbar=progressbar,
idata_kwargs=idata_kwargs,
compute_convergence_checks=compute_convergence_checks,
initvals=initvals,
jitter=self.jitter,
model=self.model,
chain_method=self.chain_method,
postprocessing_backend=self.postprocessing_backend,
keep_untransformed=self.keep_untransformed,
nuts_kwargs=self.nuts_kwargs,
nuts_sampler=self.nuts_sampler,
**kwargs,
)


class Numpyro(JAXSampler):
nuts_sampler = "numpyro"


class Blackjax(JAXSampler):
nuts_sampler = "blackjax"
149 changes: 149 additions & 0 deletions pymc/sampling/external/nutpie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2025 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from arviz import InferenceData, dict_to_dataset
from pytensor.scalar import discrete_dtypes

from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations
from pymc.sampling.external.base import ExternalSampler
from pymc.stats.convergence import log_warnings, run_convergence_checks
from pymc.util import _get_seeds_per_chain


class Nutpie(ExternalSampler):
def __init__(
self,
vars=None,
model=None,
backend="numba",
gradient_backend="pytensor",
compile_kwargs=None,
sample_kwargs=None,
):
super().__init__(vars, model)
if any(var.dtype in discrete_dtypes for var in self.vars):
raise ValueError("Nutpie can only sample continuous variables")
self.backend = backend
self.gradient_backend = gradient_backend
self.compile_kwargs = compile_kwargs or {}
self.sample_kwargs = sample_kwargs or {}

def sample(
self,
*,
tune,
draws,
chains,
initvals,
random_seed,
progressbar,
var_names,
idata_kwargs,
compute_convergence_checks,
**kwargs,
):
try:
import nutpie
except ImportError as err:
raise ImportError(
"nutpie not found. Install it with conda install -c conda-forge nutpie"
) from err

from nutpie.sample import _BackgroundSampler

if initvals:
warnings.warn(
"initvals are currently ignored by the nutpie sampler.",
UserWarning,
)
if idata_kwargs:
warnings.warn(
"idata_kwargs are currently ignored by the nutpie sampler.",
UserWarning,
)

compiled_model = nutpie.compile_pymc_model(
self.model,
var_names=var_names,
backend=self.backend,
gradient_backend=self.gradient_backend,
**self.compile_kwargs,
)

result = nutpie.sample(
compiled_model,
tune=tune,
draws=draws,
chains=chains,
seed=_get_seeds_per_chain(random_seed, 1)[0],
progress_bar=progressbar,
**self.sample_kwargs,
**kwargs,
)
if isinstance(result, _BackgroundSampler):
# Wrap _BackgroundSampler so that when sampling is finished we run post_process_sampler
class NutpieBackgroundSamplerWrapper(_BackgroundSampler):
def __init__(self, *args, pymc_model, compute_convergence_checks, **kwargs):
self.pymc_model = pymc_model
self.compute_convergence_checks = compute_convergence_checks
super().__init__(*args, **kwargs, return_raw_trace=False)

def _extract(self, *args, **kwargs):
idata = super()._extract(*args, **kwargs)
return Nutpie._post_process_sample(
model=self.pymc_model,
idata=idata,
compute_convergence_checks=self.compute_convergence_checks,
)

# non-blocked sampling
return NutpieBackgroundSamplerWrapper(
result,
pymc_model=self.model,
compute_convergence_checks=compute_convergence_checks,
)
else:
return self._post_process_sample(self.model, result, compute_convergence_checks)

@staticmethod
def _post_process_sample(
model, idata: InferenceData, compute_convergence_checks
) -> InferenceData:
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed
# gather observed and constant data as nutpie.sample() has no access to the PyMC model
if compute_convergence_checks:
log_warnings(run_convergence_checks(idata, model))

coords, dims = coords_and_dims_for_inferencedata(model)
constant_data = dict_to_dataset(
find_constants(model),
library=idata.attrs.get("library", None),
coords=coords,
dims=dims,
default_dims=[],
)
observed_data = dict_to_dataset(
find_observations(model),
library=idata.attrs.get("library", None),
coords=coords,
dims=dims,
default_dims=[],
)
idata.add_groups(
{"constant_data": constant_data, "observed_data": observed_data},
coords=coords,
dims=dims,
)
return idata
Loading
Loading