Skip to content

Commit 6a73230

Browse files
committed
Second iteration
1 parent 41b786b commit 6a73230

File tree

8 files changed

+72
-76
lines changed

8 files changed

+72
-76
lines changed

pymc/sampling/external/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from pymc.sampling.external.base import ExternalSampler
1514
from pymc.sampling.external.jax import Blackjax, Numpyro
1615
from pymc.sampling.external.nutpie import Nutpie

pymc/sampling/external/base.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,40 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15+
from collections.abc import Sequence
16+
from typing import Any
17+
18+
from pytensor.scalar import discrete_dtypes
1519

1620
from pymc.model.core import modelcontext
17-
from pymc.util import get_value_vars_from_user_vars
21+
from pymc.util import RandomSeed
1822

1923

2024
class ExternalSampler(ABC):
21-
def __init__(self, vars=None, model=None):
25+
def __init__(self, model=None):
2226
model = modelcontext(model)
23-
if vars is None:
24-
vars = model.free_RVs
25-
else:
26-
vars = get_value_vars_from_user_vars(vars, model=model)
27-
if set(vars) != set(model.free_RVs):
28-
raise ValueError(
29-
"External samplers must sample all the model free_RVs, not just a subset"
30-
)
31-
self.vars = vars
3227
self.model = model
3328

3429
@abstractmethod
3530
def sample(
3631
self,
37-
tune,
38-
draws,
39-
chains,
40-
initvals,
41-
random_seed,
42-
progressbar,
43-
var_names,
44-
idata_kwargs,
45-
compute_convergence_checks,
32+
*,
33+
tune: int,
34+
draws: int,
35+
chains: int,
36+
initvals: dict[str, Any] | Sequence[dict[str, Any]],
37+
random_seed: RandomSeed,
38+
progressbar: bool,
39+
var_names: Sequence[str] | None = None,
40+
idata_kwargs: dict[str, Any] | None = None,
41+
compute_convergence_checks: bool,
4642
**kwargs,
4743
):
4844
pass
45+
46+
47+
class NUTSExternalSampler(ExternalSampler):
48+
def __init__(self, model=None):
49+
super().__init__(model)
50+
if any(var.dtype in discrete_dtypes for var in model.free_RVs):
51+
raise ValueError("External NUTS samplers can only sample continuous variables")

pymc/sampling/external/jax.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,23 @@
1616

1717
from arviz import InferenceData
1818

19-
from pymc.sampling.external.base import ExternalSampler
19+
from pymc.sampling.external.base import NUTSExternalSampler
2020
from pymc.util import RandomState
2121

2222

23-
class JAXSampler(ExternalSampler):
24-
nuts_sampler = None # Should be defined by subclass
23+
class JAXNUTSSampler(NUTSExternalSampler):
24+
nuts_sampler: Literal["numpyro", "blackjax"]
2525

2626
def __init__(
2727
self,
28-
vars=None,
2928
model=None,
3029
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
3130
chain_method: Literal["parallel", "vectorized"] = "parallel",
3231
jitter: bool = True,
3332
keep_untransformed: bool = False,
3433
nuts_kwargs: dict | None = None,
3534
):
36-
super().__init__(vars, model)
35+
super().__init__(model)
3736
self.postprocessing_backend = postprocessing_backend
3837
self.chain_method = chain_method
3938
self.jitter = jitter
@@ -53,7 +52,6 @@ def sample(
5352
idata_kwargs: dict | None = None,
5453
compute_convergence_checks: bool = True,
5554
target_accept: float = 0.8,
56-
nuts_sampler,
5755
**kwargs,
5856
) -> InferenceData:
5957
from pymc.sampling.jax import sample_jax_nuts
@@ -80,9 +78,9 @@ def sample(
8078
)
8179

8280

83-
class Numpyro(JAXSampler):
81+
class Numpyro(JAXNUTSSampler):
8482
nuts_sampler = "numpyro"
8583

8684

87-
class Blackjax(JAXSampler):
85+
class Blackjax(JAXNUTSSampler):
8886
nuts_sampler = "blackjax"

pymc/sampling/external/nutpie.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,28 @@
1414
import warnings
1515

1616
from arviz import InferenceData, dict_to_dataset
17-
from pytensor.scalar import discrete_dtypes
1817

1918
from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations
20-
from pymc.sampling.external.base import ExternalSampler
19+
from pymc.sampling.external.base import NUTSExternalSampler
2120
from pymc.stats.convergence import log_warnings, run_convergence_checks
2221
from pymc.util import _get_seeds_per_chain
2322

2423

25-
class Nutpie(ExternalSampler):
24+
class Nutpie(NUTSExternalSampler):
2625
def __init__(
2726
self,
28-
vars=None,
2927
model=None,
3028
backend="numba",
3129
gradient_backend="pytensor",
3230
compile_kwargs=None,
3331
sample_kwargs=None,
3432
):
35-
super().__init__(vars, model)
36-
if any(var.dtype in discrete_dtypes for var in self.vars):
37-
raise ValueError("Nutpie can only sample continuous variables")
33+
super().__init__(model)
3834
self.backend = backend
3935
self.gradient_backend = gradient_backend
4036
self.compile_kwargs = compile_kwargs or {}
4137
self.sample_kwargs = sample_kwargs or {}
38+
self.compiled_model = None
4239

4340
def sample(
4441
self,

pymc/sampling/jax.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,6 @@ def sample_jax_nuts(
665665
if "dims" in idata_kwargs:
666666
dims.update(idata_kwargs.pop("dims"))
667667

668-
# Use 'partial' to set default arguments before passing 'idata_kwargs'
669668
idata = az.from_dict(
670669
posterior=mcmc_samples,
671670
log_likelihood=log_likelihood,

pymc/sampling/mcmc.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,11 @@ def sample(
292292
chains: int | None = None,
293293
cores: int | None = None,
294294
random_seed: RandomState = None,
295+
step=None,
296+
external_sampler: ExternalSampler | None = None,
295297
progressbar: bool | ProgressBarType = True,
296298
progressbar_theme: Theme | None = default_progress_theme,
297-
step=None,
298299
var_names: Sequence[str] | None = None,
299-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
300300
initvals: StartDict | Sequence[StartDict | None] | None = None,
301301
init: str = "auto",
302302
jitter_max_retries: int = 10,
@@ -324,11 +324,11 @@ def sample(
324324
chains: int | None = None,
325325
cores: int | None = None,
326326
random_seed: RandomState = None,
327+
step=None,
328+
external_sampler: ExternalSampler | None = None,
327329
progressbar: bool | ProgressBarType = True,
328330
progressbar_theme: Theme | None = default_progress_theme,
329-
step=None,
330331
var_names: Sequence[str] | None = None,
331-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
332332
initvals: StartDict | Sequence[StartDict | None] | None = None,
333333
init: str = "auto",
334334
jitter_max_retries: int = 10,
@@ -356,11 +356,11 @@ def sample(
356356
chains: int | None = None,
357357
cores: int | None = None,
358358
random_seed: RandomState = None,
359+
step=None,
360+
external_sampler: ExternalSampler | None = None,
359361
progressbar: bool | ProgressBarType = True,
360362
progressbar_theme: Theme | None = None,
361-
step=None,
362363
var_names: Sequence[str] | None = None,
363-
nuts_sampler: None | Literal["pymc", "nutpie", "numpyro", "blackjax"] = None,
364364
initvals: StartDict | Sequence[StartDict | None] | None = None,
365365
init: str = "auto",
366366
jitter_max_retries: int = 10,
@@ -407,6 +407,12 @@ def sample(
407407
A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
408408
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
409409
easy spawning of new independent random streams that are needed by the step methods.
410+
step : function or iterable of functions, optional
411+
A step function or collection of functions. If there are variables without step methods,
412+
step methods for those variables will be assigned automatically. By default the NUTS step
413+
method will be used, if appropriate to the model. Not compatible with external_sampler
414+
external_sampler: ExternalSampler, optional
415+
An external sampler to sample the whole model. Not compatible with step.
410416
progressbar: bool or ProgressType, optional
411417
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
412418
for one of the following:
@@ -419,16 +425,8 @@ def sample(
419425
are also displayed.
420426
421427
If True, the default is "split+stats" is used.
422-
step : function or iterable of functions
423-
A step function or collection of functions. If there are variables without step methods,
424-
step methods for those variables will be assigned automatically. By default the NUTS step
425-
method will be used, if appropriate to the model.
426428
var_names : list of str, optional
427429
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
428-
nuts_sampler : str
429-
Which NUTS implementation to run. One of ["pymc", "nutpie", "blackjax", "numpyro"].
430-
This requires the chosen sampler to be installed.
431-
All samplers, except "pymc", require the full model to be continuous.
432430
blas_cores: int or "auto" or None, default = "auto"
433431
The total number of threads blas and openmp functions should use during sampling.
434432
Setting it to "auto" will ensure that the total number of active blas threads is the
@@ -608,35 +606,40 @@ def joined_blas_limiter():
608606
rngs = get_random_generator(random_seed).spawn(chains)
609607
random_seed_list = [rng.integers(2**30) for rng in rngs]
610608

611-
if step is None and nuts_sampler not in (None, "pymc"):
612-
# Temporarily instantiate external samplers for user, for backwards-compat
613-
warnings.warn(
614-
f"Setting `pm.sample(nuts_sampler='{nuts_sampler}, nuts_sampler_kwargs=...)'` is deprecated.\n"
615-
f"Use `pm.sample(step=pm.external.{nuts_sampler.capitalize()}(**nuts_sampler_kwargs))` instead",
616-
FutureWarning,
617-
)
618-
from pymc.sampling import external
609+
if "nuts_sampler" in kwargs:
610+
# Transition backwards-compatibility
611+
nuts_sampler = kwargs.pop("nuts_sampler")
612+
if nuts_sampler != "pymc":
613+
warnings.warn(
614+
f"Setting `pm.sample(nuts_sampler='{nuts_sampler}, nuts_sampler_kwargs=...)'` is deprecated.\n"
615+
f"Use `pm.sample(external_sampler=pm.external.{nuts_sampler.capitalize()}(**nuts_sampler_kwargs))` instead",
616+
FutureWarning,
617+
)
618+
from pymc.sampling import external
619619

620-
step = getattr(external, nuts_sampler.capitalize())(
621-
model=model,
622-
**(nuts_sampler_kwargs or {}),
623-
)
624-
nuts_sampler_kwargs = None
620+
external_sampler = getattr(external, nuts_sampler.capitalize())(
621+
model=model,
622+
**(nuts_sampler_kwargs or {}),
623+
**(kwargs.pop("nuts") or {}),
624+
)
625+
nuts_sampler_kwargs = None
625626

626-
if isinstance(step, list | tuple) and len(step) == 1:
627-
[step] = step
627+
if external_sampler is not None:
628+
if step is not None:
629+
raise ValueError("`step` and `external_sampler` cannot be used together")
628630

629-
if isinstance(step, ExternalSampler):
630-
if step.model is not model:
631-
raise ValueError("External step model does not match model detected by sample")
631+
if external_sampler.model is not model:
632+
raise ValueError(
633+
"External sampler model does not match model detected by sample function"
634+
)
632635
if nuts_sampler_kwargs:
633636
raise ValueError(
634637
f"{nuts_sampler_kwargs=} should be passed when constructing external sampler"
635638
)
636639
if "nuts" in kwargs:
637-
kwargs.update(kwargs["nuts"].pop())
640+
kwargs.update(kwargs.pop("nuts"))
638641
with joined_blas_limiter():
639-
return step.sample(
642+
return external_sampler.sample(
640643
tune=tune,
641644
draws=draws,
642645
chains=chains,

tests/sampling/test_jax.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def test_jax_PosDefMatrix():
8282
),
8383
],
8484
)
85-
@pytest.mark.parametrize("postprocessing_vectorize", ["scan", "vmap"])
86-
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_vectorize):
85+
def test_transform_samples(sampler, postprocessing_backend, chains):
8786
pytensor.config.on_opt_error = "raise"
8887
np.random.seed(13244)
8988

@@ -99,7 +98,6 @@ def test_transform_samples(sampler, postprocessing_backend, chains, postprocessi
9998
random_seed=1322,
10099
keep_untransformed=True,
101100
postprocessing_backend=postprocessing_backend,
102-
postprocessing_vectorize=postprocessing_vectorize,
103101
)
104102

105103
log_vals = trace.posterior["sigma_log__"].values

tests/sampling/test_mcmc_external.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
5959
expected.add(
6060
(
6161
UserWarning,
62-
"`initvals` are currently not passed to nutpie sampler. "
63-
"Use `init_mean` kwarg following nutpie specification instead.",
62+
"initvals are currently ignored by the nutpie sampler.",
6463
)
6564
)
6665
assert warns == expected

0 commit comments

Comments
 (0)