1+ from dataclasses import dataclass
2+ import dataclasses
3+ import functools
14from math import prod
5+ from typing import Dict , List
26
37import aesara
48import aesara .tensor as at
9+ from numpy .typing import NDArray
510import pymc as pm
611import numpy as np
712import numba
813from aeppl .logprob import CheckParameterValue
914import aesara .link .numba .dispatch
15+ from numba import literal_unroll
16+ from numba .cpython .unsafe .tuple import alloca_once , tuple_setitem
17+ import numba .core .ccallback
1018
1119from .sample import CompiledModel
20+ from . import lib
1221
13- # Provide a numba implementation for CheckParameterValue, which doesn't exist in aesara
14- @aesara .link .numba .dispatch .numba_funcify .register (CheckParameterValue )
22+ # Provide a numba implementation for CheckParameterValue,
23+ # which doesn't exist in aesara
24+ @aesara .link .numba .dispatch .basic .numba_funcify .register (CheckParameterValue )
1525def numba_functify_CheckParameterValue (op , ** kwargs ):
26+ msg = f"Invalid parameter value { str (op )} "
27+
1628 @aesara .link .numba .dispatch .basic .numba_njit
1729 def check (value , * conditions ):
30+ for cond in literal_unroll (conditions ):
31+ if not cond :
32+ raise ValueError (msg )
1833 return value
19-
20- return check
21-
22- # Overwrite the IncSubtensor op from aesara, see https://github.com/aesara-devs/aesara/issues/603
23- @aesara .link .numba .dispatch .numba_funcify .register (at .subtensor .AdvancedIncSubtensor1 )
24- def numba_funcify_IncSubtensor (op , node , ** kwargs ):
2534
26- def incsubtensor_fn (z , vals , idxs ):
27- z = z .copy ()
28- for idx , val in zip (idxs , vals ):
29- z [idx ] += val
30- return z
35+ return check
3136
32- return aesara .link .numba .dispatch .basic .numba_njit (incsubtensor_fn )
3337
38+ @numba .extending .intrinsic
39+ def address_as_void_pointer (typingctx , src ):
40+ """returns a void pointer from a given memory address"""
41+ from numba .core import types , cgutils
42+
43+ sig = types .voidptr (src )
44+
45+ def codegen (cgctx , builder , sig , args ):
46+ return builder .inttoptr (args [0 ], cgutils .voidptr_t )
47+
48+ return sig , codegen
49+
50+
51+ @dataclass (frozen = True )
52+ class CompiledPyMCModel (CompiledModel ):
53+ compiled_logp_func : numba .core .ccallback .CFunc
54+ shared_data : Dict [str , NDArray ]
55+ user_data : NDArray
56+
57+ def with_data (self , ** updates ):
58+ shared_data = self .shared_data .copy ()
59+ user_data = self .user_data .copy ()
60+ for name , new_val in updates .items ():
61+ if name not in shared_data :
62+ raise KeyError (f"Unknown shared variable: { name } " )
63+ old_val = shared_data [name ]
64+ new_val = np .asarray (new_val ).copy ()
65+ new_val .flags .writeable = False
66+ if old_val .ndim != new_val .ndim :
67+ raise ValueError (
68+ f"Shared variable { name } must have rank { old_val .ndim } "
69+ )
70+ shared_data [name ] = new_val
71+ user_data = update_user_data (user_data , shared_data )
72+
73+ logp_func_maker = self .logp_func_maker .with_arg (user_data .ctypes .data )
74+ expand_draw_fn = functools .partial (
75+ self .expand_draw_fn .func , shared_data = shared_data
76+ )
77+ return dataclasses .replace (
78+ self ,
79+ shared_data = shared_data ,
80+ user_data = user_data ,
81+ logp_func_maker = logp_func_maker ,
82+ expand_draw_fn = expand_draw_fn ,
83+ )
84+
85+
86+ def update_user_data (user_data , user_data_storage ):
87+ user_data = user_data [()]
88+ for name , val in user_data_storage .items ():
89+ user_data ["shared" ]["data" ][name ] = val .ctypes .data
90+ user_data ["shared" ]["size" ][name ] = val .size
91+ user_data ["shared" ]["shape" ][name ] = val .shape
92+ return np .asarray (user_data )
93+
94+
95+ def make_user_data (func , shared_data ):
96+ shared_vars = func .get_shared ()
97+ record_dtype = np .dtype (
98+ [
99+ (
100+ "shared" ,
101+ [
102+ ("data" , [(var .name , np .uintp ) for var in shared_vars ]),
103+ ("size" , [(var .name , np .uintp ) for var in shared_vars ]),
104+ (
105+ "shape" ,
106+ [(var .name , np .uint , (var .ndim ,)) for var in shared_vars ],
107+ ),
108+ ],
109+ )
110+ ],
111+ align = True ,
112+ )
113+ user_data = np .zeros ((), dtype = record_dtype )
114+ update_user_data (user_data , shared_data )
115+ return user_data
34116
35117
36118def compile_pymc_model (model , ** kwargs ):
37119 """Compile necessary functions for sampling a pymc model."""
38- n_dim , logp_func , expanding_function , shape_info = _make_functions (model )
39- logp_func = numba .njit (** kwargs )(logp_func )
40- logp_numba_raw , c_sig = _make_c_logp_func (n_dim , logp_func )
120+
121+ n_dim , logp_fn_at , logp_fn , expand_fn , shared_expand , shape_info = _make_functions (
122+ model
123+ )
124+
125+ shared_data = {val .name : val .get_value ().copy () for val in logp_fn_at .get_shared ()}
126+ for val in shared_data .values ():
127+ val .flags .writeable = False
128+
129+ shared_logp = [var .name for var in logp_fn_at .get_shared ()]
130+
131+ user_data = make_user_data (logp_fn_at , shared_data )
132+
133+ logp_numba_raw , c_sig = _make_c_logp_func (
134+ n_dim , logp_fn , user_data , shared_logp , shared_data
135+ )
41136 logp_numba = numba .cfunc (c_sig , ** kwargs )(logp_numba_raw )
42137
43- def expand_draw (x ):
44- return expanding_function ( x )[0 ]
138+ def expand_draw (x , shared_data ):
139+ return expand_fn ( x , ** { name : shared_data [ name ] for name in shared_expand } )[0 ]
45140
46- def make_user_data ( ):
47- return 0
141+ def make_logp_pyfn ( data_ptr ):
142+ return logp_numba . address , data_ptr , None
48143
49- return CompiledModel (
50- model ,
144+ logp_func_maker = lib .PtrLogpFuncMaker (
145+ make_logp_pyfn ,
146+ user_data .ctypes .data ,
51147 n_dim ,
52- logp_numba .address ,
53- expand_draw ,
54- make_user_data ,
55- shape_info ,
56- model .RV_dims ,
57- model .coords ,
58- (logp_numba , logp_func ),
148+ logp_numba ,
149+ )
150+
151+ expand_draw_fn = functools .partial (expand_draw , shared_data = shared_data )
152+
153+ return CompiledPyMCModel (
154+ n_dim = n_dim ,
155+ dims = model .RV_dims ,
156+ coords = model .coords ,
157+ shape_info = shape_info ,
158+ logp_func_maker = logp_func_maker ,
159+ expand_draw_fn = expand_draw_fn ,
160+ compiled_logp_func = logp_numba ,
161+ shared_data = shared_data ,
162+ user_data = user_data ,
59163 )
60164
61165
@@ -68,7 +172,7 @@ def _compute_shapes(model):
68172 if var not in model .observed_RVs + model .potentials
69173 }
70174
71- shape_func = aesara .function (
175+ shape_func = aesara .compile . function . function (
72176 inputs = [],
73177 outputs = [var .shape for var in trace_vars .values ()],
74178 givens = (
@@ -94,7 +198,7 @@ def _make_functions(model):
94198 value_vars = [model .rvs_to_values [var ] for var in model .free_RVs ]
95199
96200 logp = model .logp ()
97- grads = at .grad (logp , value_vars )
201+ grads = aesara . gradient .grad (logp , value_vars )
98202 grad = at .concatenate ([grad .ravel () for grad in grads ])
99203
100204 count = 0
@@ -118,11 +222,11 @@ def _make_functions(model):
118222 num_free_vars = count
119223
120224 # We should avoid compiling the function, and optimize only
121- func = aesara .function (
225+ logp_fn_at = aesara . compile . function .function (
122226 (joined ,), (logp , grad ), givens = symbolic_sliced , mode = aesara .compile .NUMBA
123227 )
124228
125- logp_func = func .vm .jit_fn . py_func
229+ logp_fn = logp_fn_at .vm .jit_fn
126230
127231 # Make function that computes remaining variables for the trace
128232 trace_vars = {
@@ -150,22 +254,124 @@ def _make_functions(model):
150254 count += length
151255
152256 allvars = at .concatenate ([joined , * [var .ravel () for var in remaining_rvs ]])
153- func = aesara .function (
257+ expand_fn_at = aesara . compile . function .function (
154258 (joined ,), (allvars ,), givens = symbolic_sliced , mode = aesara .compile .NUMBA
155259 )
156- func = func .vm .jit_fn .py_func
157- expanding_function = numba .njit (func , fastmath = True , error_model = "numpy" )
158- expanding_function (np .zeros (num_free_vars ))
260+ expand_fn = expand_fn_at .vm .jit_fn
261+ # expand_fn = numba.njit(expand_fn, fastmath=True, error_model="numpy")
262+ # Trigger a compile
263+ expand_fn (np .zeros (num_free_vars ))
159264
160265 return (
161266 num_free_vars ,
162- logp_func ,
163- expanding_function ,
267+ logp_fn_at ,
268+ logp_fn ,
269+ expand_fn ,
270+ [var .name for var in expand_fn_at .get_shared ()],
164271 (all_names , all_slices , all_shapes ),
165272 )
166273
167274
168- def _make_c_logp_func (N , logp_func ):
275+ def make_extraction_fn (inner , shared_data , shared_vars , record_dtype ):
276+ if not shared_vars :
277+
278+ @numba .njit (inline = "always" )
279+ def extract_shared (x , user_data_ ):
280+ return inner (x )
281+
282+ return extract_shared
283+
284+ shared_metadata = tuple (
285+ [
286+ name ,
287+ len (shared_data [name ].shape ),
288+ shared_data [name ].shape ,
289+ np .dtype (shared_data [name ].dtype ),
290+ ]
291+ for name in shared_vars
292+ )
293+
294+ names = shared_vars
295+ indices = tuple (range (len (names )))
296+ shared_tuple = tuple (shared_data [name ] for name in shared_vars )
297+
298+ @numba .extending .intrinsic
299+ def tuple_setitem_literal (typingctx , tup , idx , val ):
300+ """Return a copy of the tuple with item at *idx* replaced with *val*.
301+ """
302+ if not isinstance (idx , numba .types .IntegerLiteral ):
303+ return
304+
305+ idx_val = idx .literal_value
306+ assert idx_val >= 0
307+ assert idx_val < len (tup )
308+
309+ import llvmlite
310+
311+ def codegen (context , builder , signature , args ):
312+ tup , idx , val = args
313+ stack = alloca_once (builder , tup .type )
314+ builder .store (tup , stack )
315+ # Unsafe load on unchecked bounds. Poison value maybe returned.
316+ tuple_idx = llvmlite .ir .IntType (32 )(idx_val )
317+ offptr = builder .gep (stack , [idx .type (0 ), tuple_idx ], inbounds = True )
318+ builder .store (val , offptr )
319+ return builder .load (stack )
320+
321+ sig = tup (tup , idx , tup [idx_val ])
322+ return sig , codegen
323+
324+ def extract_array (user_data , index ):
325+ pass
326+
327+ @numba .extending .overload (extract_array , inline = "always" )
328+ def impl_extract_array (user_data , index ):
329+ if not isinstance (index , numba .types .Literal ):
330+ return
331+
332+ index = index .literal_value
333+
334+ name , ndim , base_shape , dtype = shared_metadata [index ]
335+
336+ ndim_range = tuple (range (ndim ))
337+
338+ def impl (user_data , index ):
339+ data_ptr = address_as_void_pointer (user_data ["data" ][name ][()])
340+ data = numba .carray (data_ptr , int (user_data ["size" ][name ][()]), dtype )
341+
342+ shape = user_data ["shape" ][name ]
343+
344+ assert len (shape ) == len (base_shape )
345+
346+ shape_ = base_shape
347+
348+ # For some reason I get typing errors without this if condition
349+ if ndim > 0 :
350+ for i in range (ndim ):
351+ shape_ = tuple_setitem (shape_ , i , shape [i ])
352+
353+ return data .reshape (shape_ )
354+
355+ return impl
356+
357+ @numba .njit (inline = "always" )
358+ def extract_shared (x , user_data_ ):
359+ user_data = numba .carray (user_data_ , (), record_dtype )
360+
361+ _shared_tuple = shared_tuple
362+ for index in literal_unroll (indices ):
363+ dat = extract_array (user_data ["shared" ], index )
364+ _shared_tuple = tuple_setitem_literal (_shared_tuple , index , dat )
365+
366+ return inner (x , * _shared_tuple )
367+
368+ return extract_shared
369+
370+
371+ def _make_c_logp_func (n_dim , logp_fn , user_data , shared_logp , shared_data ):
372+
373+ extract = make_extraction_fn (logp_fn , shared_data , shared_logp , user_data .dtype )
374+
169375 c_sig = numba .types .int64 (
170376 numba .types .uint64 ,
171377 numba .types .CPointer (numba .types .double ),
@@ -175,21 +381,24 @@ def _make_c_logp_func(N, logp_func):
175381 )
176382
177383 def logp_numba (dim , x_ , out_ , logp_ , user_data_ ):
384+ if dim != n_dim :
385+ return - 1
386+
178387 try :
179- x = numba .carray (x_ , (N ,))
180- out = numba .carray (out_ , (N ,))
388+ x = numba .carray (x_ , (n_dim ,))
389+ out = numba .carray (out_ , (n_dim ,))
181390 logp = numba .carray (logp_ , ())
182391
183- logp_val , grad = logp_func ( x )
392+ logp_val , grad = extract ( x , user_data_ )
184393 logp [()] = logp_val
185394 out [...] = grad
186395
187396 if not np .all (np .isfinite (out )):
188- return 2
189- if not np .isfinite (logp_val ):
190397 return 3
191- if np .any ( out == 0 ):
398+ if not np .isfinite ( logp_val ):
192399 return 4
400+ # if np.any(out == 0):
401+ # return 4
193402 except Exception :
194403 return 1
195404 return 0
0 commit comments