Skip to content

Commit 0faf374

Browse files
committed
Support shared vars in pymc
1 parent f2e5bf2 commit 0faf374

File tree

4 files changed

+497
-236
lines changed

4 files changed

+497
-236
lines changed

nutpie/compile_pymc.py

Lines changed: 256 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,165 @@
1+
from dataclasses import dataclass
2+
import dataclasses
3+
import functools
14
from math import prod
5+
from typing import Dict, List
26

37
import aesara
48
import aesara.tensor as at
9+
from numpy.typing import NDArray
510
import pymc as pm
611
import numpy as np
712
import numba
813
from aeppl.logprob import CheckParameterValue
914
import aesara.link.numba.dispatch
15+
from numba import literal_unroll
16+
from numba.cpython.unsafe.tuple import alloca_once, tuple_setitem
17+
import numba.core.ccallback
1018

1119
from .sample import CompiledModel
20+
from . import lib
1221

13-
# Provide a numba implementation for CheckParameterValue, which doesn't exist in aesara
14-
@aesara.link.numba.dispatch.numba_funcify.register(CheckParameterValue)
22+
# Provide a numba implementation for CheckParameterValue,
23+
# which doesn't exist in aesara
24+
@aesara.link.numba.dispatch.basic.numba_funcify.register(CheckParameterValue)
1525
def numba_functify_CheckParameterValue(op, **kwargs):
26+
msg = f"Invalid parameter value {str(op)}"
27+
1628
@aesara.link.numba.dispatch.basic.numba_njit
1729
def check(value, *conditions):
30+
for cond in literal_unroll(conditions):
31+
if not cond:
32+
raise ValueError(msg)
1833
return value
19-
20-
return check
21-
22-
# Overwrite the IncSubtensor op from aesara, see https://github.com/aesara-devs/aesara/issues/603
23-
@aesara.link.numba.dispatch.numba_funcify.register(at.subtensor.AdvancedIncSubtensor1)
24-
def numba_funcify_IncSubtensor(op, node, **kwargs):
2534

26-
def incsubtensor_fn(z, vals, idxs):
27-
z = z.copy()
28-
for idx, val in zip(idxs, vals):
29-
z[idx] += val
30-
return z
35+
return check
3136

32-
return aesara.link.numba.dispatch.basic.numba_njit(incsubtensor_fn)
3337

38+
@numba.extending.intrinsic
39+
def address_as_void_pointer(typingctx, src):
40+
"""returns a void pointer from a given memory address"""
41+
from numba.core import types, cgutils
42+
43+
sig = types.voidptr(src)
44+
45+
def codegen(cgctx, builder, sig, args):
46+
return builder.inttoptr(args[0], cgutils.voidptr_t)
47+
48+
return sig, codegen
49+
50+
51+
@dataclass(frozen=True)
52+
class CompiledPyMCModel(CompiledModel):
53+
compiled_logp_func: numba.core.ccallback.CFunc
54+
shared_data: Dict[str, NDArray]
55+
user_data: NDArray
56+
57+
def with_data(self, **updates):
58+
shared_data = self.shared_data.copy()
59+
user_data = self.user_data.copy()
60+
for name, new_val in updates.items():
61+
if name not in shared_data:
62+
raise KeyError(f"Unknown shared variable: {name}")
63+
old_val = shared_data[name]
64+
new_val = np.asarray(new_val).copy()
65+
new_val.flags.writeable = False
66+
if old_val.ndim != new_val.ndim:
67+
raise ValueError(
68+
f"Shared variable {name} must have rank {old_val.ndim}"
69+
)
70+
shared_data[name] = new_val
71+
user_data = update_user_data(user_data, shared_data)
72+
73+
logp_func_maker = self.logp_func_maker.with_arg(user_data.ctypes.data)
74+
expand_draw_fn = functools.partial(
75+
self.expand_draw_fn.func, shared_data=shared_data
76+
)
77+
return dataclasses.replace(
78+
self,
79+
shared_data=shared_data,
80+
user_data=user_data,
81+
logp_func_maker=logp_func_maker,
82+
expand_draw_fn=expand_draw_fn,
83+
)
84+
85+
86+
def update_user_data(user_data, user_data_storage):
87+
user_data = user_data[()]
88+
for name, val in user_data_storage.items():
89+
user_data["shared"]["data"][name] = val.ctypes.data
90+
user_data["shared"]["size"][name] = val.size
91+
user_data["shared"]["shape"][name] = val.shape
92+
return np.asarray(user_data)
93+
94+
95+
def make_user_data(func, shared_data):
96+
shared_vars = func.get_shared()
97+
record_dtype = np.dtype(
98+
[
99+
(
100+
"shared",
101+
[
102+
("data", [(var.name, np.uintp) for var in shared_vars]),
103+
("size", [(var.name, np.uintp) for var in shared_vars]),
104+
(
105+
"shape",
106+
[(var.name, np.uint, (var.ndim,)) for var in shared_vars],
107+
),
108+
],
109+
)
110+
],
111+
align=True,
112+
)
113+
user_data = np.zeros((), dtype=record_dtype)
114+
update_user_data(user_data, shared_data)
115+
return user_data
34116

35117

36118
def compile_pymc_model(model, **kwargs):
37119
"""Compile necessary functions for sampling a pymc model."""
38-
n_dim, logp_func, expanding_function, shape_info = _make_functions(model)
39-
logp_func = numba.njit(**kwargs)(logp_func)
40-
logp_numba_raw, c_sig = _make_c_logp_func(n_dim, logp_func)
120+
121+
n_dim, logp_fn_at, logp_fn, expand_fn, shared_expand, shape_info = _make_functions(
122+
model
123+
)
124+
125+
shared_data = {val.name: val.get_value().copy() for val in logp_fn_at.get_shared()}
126+
for val in shared_data.values():
127+
val.flags.writeable = False
128+
129+
shared_logp = [var.name for var in logp_fn_at.get_shared()]
130+
131+
user_data = make_user_data(logp_fn_at, shared_data)
132+
133+
logp_numba_raw, c_sig = _make_c_logp_func(
134+
n_dim, logp_fn, user_data, shared_logp, shared_data
135+
)
41136
logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
42137

43-
def expand_draw(x):
44-
return expanding_function(x)[0]
138+
def expand_draw(x, shared_data):
139+
return expand_fn(x, **{name: shared_data[name] for name in shared_expand})[0]
45140

46-
def make_user_data():
47-
return 0
141+
def make_logp_pyfn(data_ptr):
142+
return logp_numba.address, data_ptr, None
48143

49-
return CompiledModel(
50-
model,
144+
logp_func_maker = lib.PtrLogpFuncMaker(
145+
make_logp_pyfn,
146+
user_data.ctypes.data,
51147
n_dim,
52-
logp_numba.address,
53-
expand_draw,
54-
make_user_data,
55-
shape_info,
56-
model.RV_dims,
57-
model.coords,
58-
(logp_numba, logp_func),
148+
logp_numba,
149+
)
150+
151+
expand_draw_fn = functools.partial(expand_draw, shared_data=shared_data)
152+
153+
return CompiledPyMCModel(
154+
n_dim=n_dim,
155+
dims=model.RV_dims,
156+
coords=model.coords,
157+
shape_info=shape_info,
158+
logp_func_maker=logp_func_maker,
159+
expand_draw_fn=expand_draw_fn,
160+
compiled_logp_func=logp_numba,
161+
shared_data=shared_data,
162+
user_data=user_data,
59163
)
60164

61165

@@ -68,7 +172,7 @@ def _compute_shapes(model):
68172
if var not in model.observed_RVs + model.potentials
69173
}
70174

71-
shape_func = aesara.function(
175+
shape_func = aesara.compile.function.function(
72176
inputs=[],
73177
outputs=[var.shape for var in trace_vars.values()],
74178
givens=(
@@ -94,7 +198,7 @@ def _make_functions(model):
94198
value_vars = [model.rvs_to_values[var] for var in model.free_RVs]
95199

96200
logp = model.logp()
97-
grads = at.grad(logp, value_vars)
201+
grads = aesara.gradient.grad(logp, value_vars)
98202
grad = at.concatenate([grad.ravel() for grad in grads])
99203

100204
count = 0
@@ -118,11 +222,11 @@ def _make_functions(model):
118222
num_free_vars = count
119223

120224
# We should avoid compiling the function, and optimize only
121-
func = aesara.function(
225+
logp_fn_at = aesara.compile.function.function(
122226
(joined,), (logp, grad), givens=symbolic_sliced, mode=aesara.compile.NUMBA
123227
)
124228

125-
logp_func = func.vm.jit_fn.py_func
229+
logp_fn = logp_fn_at.vm.jit_fn
126230

127231
# Make function that computes remaining variables for the trace
128232
trace_vars = {
@@ -150,22 +254,124 @@ def _make_functions(model):
150254
count += length
151255

152256
allvars = at.concatenate([joined, *[var.ravel() for var in remaining_rvs]])
153-
func = aesara.function(
257+
expand_fn_at = aesara.compile.function.function(
154258
(joined,), (allvars,), givens=symbolic_sliced, mode=aesara.compile.NUMBA
155259
)
156-
func = func.vm.jit_fn.py_func
157-
expanding_function = numba.njit(func, fastmath=True, error_model="numpy")
158-
expanding_function(np.zeros(num_free_vars))
260+
expand_fn = expand_fn_at.vm.jit_fn
261+
# expand_fn = numba.njit(expand_fn, fastmath=True, error_model="numpy")
262+
# Trigger a compile
263+
expand_fn(np.zeros(num_free_vars))
159264

160265
return (
161266
num_free_vars,
162-
logp_func,
163-
expanding_function,
267+
logp_fn_at,
268+
logp_fn,
269+
expand_fn,
270+
[var.name for var in expand_fn_at.get_shared()],
164271
(all_names, all_slices, all_shapes),
165272
)
166273

167274

168-
def _make_c_logp_func(N, logp_func):
275+
def make_extraction_fn(inner, shared_data, shared_vars, record_dtype):
276+
if not shared_vars:
277+
278+
@numba.njit(inline="always")
279+
def extract_shared(x, user_data_):
280+
return inner(x)
281+
282+
return extract_shared
283+
284+
shared_metadata = tuple(
285+
[
286+
name,
287+
len(shared_data[name].shape),
288+
shared_data[name].shape,
289+
np.dtype(shared_data[name].dtype),
290+
]
291+
for name in shared_vars
292+
)
293+
294+
names = shared_vars
295+
indices = tuple(range(len(names)))
296+
shared_tuple = tuple(shared_data[name] for name in shared_vars)
297+
298+
@numba.extending.intrinsic
299+
def tuple_setitem_literal(typingctx, tup, idx, val):
300+
"""Return a copy of the tuple with item at *idx* replaced with *val*.
301+
"""
302+
if not isinstance(idx, numba.types.IntegerLiteral):
303+
return
304+
305+
idx_val = idx.literal_value
306+
assert idx_val >= 0
307+
assert idx_val < len(tup)
308+
309+
import llvmlite
310+
311+
def codegen(context, builder, signature, args):
312+
tup, idx, val = args
313+
stack = alloca_once(builder, tup.type)
314+
builder.store(tup, stack)
315+
# Unsafe load on unchecked bounds. Poison value maybe returned.
316+
tuple_idx = llvmlite.ir.IntType(32)(idx_val)
317+
offptr = builder.gep(stack, [idx.type(0), tuple_idx], inbounds=True)
318+
builder.store(val, offptr)
319+
return builder.load(stack)
320+
321+
sig = tup(tup, idx, tup[idx_val])
322+
return sig, codegen
323+
324+
def extract_array(user_data, index):
325+
pass
326+
327+
@numba.extending.overload(extract_array, inline="always")
328+
def impl_extract_array(user_data, index):
329+
if not isinstance(index, numba.types.Literal):
330+
return
331+
332+
index = index.literal_value
333+
334+
name, ndim, base_shape, dtype = shared_metadata[index]
335+
336+
ndim_range = tuple(range(ndim))
337+
338+
def impl(user_data, index):
339+
data_ptr = address_as_void_pointer(user_data["data"][name][()])
340+
data = numba.carray(data_ptr, int(user_data["size"][name][()]), dtype)
341+
342+
shape = user_data["shape"][name]
343+
344+
assert len(shape) == len(base_shape)
345+
346+
shape_ = base_shape
347+
348+
# For some reason I get typing errors without this if condition
349+
if ndim > 0:
350+
for i in range(ndim):
351+
shape_ = tuple_setitem(shape_, i, shape[i])
352+
353+
return data.reshape(shape_)
354+
355+
return impl
356+
357+
@numba.njit(inline="always")
358+
def extract_shared(x, user_data_):
359+
user_data = numba.carray(user_data_, (), record_dtype)
360+
361+
_shared_tuple = shared_tuple
362+
for index in literal_unroll(indices):
363+
dat = extract_array(user_data["shared"], index)
364+
_shared_tuple = tuple_setitem_literal(_shared_tuple, index, dat)
365+
366+
return inner(x, *_shared_tuple)
367+
368+
return extract_shared
369+
370+
371+
def _make_c_logp_func(n_dim, logp_fn, user_data, shared_logp, shared_data):
372+
373+
extract = make_extraction_fn(logp_fn, shared_data, shared_logp, user_data.dtype)
374+
169375
c_sig = numba.types.int64(
170376
numba.types.uint64,
171377
numba.types.CPointer(numba.types.double),
@@ -175,21 +381,24 @@ def _make_c_logp_func(N, logp_func):
175381
)
176382

177383
def logp_numba(dim, x_, out_, logp_, user_data_):
384+
if dim != n_dim:
385+
return -1
386+
178387
try:
179-
x = numba.carray(x_, (N,))
180-
out = numba.carray(out_, (N,))
388+
x = numba.carray(x_, (n_dim,))
389+
out = numba.carray(out_, (n_dim,))
181390
logp = numba.carray(logp_, ())
182391

183-
logp_val, grad = logp_func(x)
392+
logp_val, grad = extract(x, user_data_)
184393
logp[()] = logp_val
185394
out[...] = grad
186395

187396
if not np.all(np.isfinite(out)):
188-
return 2
189-
if not np.isfinite(logp_val):
190397
return 3
191-
if np.any(out == 0):
398+
if not np.isfinite(logp_val):
192399
return 4
400+
# if np.any(out == 0):
401+
# return 4
193402
except Exception:
194403
return 1
195404
return 0

0 commit comments

Comments
 (0)