-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Draft external sampler API #7880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright 2025 - present The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from pymc.sampling.external.base import ExternalSampler | ||
from pymc.sampling.external.jax import Blackjax, Numpyro | ||
from pymc.sampling.external.nutpie import Nutpie | ||
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,48 @@ | ||||||
# Copyright 2025 - present The PyMC Developers | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
from abc import ABC, abstractmethod | ||||||
|
||||||
from pymc.model.core import modelcontext | ||||||
from pymc.util import get_value_vars_from_user_vars | ||||||
|
||||||
|
||||||
class ExternalSampler(ABC): | ||||||
def __init__(self, vars=None, model=None): | ||||||
model = modelcontext(model) | ||||||
if vars is None: | ||||||
vars = model.free_RVs | ||||||
else: | ||||||
vars = get_value_vars_from_user_vars(vars, model=model) | ||||||
if set(vars) != set(model.free_RVs): | ||||||
raise ValueError( | ||||||
"External samplers must sample all the model free_RVs, not just a subset" | ||||||
) | ||||||
self.vars = vars | ||||||
self.model = model | ||||||
|
||||||
@abstractmethod | ||||||
def sample( | ||||||
self, | ||||||
tune, | ||||||
draws, | ||||||
chains, | ||||||
initvals, | ||||||
random_seed, | ||||||
progressbar, | ||||||
var_names, | ||||||
idata_kwargs, | ||||||
compute_convergence_checks, | ||||||
**kwargs, | ||||||
): | ||||||
Comment on lines
+35
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Golden opportunity to add type hints and return types early! |
||||||
pass | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
make it fail fast |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Copyright 2025 - present The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from collections.abc import Sequence | ||
from typing import Literal | ||
|
||
from arviz import InferenceData | ||
|
||
from pymc.sampling.external.base import ExternalSampler | ||
from pymc.util import RandomState | ||
|
||
|
||
class JAXSampler(ExternalSampler): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably best not to assume that any JAXSampler is a NUTS sampler |
||
nuts_sampler = None # Should be defined by subclass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be enforced by making an abstract method property that inheritors must implement, I've found that pattern quite useful |
||
|
||
def __init__( | ||
self, | ||
vars=None, | ||
model=None, | ||
postprocessing_backend: Literal["cpu", "gpu"] | None = None, | ||
chain_method: Literal["parallel", "vectorized"] = "parallel", | ||
jitter: bool = True, | ||
keep_untransformed: bool = False, | ||
nuts_kwargs: dict | None = None, | ||
): | ||
super().__init__(vars, model) | ||
self.postprocessing_backend = postprocessing_backend | ||
self.chain_method = chain_method | ||
self.jitter = jitter | ||
self.keep_untransformed = keep_untransformed | ||
self.nuts_kwargs = nuts_kwargs or {} | ||
|
||
def sample( | ||
self, | ||
*, | ||
tune: int = 1000, | ||
draws: int = 1000, | ||
chains: int = 4, | ||
initvals=None, | ||
random_seed: RandomState | None = None, | ||
progressbar: bool = True, | ||
var_names: Sequence[str] | None = None, | ||
idata_kwargs: dict | None = None, | ||
compute_convergence_checks: bool = True, | ||
target_accept: float = 0.8, | ||
nuts_sampler, | ||
**kwargs, | ||
) -> InferenceData: | ||
from pymc.sampling.jax import sample_jax_nuts | ||
|
||
return sample_jax_nuts( | ||
tune=tune, | ||
draws=draws, | ||
chains=chains, | ||
target_accept=target_accept, | ||
random_seed=random_seed, | ||
var_names=var_names, | ||
progressbar=progressbar, | ||
idata_kwargs=idata_kwargs, | ||
compute_convergence_checks=compute_convergence_checks, | ||
initvals=initvals, | ||
jitter=self.jitter, | ||
model=self.model, | ||
chain_method=self.chain_method, | ||
postprocessing_backend=self.postprocessing_backend, | ||
keep_untransformed=self.keep_untransformed, | ||
nuts_kwargs=self.nuts_kwargs, | ||
nuts_sampler=self.nuts_sampler, | ||
**kwargs, | ||
) | ||
|
||
|
||
class Numpyro(JAXSampler): | ||
nuts_sampler = "numpyro" | ||
|
||
|
||
class Blackjax(JAXSampler): | ||
nuts_sampler = "blackjax" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright 2025 - present The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import warnings | ||
|
||
from arviz import InferenceData, dict_to_dataset | ||
from pytensor.scalar import discrete_dtypes | ||
|
||
from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations | ||
from pymc.sampling.external.base import ExternalSampler | ||
from pymc.stats.convergence import log_warnings, run_convergence_checks | ||
from pymc.util import _get_seeds_per_chain | ||
|
||
|
||
class Nutpie(ExternalSampler): | ||
def __init__( | ||
self, | ||
vars=None, | ||
model=None, | ||
backend="numba", | ||
gradient_backend="pytensor", | ||
compile_kwargs=None, | ||
sample_kwargs=None, | ||
): | ||
super().__init__(vars, model) | ||
if any(var.dtype in discrete_dtypes for var in self.vars): | ||
raise ValueError("Nutpie can only sample continuous variables") | ||
self.backend = backend | ||
self.gradient_backend = gradient_backend | ||
self.compile_kwargs = compile_kwargs or {} | ||
self.sample_kwargs = sample_kwargs or {} | ||
|
||
def sample( | ||
self, | ||
*, | ||
tune, | ||
draws, | ||
chains, | ||
initvals, | ||
random_seed, | ||
progressbar, | ||
var_names, | ||
idata_kwargs, | ||
compute_convergence_checks, | ||
**kwargs, | ||
): | ||
try: | ||
import nutpie | ||
except ImportError as err: | ||
raise ImportError( | ||
"nutpie not found. Install it with conda install -c conda-forge nutpie" | ||
) from err | ||
|
||
from nutpie.sample import _BackgroundSampler | ||
|
||
if initvals: | ||
warnings.warn( | ||
"initvals are currently ignored by the nutpie sampler.", | ||
UserWarning, | ||
) | ||
if idata_kwargs: | ||
warnings.warn( | ||
"idata_kwargs are currently ignored by the nutpie sampler.", | ||
UserWarning, | ||
) | ||
|
||
compiled_model = nutpie.compile_pymc_model( | ||
self.model, | ||
var_names=var_names, | ||
backend=self.backend, | ||
gradient_backend=self.gradient_backend, | ||
**self.compile_kwargs, | ||
) | ||
|
||
result = nutpie.sample( | ||
compiled_model, | ||
tune=tune, | ||
draws=draws, | ||
chains=chains, | ||
seed=_get_seeds_per_chain(random_seed, 1)[0], | ||
progress_bar=progressbar, | ||
**self.sample_kwargs, | ||
**kwargs, | ||
) | ||
if isinstance(result, _BackgroundSampler): | ||
# Wrap _BackgroundSampler so that when sampling is finished we run post_process_sampler | ||
class NutpieBackgroundSamplerWrapper(_BackgroundSampler): | ||
def __init__(self, *args, pymc_model, compute_convergence_checks, **kwargs): | ||
self.pymc_model = pymc_model | ||
self.compute_convergence_checks = compute_convergence_checks | ||
super().__init__(*args, **kwargs, return_raw_trace=False) | ||
|
||
def _extract(self, *args, **kwargs): | ||
idata = super()._extract(*args, **kwargs) | ||
return Nutpie._post_process_sample( | ||
model=self.pymc_model, | ||
idata=idata, | ||
compute_convergence_checks=self.compute_convergence_checks, | ||
) | ||
|
||
# non-blocked sampling | ||
return NutpieBackgroundSamplerWrapper( | ||
result, | ||
pymc_model=self.model, | ||
compute_convergence_checks=compute_convergence_checks, | ||
) | ||
else: | ||
return self._post_process_sample(self.model, result, compute_convergence_checks) | ||
|
||
@staticmethod | ||
def _post_process_sample( | ||
model, idata: InferenceData, compute_convergence_checks | ||
) -> InferenceData: | ||
# Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed | ||
# gather observed and constant data as nutpie.sample() has no access to the PyMC model | ||
if compute_convergence_checks: | ||
log_warnings(run_convergence_checks(idata, model)) | ||
|
||
coords, dims = coords_and_dims_for_inferencedata(model) | ||
constant_data = dict_to_dataset( | ||
find_constants(model), | ||
library=idata.attrs.get("library", None), | ||
coords=coords, | ||
dims=dims, | ||
default_dims=[], | ||
) | ||
observed_data = dict_to_dataset( | ||
find_observations(model), | ||
library=idata.attrs.get("library", None), | ||
coords=coords, | ||
dims=dims, | ||
default_dims=[], | ||
) | ||
idata.add_groups( | ||
{"constant_data": constant_data, "observed_data": observed_data}, | ||
coords=coords, | ||
dims=dims, | ||
) | ||
return idata |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably don't need to export the ABC right?