diff --git a/pymc/__init__.py b/pymc/__init__.py index 69a29c97e..5dae1ef57 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -71,6 +71,7 @@ def __set_compiler_flags(): from pymc.printing import * from pymc.pytensorf import * from pymc.sampling import * +from pymc.sampling import external from pymc.smc import * from pymc.stats import * from pymc.step_methods import * diff --git a/pymc/sampling/external/__init__.py b/pymc/sampling/external/__init__.py new file mode 100644 index 000000000..95c56c67e --- /dev/null +++ b/pymc/sampling/external/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pymc.sampling.external.base import ExternalSampler +from pymc.sampling.external.jax import Blackjax, Numpyro +from pymc.sampling.external.nutpie import Nutpie diff --git a/pymc/sampling/external/base.py b/pymc/sampling/external/base.py new file mode 100644 index 000000000..29683e013 --- /dev/null +++ b/pymc/sampling/external/base.py @@ -0,0 +1,48 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from pymc.model.core import modelcontext +from pymc.util import get_value_vars_from_user_vars + + +class ExternalSampler(ABC): + def __init__(self, vars=None, model=None): + model = modelcontext(model) + if vars is None: + vars = model.free_RVs + else: + vars = get_value_vars_from_user_vars(vars, model=model) + if set(vars) != set(model.free_RVs): + raise ValueError( + "External samplers must sample all the model free_RVs, not just a subset" + ) + self.vars = vars + self.model = model + + @abstractmethod + def sample( + self, + tune, + draws, + chains, + initvals, + random_seed, + progressbar, + var_names, + idata_kwargs, + compute_convergence_checks, + **kwargs, + ): + pass diff --git a/pymc/sampling/external/jax.py b/pymc/sampling/external/jax.py new file mode 100644 index 000000000..2910cc190 --- /dev/null +++ b/pymc/sampling/external/jax.py @@ -0,0 +1,88 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Literal + +from arviz import InferenceData + +from pymc.sampling.external.base import ExternalSampler +from pymc.util import RandomState + + +class JAXSampler(ExternalSampler): + nuts_sampler = None # Should be defined by subclass + + def __init__( + self, + vars=None, + model=None, + postprocessing_backend: Literal["cpu", "gpu"] | None = None, + chain_method: Literal["parallel", "vectorized"] = "parallel", + jitter: bool = True, + keep_untransformed: bool = False, + nuts_kwargs: dict | None = None, + ): + super().__init__(vars, model) + self.postprocessing_backend = postprocessing_backend + self.chain_method = chain_method + self.jitter = jitter + self.keep_untransformed = keep_untransformed + self.nuts_kwargs = nuts_kwargs or {} + + def sample( + self, + *, + tune: int = 1000, + draws: int = 1000, + chains: int = 4, + initvals=None, + random_seed: RandomState | None = None, + progressbar: bool = True, + var_names: Sequence[str] | None = None, + idata_kwargs: dict | None = None, + compute_convergence_checks: bool = True, + target_accept: float = 0.8, + nuts_sampler, + **kwargs, + ) -> InferenceData: + from pymc.sampling.jax import sample_jax_nuts + + return sample_jax_nuts( + tune=tune, + draws=draws, + chains=chains, + target_accept=target_accept, + random_seed=random_seed, + var_names=var_names, + progressbar=progressbar, + idata_kwargs=idata_kwargs, + compute_convergence_checks=compute_convergence_checks, + initvals=initvals, + jitter=self.jitter, + model=self.model, + chain_method=self.chain_method, + postprocessing_backend=self.postprocessing_backend, + keep_untransformed=self.keep_untransformed, + nuts_kwargs=self.nuts_kwargs, + nuts_sampler=self.nuts_sampler, + **kwargs, + ) + + +class Numpyro(JAXSampler): + nuts_sampler = "numpyro" + + +class Blackjax(JAXSampler): + nuts_sampler = "blackjax" diff --git a/pymc/sampling/external/nutpie.py b/pymc/sampling/external/nutpie.py new file mode 100644 index 000000000..3a9ac4f38 --- /dev/null +++ b/pymc/sampling/external/nutpie.py @@ -0,0 +1,149 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings + +from arviz import InferenceData, dict_to_dataset +from pytensor.scalar import discrete_dtypes + +from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations +from pymc.sampling.external.base import ExternalSampler +from pymc.stats.convergence import log_warnings, run_convergence_checks +from pymc.util import _get_seeds_per_chain + + +class Nutpie(ExternalSampler): + def __init__( + self, + vars=None, + model=None, + backend="numba", + gradient_backend="pytensor", + compile_kwargs=None, + sample_kwargs=None, + ): + super().__init__(vars, model) + if any(var.dtype in discrete_dtypes for var in self.vars): + raise ValueError("Nutpie can only sample continuous variables") + self.backend = backend + self.gradient_backend = gradient_backend + self.compile_kwargs = compile_kwargs or {} + self.sample_kwargs = sample_kwargs or {} + + def sample( + self, + *, + tune, + draws, + chains, + initvals, + random_seed, + progressbar, + var_names, + idata_kwargs, + compute_convergence_checks, + **kwargs, + ): + try: + import nutpie + except ImportError as err: + raise ImportError( + "nutpie not found. Install it with conda install -c conda-forge nutpie" + ) from err + + from nutpie.sample import _BackgroundSampler + + if initvals: + warnings.warn( + "initvals are currently ignored by the nutpie sampler.", + UserWarning, + ) + if idata_kwargs: + warnings.warn( + "idata_kwargs are currently ignored by the nutpie sampler.", + UserWarning, + ) + + compiled_model = nutpie.compile_pymc_model( + self.model, + var_names=var_names, + backend=self.backend, + gradient_backend=self.gradient_backend, + **self.compile_kwargs, + ) + + result = nutpie.sample( + compiled_model, + tune=tune, + draws=draws, + chains=chains, + seed=_get_seeds_per_chain(random_seed, 1)[0], + progress_bar=progressbar, + **self.sample_kwargs, + **kwargs, + ) + if isinstance(result, _BackgroundSampler): + # Wrap _BackgroundSampler so that when sampling is finished we run post_process_sampler + class NutpieBackgroundSamplerWrapper(_BackgroundSampler): + def __init__(self, *args, pymc_model, compute_convergence_checks, **kwargs): + self.pymc_model = pymc_model + self.compute_convergence_checks = compute_convergence_checks + super().__init__(*args, **kwargs, return_raw_trace=False) + + def _extract(self, *args, **kwargs): + idata = super()._extract(*args, **kwargs) + return Nutpie._post_process_sample( + model=self.pymc_model, + idata=idata, + compute_convergence_checks=self.compute_convergence_checks, + ) + + # non-blocked sampling + return NutpieBackgroundSamplerWrapper( + result, + pymc_model=self.model, + compute_convergence_checks=compute_convergence_checks, + ) + else: + return self._post_process_sample(self.model, result, compute_convergence_checks) + + @staticmethod + def _post_process_sample( + model, idata: InferenceData, compute_convergence_checks + ) -> InferenceData: + # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed + # gather observed and constant data as nutpie.sample() has no access to the PyMC model + if compute_convergence_checks: + log_warnings(run_convergence_checks(idata, model)) + + coords, dims = coords_and_dims_for_inferencedata(model) + constant_data = dict_to_dataset( + find_constants(model), + library=idata.attrs.get("library", None), + coords=coords, + dims=dims, + default_dims=[], + ) + observed_data = dict_to_dataset( + find_observations(model), + library=idata.attrs.get("library", None), + coords=coords, + dims=dims, + default_dims=[], + ) + idata.add_groups( + {"constant_data": constant_data, "observed_data": observed_data}, + coords=coords, + dims=dims, + ) + return idata diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 6c0119282..b1fc4c4aa 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -35,7 +35,6 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace from pytensor.link.jax.dispatch import jax_funcify -from pytensor.raise_op import Assert from pytensor.tensor import TensorVariable from pytensor.tensor.random.type import RandomType @@ -47,7 +46,6 @@ ) from pymc.distributions.multivariate import PosDefMatrix from pymc.initial_point import StartDict -from pymc.logprob.utils import CheckParameterValue from pymc.sampling.mcmc import _init_jitter from pymc.stats.convergence import log_warnings, run_convergence_checks from pymc.util import ( @@ -71,19 +69,6 @@ ) -@jax_funcify.register(Assert) -@jax_funcify.register(CheckParameterValue) -def jax_funcify_Assert(op, **kwargs): - # Jax does not allow assert whose values aren't known during JIT compilation - # within it's JIT-ed code. Hence we need to make a simple pass through - # version of the Assert Op. - # https://github.com/google/jax/issues/2273#issuecomment-589098722 - def assert_fn(value, *inps): - return value - - return assert_fn - - @jax_funcify.register(PosDefMatrix) def jax_funcify_PosDefMatrix(op, **kwargs): def posdefmatrix_fn(value, *inps): @@ -520,8 +505,6 @@ def sample_jax_nuts( keep_untransformed: bool = False, chain_method: Literal["parallel", "vectorized"] = "parallel", postprocessing_backend: Literal["cpu", "gpu"] | None = None, - postprocessing_vectorize: Literal["vmap", "scan"] | None = None, - postprocessing_chunks=None, idata_kwargs: dict | None = None, compute_convergence_checks: bool = True, nuts_sampler: Literal["numpyro", "blackjax"], @@ -593,25 +576,6 @@ def sample_jax_nuts( with their respective sample stats and pointwise log likeihood values (unless skipped with ``idata_kwargs``). """ - if postprocessing_chunks is not None: - import warnings - - warnings.warn( - "postprocessing_chunks is deprecated due to being unstable, " - "using postprocessing_vectorize='scan' instead", - DeprecationWarning, - ) - - if postprocessing_vectorize is not None: - import warnings - - warnings.warn( - 'postprocessing_vectorize={"scan", "vmap"} will be removed in a future release.', - FutureWarning, - ) - else: - postprocessing_vectorize = "vmap" - model = modelcontext(model) if var_names is not None: @@ -674,7 +638,6 @@ def sample_jax_nuts( model, raw_mcmc_samples, backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, ) else: log_likelihood = None @@ -684,7 +647,6 @@ def sample_jax_nuts( jax_fn, raw_mcmc_samples, postprocessing_backend=postprocessing_backend, - postprocessing_vectorize=postprocessing_vectorize, donate_samples=True, ) del raw_mcmc_samples @@ -704,8 +666,8 @@ def sample_jax_nuts( dims.update(idata_kwargs.pop("dims")) # Use 'partial' to set default arguments before passing 'idata_kwargs' - to_trace = partial( - az.from_dict, + idata = az.from_dict( + posterior=mcmc_samples, log_likelihood=log_likelihood, observed_data=find_observations(model), constant_data=find_constants(model), @@ -714,14 +676,13 @@ def sample_jax_nuts( dims=dims, attrs=make_attrs(attrs, library=library), posterior_attrs=make_attrs(attrs, library=library), + **idata_kwargs, ) - az_trace = to_trace(posterior=mcmc_samples, **idata_kwargs) if compute_convergence_checks: - warns = run_convergence_checks(az_trace, model) - log_warnings(warns) + log_warnings(run_convergence_checks(idata, model)) - return az_trace + return idata sample_numpyro_nuts = partial(sample_jax_nuts, nuts_sampler="numpyro") diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 542797caa..dca1ca69b 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -33,8 +33,7 @@ import numpy as np import pytensor.gradient as tg -from arviz import InferenceData, dict_to_dataset -from arviz.data.base import make_attrs +from arviz import InferenceData from pytensor.graph.basic import Variable from rich.theme import Theme from threadpoolctl import threadpool_limits @@ -43,11 +42,6 @@ import pymc as pm from pymc.backends import RunType, TraceOrBackend, init_traces -from pymc.backends.arviz import ( - coords_and_dims_for_inferencedata, - find_constants, - find_observations, -) from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.backends.zarr import ZarrChain, ZarrTrace from pymc.blocking import DictToArrayBijection @@ -55,6 +49,7 @@ from pymc.initial_point import PointType, StartDict, make_initial_point_fns_per_chain from pymc.model import Model, modelcontext from pymc.progress_bar import ProgressBarManager, ProgressBarType, default_progress_theme +from pymc.sampling.external.base import ExternalSampler from pymc.sampling.parallel import Draw, _cpu_count from pymc.sampling.population import _sample_population from pymc.stats.convergence import ( @@ -238,17 +233,18 @@ def assign_step_methods( ) assigned_vars = assigned_vars.union(set(step.vars)) - # Use competence classmethods to select step methods for remaining - # variables + # Use competence classmethods to select step methods for remaining variables methods_list: list[type[BlockedStep]] = list(methods or pm.STEP_METHODS) selected_steps: dict[type[BlockedStep], list] = {} - model_logp = model.logp() + model_logp = None for var in model.value_vars: if var not in assigned_vars: # determine if a gradient can be computed has_gradient = getattr(var, "dtype") not in discrete_types if has_gradient: + if model_logp is None: + model_logp = model.logp() try: tg.grad(model_logp, var) # type: ignore[arg-type] except (NotImplementedError, tg.NullTypeGradError): @@ -258,9 +254,7 @@ def assign_step_methods( rv_var = model.values_to_rvs[var] selected = max( methods_list, - key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence( # type: ignore[misc] - var, has_gradient - ), + key=lambda method: method._competence(rv_var, has_gradient), ) selected_steps.setdefault(selected, []).append(var) @@ -290,127 +284,6 @@ def all_continuous(vars): return True -def _sample_external_nuts( - sampler: Literal["nutpie", "numpyro", "blackjax"], - draws: int, - tune: int, - chains: int, - target_accept: float, - random_seed: RandomState | None, - initvals: StartDict | Sequence[StartDict | None] | None, - model: Model, - var_names: Sequence[str] | None, - progressbar: bool, - idata_kwargs: dict | None, - compute_convergence_checks: bool, - nuts_sampler_kwargs: dict | None, - **kwargs, -): - if nuts_sampler_kwargs is None: - nuts_sampler_kwargs = {} - - if sampler == "nutpie": - try: - import nutpie - except ImportError as err: - raise ImportError( - "nutpie not found. Install it with conda install -c conda-forge nutpie" - ) from err - - if initvals is not None: - warnings.warn( - "`initvals` are currently not passed to nutpie sampler. " - "Use `init_mean` kwarg following nutpie specification instead.", - UserWarning, - ) - - if idata_kwargs is not None: - warnings.warn( - "`idata_kwargs` are currently ignored by the nutpie sampler", - UserWarning, - ) - - compile_kwargs = {} - nuts_sampler_kwargs = nuts_sampler_kwargs.copy() - for kwarg in ("backend", "gradient_backend"): - if kwarg in nuts_sampler_kwargs: - compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg) - compiled_model = nutpie.compile_pymc_model( - model, - var_names=var_names, - **compile_kwargs, - ) - t_start = time.time() - idata = nutpie.sample( - compiled_model, - draws=draws, - tune=tune, - chains=chains, - target_accept=target_accept, - seed=_get_seeds_per_chain(random_seed, 1)[0], - progress_bar=progressbar, - **nuts_sampler_kwargs, - ) - t_sample = time.time() - t_start - # Temporary work-around. Revert once https://github.com/pymc-devs/nutpie/issues/74 is fixed - # gather observed and constant data as nutpie.sample() has no access to the PyMC model - coords, dims = coords_and_dims_for_inferencedata(model) - constant_data = dict_to_dataset( - find_constants(model), - library=pm, - coords=coords, - dims=dims, - default_dims=[], - ) - observed_data = dict_to_dataset( - find_observations(model), - library=pm, - coords=coords, - dims=dims, - default_dims=[], - ) - attrs = make_attrs( - { - "sampling_time": t_sample, - "tuning_steps": tune, - }, - library=nutpie, - ) - for k, v in attrs.items(): - idata.posterior.attrs[k] = v - idata.add_groups( - {"constant_data": constant_data, "observed_data": observed_data}, - coords=coords, - dims=dims, - ) - return idata - - elif sampler in ("numpyro", "blackjax"): - import pymc.sampling.jax as pymc_jax - - idata = pymc_jax.sample_jax_nuts( - draws=draws, - tune=tune, - chains=chains, - target_accept=target_accept, - random_seed=random_seed, - initvals=initvals, - model=model, - var_names=var_names, - progressbar=progressbar, - nuts_sampler=sampler, - idata_kwargs=idata_kwargs, - compute_convergence_checks=compute_convergence_checks, - **nuts_sampler_kwargs, - ) - return idata - - else: - raise ValueError( - f"Sampler {sampler} not found. Choose one of ['nutpie', 'numpyro', 'blackjax', 'pymc']." - ) - - @overload def sample( draws: int = 1000, @@ -477,9 +350,9 @@ def sample( def sample( - draws: int = 1000, + draws: int | None = None, *, - tune: int = 1000, + tune: int | None = None, chains: int | None = None, cores: int | None = None, random_seed: RandomState = None, @@ -487,7 +360,7 @@ def sample( progressbar_theme: Theme | None = None, step=None, var_names: Sequence[str] | None = None, - nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc", + nuts_sampler: None | Literal["pymc", "nutpie", "numpyro", "blackjax"] = None, initvals: StartDict | Sequence[StartDict | None] | None = None, init: str = "auto", jitter_max_retries: int = 10, @@ -688,17 +561,6 @@ def sample( mean sd hdi_3% hdi_97% p 0.609 0.047 0.528 0.699 """ - if "start" in kwargs: - if initvals is not None: - raise ValueError("Passing both `start` and `initvals` is not supported.") - warnings.warn( - "The `start` kwarg was renamed to `initvals` and can now do more. Please check the docstring.", - FutureWarning, - stacklevel=2, - ) - initvals = kwargs.pop("start") - if nuts_sampler_kwargs is None: - nuts_sampler_kwargs = {} if "target_accept" in kwargs: if "nuts" in kwargs and "target_accept" in kwargs["nuts"]: raise ValueError( @@ -708,12 +570,6 @@ def sample( kwargs["nuts"]["target_accept"] = kwargs.pop("target_accept") else: kwargs["nuts"] = {"target_accept": kwargs.pop("target_accept")} - if isinstance(trace, list): - raise ValueError("Please use `var_names` keyword argument for partial traces.") - - # progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and - # ADVI initialization expect just a bool. - progress_bool = bool(progressbar) model = modelcontext(model) if not model.free_RVs: @@ -749,75 +605,69 @@ def joined_blas_limiter(): f"Invalid argument `blas_cores`, must be int, 'auto' or None: {blas_cores}" ) - if random_seed == -1: - raise ValueError( - "Setting random_seed = -1 is not allowed. Pass `None` to not specify a seed." - ) - elif isinstance(random_seed, tuple | list): - warnings.warn( - "A list or tuple of random_seed no longer specifies the specific random_seed of each chain. " - "Use a single seed instead.", - UserWarning, - ) rngs = get_random_generator(random_seed).spawn(chains) random_seed_list = [rng.integers(2**30) for rng in rngs] - if not discard_tuned_samples and not return_inferencedata and not isinstance(trace, ZarrTrace): + if step is None and nuts_sampler not in (None, "pymc"): + # Temporarily instantiate external samplers for user, for backwards-compat warnings.warn( - "Tuning samples will be included in the returned `MultiTrace` object, which can lead to" - " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n" - "`pm.sample(..., return_inferencedata=True)`", - UserWarning, - stacklevel=2, + f"Setting `pm.sample(nuts_sampler='{nuts_sampler}, nuts_sampler_kwargs=...)'` is deprecated.\n" + f"Use `pm.sample(step=pm.external.{nuts_sampler.capitalize()}(**nuts_sampler_kwargs))` instead", + FutureWarning, ) + from pymc.sampling import external - # small trace warning - if draws == 0: - msg = "Tuning was enabled throughout the whole trace." - _log.warning(msg) - elif draws < 100: - msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate." - _log.warning(msg) - - provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS) - exclusive_nuts = ( - # User provided an instantiated NUTS step, and nothing else is needed - (not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS)) - or - # Only automatically selected NUTS step is needed - ( - not provided_steps - and len(selected_steps) == 1 - and issubclass(next(iter(selected_steps)), NUTS) + step = getattr(external, nuts_sampler.capitalize())( + model=model, + **(nuts_sampler_kwargs or {}), ) - ) + nuts_sampler_kwargs = None - if nuts_sampler != "pymc": - if not exclusive_nuts: + if isinstance(step, list | tuple) and len(step) == 1: + [step] = step + + if isinstance(step, ExternalSampler): + if step.model is not model: + raise ValueError("External step model does not match model detected by sample") + if nuts_sampler_kwargs: raise ValueError( - "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability." + f"{nuts_sampler_kwargs=} should be passed when constructing external sampler" ) - + if "nuts" in kwargs: + kwargs.update(kwargs["nuts"].pop()) with joined_blas_limiter(): - return _sample_external_nuts( - sampler=nuts_sampler, - draws=draws, + return step.sample( tune=tune, + draws=draws, chains=chains, - target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8), - random_seed=random_seed, initvals=initvals, - model=model, + random_seed=random_seed, + progressbar=bool(progressbar), var_names=var_names, - progressbar=progress_bool, idata_kwargs=idata_kwargs, compute_convergence_checks=compute_convergence_checks, - nuts_sampler_kwargs=nuts_sampler_kwargs, **kwargs, ) - if exclusive_nuts and not provided_steps: - # Special path for NUTS initialization + # PyMC defaults + if tune is None: + tune = 1000 + if draws is None: + draws = 1000 + elif 0 < draws < 100: + _log.warning( + f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate." + ) + + provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS) + + if ( + not provided_steps + and len(selected_steps) == 1 + and issubclass(next(iter(selected_steps)), NUTS) + ): + # Special path for automatically NUTS initialization + # When no step sampler is provided and only NUTS is needed if "nuts" in kwargs: nuts_kwargs = kwargs.pop("nuts") [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] @@ -828,7 +678,7 @@ def joined_blas_limiter(): n_init=n_init, model=model, random_seed=random_seed_list, - progressbar=progress_bool, + progressbar=bool(progressbar), jitter_max_retries=jitter_max_retries, tune=tune, initvals=initvals, @@ -907,10 +757,6 @@ def joined_blas_limiter(): ) parallel = cores > 1 and chains > 1 and not has_population_samplers - # At some point it was decided that PyMC should not set a global seed by default, - # unless the user specified a seed. This is a symptom of the fact that PyMC samplers - # are built around global seeding. This branch makes sure we maintain this unspoken - # rule. See https://github.com/pymc-devs/pymc/pull/1395. if parallel: # For parallel sampling we can pass the list of random seeds directly, as # global seeding will only be called inside each process diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 090b76130..c009aaf57 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -332,6 +332,7 @@ def astep(self, q0): return draw, stats +@pytest.mark.filterwarnings("error") class TestSampleReturn: """Tests related to kwargs that parametrize how `pm.sample` results are returned.""" @@ -340,18 +341,17 @@ def test_sample_return_lengths(self): pm.Normal("n") # Get a MultiTrace with warmup - with pytest.warns(UserWarning, match="will be included"): - mtrace = pm.sample( - draws=100, - tune=50, - cores=1, - chains=3, - step=pm.Metropolis(), - return_inferencedata=False, - discard_tuned_samples=False, - ) - assert isinstance(mtrace, pm.backends.base.MultiTrace) - assert len(mtrace) == 150 + mtrace = pm.sample( + draws=100, + tune=50, + cores=1, + chains=3, + step=pm.Metropolis(), + return_inferencedata=False, + discard_tuned_samples=False, + ) + assert isinstance(mtrace, pm.backends.base.MultiTrace) + assert len(mtrace) == 150 # Now instead of running more MCMCs, we'll test the other return # options using the basetraces inside the MultiTrace. @@ -518,14 +518,6 @@ def test_blas_cores(): pm.sample(blas_cores=2, tune=10, cores=2, draws=10) -def test_partial_trace_with_trace_unsupported(): - with pm.Model() as model: - a = pm.Normal("a", mu=0, sigma=1) - b = pm.Normal("b", mu=0, sigma=1) - with pytest.raises(ValueError, match="var_names"): - pm.sample(trace=[a]) - - class TestNamedSampling: def test_shared_named(self): G_var = shared(value=np.atleast_2d(1.0), shape=(1, None), name="G")