Skip to content

Commit ee0969e

Browse files
committed
Derive logprob for Split operation
1 parent dc7cfee commit ee0969e

File tree

2 files changed

+223
-3
lines changed

2 files changed

+223
-3
lines changed

pymc/logprob/tensor.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@
3939
from pytensor import tensor as pt
4040
from pytensor.graph.fg import FunctionGraph
4141
from pytensor.graph.rewriting.basic import node_rewriter
42+
from pytensor.npy_2_compat import normalize_axis_index
4243
from pytensor.tensor import TensorVariable
43-
from pytensor.tensor.basic import Join, MakeVector
44+
from pytensor.tensor.basic import Join, MakeVector, Split
4445
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4546
from pytensor.tensor.random.op import RandomVariable
4647
from pytensor.tensor.random.rewriting import (
4748
local_dimshuffle_rv_lift,
4849
)
4950

51+
from pymc.exceptions import NotConstantValueError
5052
from pymc.logprob.abstract import (
5153
MeasurableOp,
5254
ValuedRV,
@@ -70,7 +72,7 @@
7072

7173

7274
class MeasurableMakeVector(MeasurableOp, MakeVector):
73-
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""
75+
"""A placeholder used to specify a log-likelihood for a make_vector sub-graph."""
7476

7577

7678
@_logprob.register(MeasurableMakeVector)
@@ -183,6 +185,64 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None:
183185
return [measurable_stack]
184186

185187

188+
class MeasurableSplit(MeasurableOp, Split):
189+
"""A placeholder used to specify a log-likelihood for a split sub-graph."""
190+
191+
192+
@node_rewriter([Split])
193+
def find_measurable_splits(fgraph, node) -> list[TensorVariable] | None:
194+
if isinstance(node.op, MeasurableOp):
195+
return None
196+
197+
x, axis, splits = node.inputs
198+
if not filter_measurable_variables([x]):
199+
return None
200+
201+
return MeasurableSplit(node.op.len_splits)(x, axis, splits)
202+
203+
204+
@_logprob.register(MeasurableSplit)
205+
def logprob_split(op: MeasurableSplit, values, x, axis, splits, **kwargs):
206+
"""Compute the log-likelihood graph for a `MeasurableSplit`."""
207+
if len(values) != op.len_splits:
208+
# TODO: Don't rewrite the split in the first place if not all parts are linked to value variables
209+
# This also allows handling some cases where not all splits are used
210+
raise ValueError("Split logp requires the number of values to match the number of splits")
211+
212+
# Reverse the effects of split on the value variable
213+
join_value = pt.join(axis, *values)
214+
215+
join_logp = _logprob_helper(x, join_value)
216+
217+
reduced_dims = join_value.ndim - join_logp.ndim
218+
219+
if reduced_dims:
220+
# This happens for multivariate distributions
221+
try:
222+
[constant_axis] = constant_fold([axis])
223+
except NotConstantValueError:
224+
raise NotImplementedError("Cannot split multivariate logp with non-constant axis")
225+
226+
constant_axis = normalize_axis_index(constant_axis, join_value.ndim)
227+
if constant_axis >= join_logp.ndim:
228+
# If the axis is over a dimension that was reduced in the logp (multivariate logp),
229+
# We cannot split it into distinct entries. The mapping between values-densities breaks.
230+
# We return the weighted logp by the split sizes. This is a good solution as any?
231+
split_weights = splits / pt.sum(splits)
232+
return [join_logp * split_weights[i] for i in range(op.len_splits)]
233+
else:
234+
# Otherwise we can split the logp as the split were over batched dimensions
235+
# We just need to be sure to use the positive axis index
236+
axis = constant_axis
237+
238+
return pt.split(
239+
join_logp,
240+
splits_size=splits,
241+
n_splits=op.len_splits,
242+
axis=axis,
243+
)
244+
245+
186246
class MeasurableDimShuffle(MeasurableOp, DimShuffle):
187247
"""A placeholder used to specify a log-likelihood for a dimshuffle sub-graph."""
188248

@@ -308,3 +368,10 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None:
308368
"basic",
309369
"tensor",
310370
)
371+
372+
measurable_ir_rewrites_db.register(
373+
"find_measurable_splits",
374+
find_measurable_splits,
375+
"basic",
376+
"tensor",
377+
)

tests/logprob/test_tensor.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from pytensor import tensor as pt
4242
from pytensor.graph import RewriteDatabaseQuery
43+
from pytensor.tensor.random.type import random_generator_type
4344
from scipy import stats as st
4445

4546
from pymc.logprob.basic import conditional_logp, logp
@@ -352,7 +353,7 @@ def test_measurable_dimshuffle(ds_order, multivariate):
352353
np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value))
353354

354355

355-
def test_unmeargeable_dimshuffles():
356+
def test_unmeasurable_dimshuffles():
356357
# Test that graphs with DimShuffles that cannot be lifted/merged fail
357358

358359
# Initial support axis is at axis=-1
@@ -372,3 +373,155 @@ def test_unmeargeable_dimshuffles():
372373
# TODO: Check that logp is correct if this type of graphs is ever supported
373374
with pytest.raises(RuntimeError, match="could not be derived"):
374375
conditional_logp({w: w_vv})
376+
377+
378+
class TestMeasurableSplit:
379+
def test_univariate(self):
380+
rng = np.random.default_rng(388)
381+
mu = np.arange(6)[:, None]
382+
sigma = np.arange(5) + 1
383+
384+
x = pt.random.normal(mu, sigma, size=(6, 5), name="x")
385+
386+
# axis=0
387+
x_parts = pt.split(x, splits_size=[2, 4], n_splits=2, axis=0)
388+
x_parts_vv = [x_part.clone() for x_part in x_parts]
389+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
390+
391+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
392+
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
393+
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
394+
np.testing.assert_allclose(
395+
logp_x1_eval,
396+
st.norm.logpdf(x_parts_test[0], mu[:2], sigma),
397+
)
398+
np.testing.assert_allclose(
399+
logp_x2_eval,
400+
st.norm.logpdf(x_parts_test[1], mu[2:], sigma),
401+
)
402+
403+
# axis=1
404+
x_parts = pt.split(x, splits_size=[2, 1, 2], n_splits=3, axis=1)
405+
x_parts_vv = [x_part.clone() for x_part in x_parts]
406+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
407+
408+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
409+
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
410+
logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test)
411+
np.testing.assert_allclose(
412+
logp_x1_eval,
413+
st.norm.logpdf(x_parts_test[0], mu, sigma[:2]),
414+
)
415+
np.testing.assert_allclose(
416+
logp_x2_eval,
417+
st.norm.logpdf(x_parts_test[1], mu, sigma[2:3]),
418+
)
419+
np.testing.assert_allclose(
420+
logp_x3_eval,
421+
st.norm.logpdf(x_parts_test[2], mu, sigma[3:]),
422+
)
423+
424+
def test_multivariate(self):
425+
@np.vectorize(signature=("(n),(n)->()"))
426+
def scipy_dirichlet_logpdf(x, alpha):
427+
"""Compute the logpdf of a Dirichlet distribution using scipy."""
428+
return st.dirichlet.logpdf(x, alpha)
429+
430+
# (3, 5) Dirichlet
431+
rng = np.random.default_rng(426)
432+
rng_pt = random_generator_type("rng")
433+
alpha = np.linspace(1, 10, 5) * np.array([1, 10, 100])[:, None]
434+
x = pt.random.dirichlet(alpha, rng=rng_pt)
435+
436+
# axis=-2 (i.e., 0, - batch dimension)
437+
x_parts = pt.split(x, splits_size=[2, 1], n_splits=2, axis=-2)
438+
x_parts_vv = [x_part.clone() for x_part in x_parts]
439+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
440+
assert logp_parts[0].type.shape == (2,)
441+
assert logp_parts[1].type.shape == (1,)
442+
443+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
444+
x_parts_test = pytensor.function([rng_pt], x_parts)(rng)
445+
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
446+
np.testing.assert_allclose(
447+
logp_x1_eval,
448+
scipy_dirichlet_logpdf(x_parts_test[0], alpha[:2]),
449+
)
450+
np.testing.assert_allclose(
451+
logp_x2_eval,
452+
scipy_dirichlet_logpdf(x_parts_test[1], alpha[2:]),
453+
)
454+
455+
# axis=-1 (i.e., 1, - support dimension)
456+
x_parts = pt.split(x, splits_size=[2, 3], n_splits=2, axis=-1)
457+
x_parts_vv = [x_part.clone() for x_part in x_parts]
458+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
459+
460+
assert logp_parts[0].type.shape == (3,)
461+
assert logp_parts[1].type.shape == (3,)
462+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
463+
464+
x_parts_test = pytensor.function([rng_pt], x_parts)(rng)
465+
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
466+
np.testing.assert_allclose(logp_x1_eval * 3, logp_x2_eval * 2)
467+
logp_total = logp_x1_eval + logp_x2_eval
468+
np.testing.assert_allclose(
469+
logp_total,
470+
scipy_dirichlet_logpdf(np.concatenate(x_parts_test, axis=1), alpha),
471+
)
472+
473+
@pytest.mark.xfail(
474+
reason="Rewrite from partial split to split on subtensor not implemented yet"
475+
)
476+
def test_not_all_splits_used(self):
477+
x = pt.random.normal(mu=pt.arange(6), name="x")
478+
x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[
479+
::2
480+
] # Only use first two splits
481+
x_parts_vv = [x_part.clone() for x_part in x_parts]
482+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
483+
assert len(logp_parts) == 2
484+
485+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
486+
x_parts_test = [x_part.eval() for x_part in x_parts_vv]
487+
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
488+
np.testing.assert_allclose(
489+
logp_x1_eval,
490+
st.norm.logpdf(x_parts_test[0], loc=[0, 1]),
491+
)
492+
np.testing.assert_allclose(
493+
logp_x2_eval,
494+
st.norm.logpdf(x_parts_test[1], loc=[4, 5]),
495+
)
496+
497+
def test_not_all_splits_used_core_dim(self):
498+
# TODO: We could support this for univariate/batch dimensions by rewriting as
499+
# split(x, splits_size=[2, 2, 2], n_splits=3, axis=1)[:2] -> split(x[:-2], splits_size=[2, 2], n_splits=2, axis=1)
500+
# And letting logp infer the probability of x[:-2]
501+
x = pt.random.dirichlet(alphas=pt.ones(6), name="x")
502+
x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[
503+
:2
504+
] # Only use first two splits
505+
x_parts_vv = [x_part.clone() for x_part in x_parts]
506+
507+
with pytest.raises(
508+
ValueError,
509+
match="Split logp requires the number of values to match the number of splits",
510+
):
511+
conditional_logp(dict(zip(x_parts, x_parts_vv)))
512+
513+
@pytest.mark.xfail(reason="Rewrite from subtensor to split not implemented yet")
514+
def test_subtensor_converted_to_splits(self):
515+
rng = np.random.default_rng(388)
516+
x = pt.random.normal(mu=pt.arange(5), name="x")
517+
518+
x_parts = [x[:2], x[2:3], x[3:]]
519+
x_parts_vv = [x_part.clone() for x_part in x_parts]
520+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
521+
assert len(logp_parts) == 3
522+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
523+
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
524+
logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test)
525+
np.testing.assert_allclose(logp_x1_eval, st.norm.logpdf(x_parts_test[0], loc=[0, 1]))
526+
np.testing.assert_allclose(logp_x2_eval, st.norm.logpdf(x_parts_test[1], loc=[2]))
527+
np.testing.assert_allclose(logp_x3_eval, st.norm.logpdf(x_parts_test[2], loc=[3, 4]))

0 commit comments

Comments
 (0)