@@ -292,11 +292,11 @@ def sample(
292
292
chains : int | None = None ,
293
293
cores : int | None = None ,
294
294
random_seed : RandomState = None ,
295
+ step = None ,
296
+ external_sampler : ExternalSampler | None = None ,
295
297
progressbar : bool | ProgressBarType = True ,
296
298
progressbar_theme : Theme | None = default_progress_theme ,
297
- step = None ,
298
299
var_names : Sequence [str ] | None = None ,
299
- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
300
300
initvals : StartDict | Sequence [StartDict | None ] | None = None ,
301
301
init : str = "auto" ,
302
302
jitter_max_retries : int = 10 ,
@@ -324,11 +324,11 @@ def sample(
324
324
chains : int | None = None ,
325
325
cores : int | None = None ,
326
326
random_seed : RandomState = None ,
327
+ step = None ,
328
+ external_sampler : ExternalSampler | None = None ,
327
329
progressbar : bool | ProgressBarType = True ,
328
330
progressbar_theme : Theme | None = default_progress_theme ,
329
- step = None ,
330
331
var_names : Sequence [str ] | None = None ,
331
- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
332
332
initvals : StartDict | Sequence [StartDict | None ] | None = None ,
333
333
init : str = "auto" ,
334
334
jitter_max_retries : int = 10 ,
@@ -356,11 +356,11 @@ def sample(
356
356
chains : int | None = None ,
357
357
cores : int | None = None ,
358
358
random_seed : RandomState = None ,
359
+ step = None ,
360
+ external_sampler : ExternalSampler | None = None ,
359
361
progressbar : bool | ProgressBarType = True ,
360
362
progressbar_theme : Theme | None = None ,
361
- step = None ,
362
363
var_names : Sequence [str ] | None = None ,
363
- nuts_sampler : None | Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = None ,
364
364
initvals : StartDict | Sequence [StartDict | None ] | None = None ,
365
365
init : str = "auto" ,
366
366
jitter_max_retries : int = 10 ,
@@ -407,6 +407,12 @@ def sample(
407
407
A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
408
408
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
409
409
easy spawning of new independent random streams that are needed by the step methods.
410
+ step : function or iterable of functions, optional
411
+ A step function or collection of functions. If there are variables without step methods,
412
+ step methods for those variables will be assigned automatically. By default the NUTS step
413
+ method will be used, if appropriate to the model. Not compatible with external_sampler
414
+ external_sampler: ExternalSampler, optional
415
+ An external sampler to sample the whole model. Not compatible with step.
410
416
progressbar: bool or ProgressType, optional
411
417
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
412
418
for one of the following:
@@ -419,10 +425,6 @@ def sample(
419
425
are also displayed.
420
426
421
427
If True, the default is "split+stats" is used.
422
- step : function or iterable of functions
423
- A step function or collection of functions. If there are variables without step methods,
424
- step methods for those variables will be assigned automatically. By default the NUTS step
425
- method will be used, if appropriate to the model.
426
428
var_names : list of str, optional
427
429
Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
428
430
nuts_sampler : str
@@ -608,35 +610,38 @@ def joined_blas_limiter():
608
610
rngs = get_random_generator (random_seed ).spawn (chains )
609
611
random_seed_list = [rng .integers (2 ** 30 ) for rng in rngs ]
610
612
611
- if step is None and nuts_sampler not in (None , "pymc" ):
612
- # Temporarily instantiate external samplers for user, for backwards-compat
613
+ if "nuts_sampler" in kwargs :
614
+ # Transition backwards-compatibility
615
+ nuts_sampler = kwargs .pop ("nuts_sampler" )
613
616
warnings .warn (
614
617
f"Setting `pm.sample(nuts_sampler='{ nuts_sampler } , nuts_sampler_kwargs=...)'` is deprecated.\n "
615
- f"Use `pm.sample(step =pm.external.{ nuts_sampler .capitalize ()} (**nuts_sampler_kwargs))` instead" ,
618
+ f"Use `pm.sample(external_sampler =pm.external.{ nuts_sampler .capitalize ()} (**nuts_sampler_kwargs))` instead" ,
616
619
FutureWarning ,
617
620
)
618
621
from pymc .sampling import external
619
622
620
- step = getattr (external , nuts_sampler .capitalize ())(
623
+ external_sampler = getattr (external , nuts_sampler .capitalize ())(
621
624
model = model ,
622
625
** (nuts_sampler_kwargs or {}),
623
626
)
624
627
nuts_sampler_kwargs = None
625
628
626
- if isinstance (step , list | tuple ) and len (step ) == 1 :
627
- [step ] = step
629
+ if external_sampler is not None :
630
+ if step is not None :
631
+ raise ValueError ("`step` and `external_sampler` cannot be used together" )
628
632
629
- if isinstance (step , ExternalSampler ):
630
- if step .model is not model :
631
- raise ValueError ("External step model does not match model detected by sample" )
633
+ if external_sampler .model is not model :
634
+ raise ValueError (
635
+ "External sampler model does not match model detected by sample function"
636
+ )
632
637
if nuts_sampler_kwargs :
633
638
raise ValueError (
634
639
f"{ nuts_sampler_kwargs = } should be passed when constructing external sampler"
635
640
)
636
641
if "nuts" in kwargs :
637
- kwargs .update (kwargs [ "nuts" ] .pop ())
642
+ kwargs .update (kwargs .pop ("nuts" ))
638
643
with joined_blas_limiter ():
639
- return step .sample (
644
+ return external_sampler .sample (
640
645
tune = tune ,
641
646
draws = draws ,
642
647
chains = chains ,
0 commit comments