11import dataclasses
22import itertools
3+ import warnings
34from dataclasses import dataclass
45from importlib .util import find_spec
56from math import prod
@@ -200,12 +201,26 @@ def compile_pymc_model(model: "pm.Model", **kwargs) -> CompiledPyMCModel:
200201 logp_numba_raw , c_sig = _make_c_logp_func (
201202 n_dim , logp_fn , user_data , shared_logp , shared_data
202203 )
203- logp_numba = numba .cfunc (c_sig , ** kwargs )(logp_numba_raw )
204+ with warnings .catch_warnings ():
205+ warnings .filterwarnings (
206+ "ignore" ,
207+ message = "Cannot cache compiled function .* as it uses dynamic globals" ,
208+ category = numba .NumbaWarning ,
209+ )
210+
211+ logp_numba = numba .cfunc (c_sig , ** kwargs )(logp_numba_raw )
204212
205213 expand_numba_raw , c_sig_expand = _make_c_expand_func (
206214 n_dim , n_expanded , expand_fn , user_data , shared_expand , shared_data
207215 )
208- expand_numba = numba .cfunc (c_sig_expand , ** kwargs )(expand_numba_raw )
216+ with warnings .catch_warnings ():
217+ warnings .filterwarnings (
218+ "ignore" ,
219+ message = "Cannot cache compiled function .* as it uses dynamic globals" ,
220+ category = numba .NumbaWarning ,
221+ )
222+
223+ expand_numba = numba .cfunc (c_sig_expand , ** kwargs )(expand_numba_raw )
209224
210225 coords = {}
211226 for name , vals in model .coords .items ():
@@ -276,6 +291,7 @@ def _make_functions(model):
276291 import pytensor
277292 import pytensor .link .numba .dispatch
278293 import pytensor .tensor as pt
294+ from pymc .pytensorf import compile_pymc
279295
280296 shapes = _compute_shapes (model )
281297
@@ -340,9 +356,8 @@ def _make_functions(model):
340356 (logp , grad ) = pytensor .graph_replace ([logp , grad ], replacements )
341357
342358 # We should avoid compiling the function, and optimize only
343- logp_fn_pt = pytensor .compile .function .function (
344- (joined ,), (logp , grad ), mode = pytensor .compile .NUMBA
345- )
359+ with model :
360+ logp_fn_pt = compile_pymc ((joined ,), (logp , grad ), mode = pytensor .compile .NUMBA )
346361
347362 logp_fn = logp_fn_pt .vm .jit_fn
348363
@@ -368,12 +383,13 @@ def _make_functions(model):
368383 num_expanded = count
369384
370385 allvars = pt .concatenate ([joined , * [var .ravel () for var in remaining_rvs ]])
371- expand_fn_pt = pytensor .compile .function .function (
372- (joined ,),
373- (allvars ,),
374- givens = list (replacements .items ()),
375- mode = pytensor .compile .NUMBA ,
376- )
386+ with model :
387+ expand_fn_pt = compile_pymc (
388+ (joined ,),
389+ (allvars ,),
390+ givens = list (replacements .items ()),
391+ mode = pytensor .compile .NUMBA ,
392+ )
377393 expand_fn = expand_fn_pt .vm .jit_fn
378394
379395 return (
0 commit comments