Skip to content

Commit cf8e8a0

Browse files
committed
Second iteration
1 parent 41b786b commit cf8e8a0

File tree

7 files changed

+61
-61
lines changed

7 files changed

+61
-61
lines changed

pymc/sampling/external/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,5 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from pymc.sampling.external.base import ExternalSampler
1514
from pymc.sampling.external.jax import Blackjax, Numpyro
1615
from pymc.sampling.external.nutpie import Nutpie

pymc/sampling/external/base.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,37 +12,40 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15+
from collections.abc import Sequence
16+
from typing import Any
17+
18+
from pytensor.scalar import discrete_dtypes
1519

1620
from pymc.model.core import modelcontext
17-
from pymc.util import get_value_vars_from_user_vars
21+
from pymc.util import RandomSeed
1822

1923

2024
class ExternalSampler(ABC):
21-
def __init__(self, vars=None, model=None):
25+
def __init__(self, model=None):
2226
model = modelcontext(model)
23-
if vars is None:
24-
vars = model.free_RVs
25-
else:
26-
vars = get_value_vars_from_user_vars(vars, model=model)
27-
if set(vars) != set(model.free_RVs):
28-
raise ValueError(
29-
"External samplers must sample all the model free_RVs, not just a subset"
30-
)
31-
self.vars = vars
3227
self.model = model
3328

3429
@abstractmethod
3530
def sample(
3631
self,
37-
tune,
38-
draws,
39-
chains,
40-
initvals,
41-
random_seed,
42-
progressbar,
43-
var_names,
44-
idata_kwargs,
45-
compute_convergence_checks,
32+
*,
33+
tune: int,
34+
draws: int,
35+
chains: int,
36+
initvals: dict[str, Any] | Sequence[dict[str, Any]],
37+
random_seed: RandomSeed,
38+
progressbar: bool,
39+
var_names: Sequence[str] | None = None,
40+
idata_kwargs: dict[str, Any] | None = None,
41+
compute_convergence_checks: bool,
4642
**kwargs,
4743
):
4844
pass
45+
46+
47+
class NUTSExternalSampler(ExternalSampler):
48+
def __init__(self, model=None):
49+
super().__init__(model)
50+
if any(var.dtype in discrete_dtypes for var in model.free_RVs):
51+
raise ValueError("External NUTS samplers can only sample continuous variables")

pymc/sampling/external/jax.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,23 @@
1616

1717
from arviz import InferenceData
1818

19-
from pymc.sampling.external.base import ExternalSampler
19+
from pymc.sampling.external.base import NUTSExternalSampler
2020
from pymc.util import RandomState
2121

2222

23-
class JAXSampler(ExternalSampler):
24-
nuts_sampler = None # Should be defined by subclass
23+
class JAXNUTSSampler(NUTSExternalSampler):
24+
nuts_sampler: Literal["numpyro", "blackjax"]
2525

2626
def __init__(
2727
self,
28-
vars=None,
2928
model=None,
3029
postprocessing_backend: Literal["cpu", "gpu"] | None = None,
3130
chain_method: Literal["parallel", "vectorized"] = "parallel",
3231
jitter: bool = True,
3332
keep_untransformed: bool = False,
3433
nuts_kwargs: dict | None = None,
3534
):
36-
super().__init__(vars, model)
35+
super().__init__(model)
3736
self.postprocessing_backend = postprocessing_backend
3837
self.chain_method = chain_method
3938
self.jitter = jitter
@@ -53,7 +52,6 @@ def sample(
5352
idata_kwargs: dict | None = None,
5453
compute_convergence_checks: bool = True,
5554
target_accept: float = 0.8,
56-
nuts_sampler,
5755
**kwargs,
5856
) -> InferenceData:
5957
from pymc.sampling.jax import sample_jax_nuts
@@ -80,9 +78,9 @@ def sample(
8078
)
8179

8280

83-
class Numpyro(JAXSampler):
81+
class Numpyro(JAXNUTSSampler):
8482
nuts_sampler = "numpyro"
8583

8684

87-
class Blackjax(JAXSampler):
85+
class Blackjax(JAXNUTSSampler):
8886
nuts_sampler = "blackjax"

pymc/sampling/external/nutpie.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,28 @@
1414
import warnings
1515

1616
from arviz import InferenceData, dict_to_dataset
17-
from pytensor.scalar import discrete_dtypes
1817

1918
from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations
20-
from pymc.sampling.external.base import ExternalSampler
19+
from pymc.sampling.external.base import NUTSExternalSampler
2120
from pymc.stats.convergence import log_warnings, run_convergence_checks
2221
from pymc.util import _get_seeds_per_chain
2322

2423

25-
class Nutpie(ExternalSampler):
24+
class Nutpie(NUTSExternalSampler):
2625
def __init__(
2726
self,
28-
vars=None,
2927
model=None,
3028
backend="numba",
3129
gradient_backend="pytensor",
3230
compile_kwargs=None,
3331
sample_kwargs=None,
3432
):
35-
super().__init__(vars, model)
36-
if any(var.dtype in discrete_dtypes for var in self.vars):
37-
raise ValueError("Nutpie can only sample continuous variables")
33+
super().__init__(model)
3834
self.backend = backend
3935
self.gradient_backend = gradient_backend
4036
self.compile_kwargs = compile_kwargs or {}
4137
self.sample_kwargs = sample_kwargs or {}
38+
self.compiled_model = None
4239

4340
def sample(
4441
self,

pymc/sampling/mcmc.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,11 @@ def sample(
292292
chains: int | None = None,
293293
cores: int | None = None,
294294
random_seed: RandomState = None,
295+
step=None,
296+
external_sampler: ExternalSampler | None = None,
295297
progressbar: bool | ProgressBarType = True,
296298
progressbar_theme: Theme | None = default_progress_theme,
297-
step=None,
298299
var_names: Sequence[str] | None = None,
299-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
300300
initvals: StartDict | Sequence[StartDict | None] | None = None,
301301
init: str = "auto",
302302
jitter_max_retries: int = 10,
@@ -324,11 +324,11 @@ def sample(
324324
chains: int | None = None,
325325
cores: int | None = None,
326326
random_seed: RandomState = None,
327+
step=None,
328+
external_sampler: ExternalSampler | None = None,
327329
progressbar: bool | ProgressBarType = True,
328330
progressbar_theme: Theme | None = default_progress_theme,
329-
step=None,
330331
var_names: Sequence[str] | None = None,
331-
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
332332
initvals: StartDict | Sequence[StartDict | None] | None = None,
333333
init: str = "auto",
334334
jitter_max_retries: int = 10,
@@ -356,11 +356,11 @@ def sample(
356356
chains: int | None = None,
357357
cores: int | None = None,
358358
random_seed: RandomState = None,
359+
step=None,
360+
external_sampler: ExternalSampler | None = None,
359361
progressbar: bool | ProgressBarType = True,
360362
progressbar_theme: Theme | None = None,
361-
step=None,
362363
var_names: Sequence[str] | None = None,
363-
nuts_sampler: None | Literal["pymc", "nutpie", "numpyro", "blackjax"] = None,
364364
initvals: StartDict | Sequence[StartDict | None] | None = None,
365365
init: str = "auto",
366366
jitter_max_retries: int = 10,
@@ -407,6 +407,12 @@ def sample(
407407
A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
408408
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
409409
easy spawning of new independent random streams that are needed by the step methods.
410+
step : function or iterable of functions, optional
411+
A step function or collection of functions. If there are variables without step methods,
412+
step methods for those variables will be assigned automatically. By default the NUTS step
413+
method will be used, if appropriate to the model. Not compatible with external_sampler
414+
external_sampler: ExternalSampler, optional
415+
An external sampler to sample the whole model. Not compatible with step.
410416
progressbar: bool or ProgressType, optional
411417
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
412418
for one of the following:
@@ -419,10 +425,6 @@ def sample(
419425
are also displayed.
420426
421427
If True, the default is "split+stats" is used.
422-
step : function or iterable of functions
423-
A step function or collection of functions. If there are variables without step methods,
424-
step methods for those variables will be assigned automatically. By default the NUTS step
425-
method will be used, if appropriate to the model.
426428
var_names : list of str, optional
427429
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
428430
nuts_sampler : str
@@ -608,35 +610,38 @@ def joined_blas_limiter():
608610
rngs = get_random_generator(random_seed).spawn(chains)
609611
random_seed_list = [rng.integers(2**30) for rng in rngs]
610612

611-
if step is None and nuts_sampler not in (None, "pymc"):
612-
# Temporarily instantiate external samplers for user, for backwards-compat
613+
if "nuts_sampler" in kwargs:
614+
# Transition backwards-compatibility
615+
nuts_sampler = kwargs.pop("nuts_sampler")
613616
warnings.warn(
614617
f"Setting `pm.sample(nuts_sampler='{nuts_sampler}, nuts_sampler_kwargs=...)'` is deprecated.\n"
615-
f"Use `pm.sample(step=pm.external.{nuts_sampler.capitalize()}(**nuts_sampler_kwargs))` instead",
618+
f"Use `pm.sample(external_sampler=pm.external.{nuts_sampler.capitalize()}(**nuts_sampler_kwargs))` instead",
616619
FutureWarning,
617620
)
618621
from pymc.sampling import external
619622

620-
step = getattr(external, nuts_sampler.capitalize())(
623+
external_sampler = getattr(external, nuts_sampler.capitalize())(
621624
model=model,
622625
**(nuts_sampler_kwargs or {}),
623626
)
624627
nuts_sampler_kwargs = None
625628

626-
if isinstance(step, list | tuple) and len(step) == 1:
627-
[step] = step
629+
if external_sampler is not None:
630+
if step is not None:
631+
raise ValueError("`step` and `external_sampler` cannot be used together")
628632

629-
if isinstance(step, ExternalSampler):
630-
if step.model is not model:
631-
raise ValueError("External step model does not match model detected by sample")
633+
if external_sampler.model is not model:
634+
raise ValueError(
635+
"External sampler model does not match model detected by sample function"
636+
)
632637
if nuts_sampler_kwargs:
633638
raise ValueError(
634639
f"{nuts_sampler_kwargs=} should be passed when constructing external sampler"
635640
)
636641
if "nuts" in kwargs:
637-
kwargs.update(kwargs["nuts"].pop())
642+
kwargs.update(kwargs.pop("nuts"))
638643
with joined_blas_limiter():
639-
return step.sample(
644+
return external_sampler.sample(
640645
tune=tune,
641646
draws=draws,
642647
chains=chains,

tests/sampling/test_jax.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def test_jax_PosDefMatrix():
8282
),
8383
],
8484
)
85-
@pytest.mark.parametrize("postprocessing_vectorize", ["scan", "vmap"])
86-
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_vectorize):
85+
def test_transform_samples(sampler, postprocessing_backend, chains):
8786
pytensor.config.on_opt_error = "raise"
8887
np.random.seed(13244)
8988

@@ -99,7 +98,6 @@ def test_transform_samples(sampler, postprocessing_backend, chains, postprocessi
9998
random_seed=1322,
10099
keep_untransformed=True,
101100
postprocessing_backend=postprocessing_backend,
102-
postprocessing_vectorize=postprocessing_vectorize,
103101
)
104102

105103
log_vals = trace.posterior["sigma_log__"].values

tests/sampling/test_mcmc_external.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_external_nuts_sampler(recwarn, nuts_sampler):
6060
(
6161
UserWarning,
6262
"`initvals` are currently not passed to nutpie sampler. "
63-
"Use `init_mean` kwarg following nutpie specification instead.",
63+
"Use `init_mean` kwarg following nutpie specification instead.'",
6464
)
6565
)
6666
assert warns == expected

0 commit comments

Comments
 (0)