-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Derive logprob for Split operation #7875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ | |
|
||
from pytensor import tensor as pt | ||
from pytensor.graph import RewriteDatabaseQuery | ||
from pytensor.tensor.random.type import random_generator_type | ||
from scipy import stats as st | ||
|
||
from pymc.logprob.basic import conditional_logp, logp | ||
|
@@ -352,7 +353,7 @@ def test_measurable_dimshuffle(ds_order, multivariate): | |
np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value)) | ||
|
||
|
||
def test_unmeargeable_dimshuffles(): | ||
def test_unmeasurable_dimshuffles(): | ||
# Test that graphs with DimShuffles that cannot be lifted/merged fail | ||
|
||
# Initial support axis is at axis=-1 | ||
|
@@ -372,3 +373,155 @@ def test_unmeargeable_dimshuffles(): | |
# TODO: Check that logp is correct if this type of graphs is ever supported | ||
with pytest.raises(RuntimeError, match="could not be derived"): | ||
conditional_logp({w: w_vv}) | ||
|
||
|
||
class TestMeasurableSplit: | ||
def test_univariate(self): | ||
rng = np.random.default_rng(388) | ||
mu = np.arange(6)[:, None] | ||
sigma = np.arange(5) + 1 | ||
|
||
x = pt.random.normal(mu, sigma, size=(6, 5), name="x") | ||
|
||
# axis=0 | ||
x_parts = pt.split(x, splits_size=[2, 4], n_splits=2, axis=0) | ||
x_parts_vv = [x_part.clone() for x_part in x_parts] | ||
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) | ||
|
||
logp_fn = pytensor.function(x_parts_vv, logp_parts) | ||
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv] | ||
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test) | ||
np.testing.assert_allclose( | ||
logp_x1_eval, | ||
st.norm.logpdf(x_parts_test[0], mu[:2], sigma), | ||
) | ||
np.testing.assert_allclose( | ||
logp_x2_eval, | ||
st.norm.logpdf(x_parts_test[1], mu[2:], sigma), | ||
) | ||
|
||
# axis=1 | ||
x_parts = pt.split(x, splits_size=[2, 1, 2], n_splits=3, axis=1) | ||
x_parts_vv = [x_part.clone() for x_part in x_parts] | ||
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) | ||
|
||
logp_fn = pytensor.function(x_parts_vv, logp_parts) | ||
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv] | ||
logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test) | ||
np.testing.assert_allclose( | ||
logp_x1_eval, | ||
st.norm.logpdf(x_parts_test[0], mu, sigma[:2]), | ||
) | ||
np.testing.assert_allclose( | ||
logp_x2_eval, | ||
st.norm.logpdf(x_parts_test[1], mu, sigma[2:3]), | ||
) | ||
np.testing.assert_allclose( | ||
logp_x3_eval, | ||
st.norm.logpdf(x_parts_test[2], mu, sigma[3:]), | ||
) | ||
|
||
def test_multivariate(self): | ||
@np.vectorize(signature=("(n),(n)->()")) | ||
def scipy_dirichlet_logpdf(x, alpha): | ||
"""Compute the logpdf of a Dirichlet distribution using scipy.""" | ||
return st.dirichlet.logpdf(x, alpha) | ||
|
||
# (3, 5) Dirichlet | ||
rng = np.random.default_rng(426) | ||
rng_pt = random_generator_type("rng") | ||
alpha = np.linspace(1, 10, 5) * np.array([1, 10, 100])[:, None] | ||
x = pt.random.dirichlet(alpha, rng=rng_pt) | ||
|
||
# axis=-2 (i.e., 0, - batch dimension) | ||
x_parts = pt.split(x, splits_size=[2, 1], n_splits=2, axis=-2) | ||
x_parts_vv = [x_part.clone() for x_part in x_parts] | ||
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do i understand this correctly that each part is conditioned on the values of all other parts? Thinking about e.g. the MVN case, where if you split the vector and condition each split on the other, you get two new MVN distributions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no marginalization going on, you can't evaluate the logp of only one part without providing the remaining ones. The only thing we do is join the value, get the logp, and split it again. We could argue that we don't want to do this for multivariate variables split along the core dimension, since there's no way you can split the logp (I did the weighing, but we can revert and raise NotImplemented) |
||
assert logp_parts[0].type.shape == (2,) | ||
assert logp_parts[1].type.shape == (1,) | ||
|
||
logp_fn = pytensor.function(x_parts_vv, logp_parts) | ||
x_parts_test = pytensor.function([rng_pt], x_parts)(rng) | ||
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test) | ||
np.testing.assert_allclose( | ||
logp_x1_eval, | ||
scipy_dirichlet_logpdf(x_parts_test[0], alpha[:2]), | ||
) | ||
np.testing.assert_allclose( | ||
logp_x2_eval, | ||
scipy_dirichlet_logpdf(x_parts_test[1], alpha[2:]), | ||
) | ||
|
||
# axis=-1 (i.e., 1, - support dimension) | ||
x_parts = pt.split(x, splits_size=[2, 3], n_splits=2, axis=-1) | ||
x_parts_vv = [x_part.clone() for x_part in x_parts] | ||
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) | ||
|
||
assert logp_parts[0].type.shape == (3,) | ||
assert logp_parts[1].type.shape == (3,) | ||
logp_fn = pytensor.function(x_parts_vv, logp_parts) | ||
|
||
x_parts_test = pytensor.function([rng_pt], x_parts)(rng) | ||
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test) | ||
np.testing.assert_allclose(logp_x1_eval * 3, logp_x2_eval * 2) | ||
logp_total = logp_x1_eval + logp_x2_eval | ||
np.testing.assert_allclose( | ||
logp_total, | ||
scipy_dirichlet_logpdf(np.concatenate(x_parts_test, axis=1), alpha), | ||
) | ||
|
||
@pytest.mark.xfail( | ||
reason="Rewrite from partial split to split on subtensor not implemented yet" | ||
) | ||
def test_not_all_splits_used(self): | ||
x = pt.random.normal(mu=pt.arange(6), name="x") | ||
x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[ | ||
::2 | ||
] # Only use first two splits | ||
x_parts_vv = [x_part.clone() for x_part in x_parts] | ||
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) | ||
assert len(logp_parts) == 2 | ||
|
||
logp_fn = pytensor.function(x_parts_vv, logp_parts) | ||
x_parts_test = [x_part.eval() for x_part in x_parts_vv] | ||
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test) | ||
np.testing.assert_allclose( | ||
logp_x1_eval, | ||
st.norm.logpdf(x_parts_test[0], loc=[0, 1]), | ||
) | ||
np.testing.assert_allclose( | ||
logp_x2_eval, | ||
st.norm.logpdf(x_parts_test[1], loc=[4, 5]), | ||
) | ||
|
||
def test_not_all_splits_used_core_dim(self): | ||
# TODO: We could support this for univariate/batch dimensions by rewriting as | ||
# split(x, splits_size=[2, 2, 2], n_splits=3, axis=1)[:2] -> split(x[:-2], splits_size=[2, 2], n_splits=2, axis=1) | ||
# And letting logp infer the probability of x[:-2] | ||
x = pt.random.dirichlet(alphas=pt.ones(6), name="x") | ||
x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[ | ||
:2 | ||
] # Only use first two splits | ||
x_parts_vv = [x_part.clone() for x_part in x_parts] | ||
|
||
with pytest.raises( | ||
ValueError, | ||
match="Split logp requires the number of values to match the number of splits", | ||
): | ||
conditional_logp(dict(zip(x_parts, x_parts_vv))) | ||
|
||
@pytest.mark.xfail(reason="Rewrite from subtensor to split not implemented yet") | ||
def test_subtensor_converted_to_splits(self): | ||
rng = np.random.default_rng(388) | ||
x = pt.random.normal(mu=pt.arange(5), name="x") | ||
|
||
x_parts = [x[:2], x[2:3], x[3:]] | ||
x_parts_vv = [x_part.clone() for x_part in x_parts] | ||
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) | ||
assert len(logp_parts) == 3 | ||
logp_fn = pytensor.function(x_parts_vv, logp_parts) | ||
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv] | ||
logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test) | ||
np.testing.assert_allclose(logp_x1_eval, st.norm.logpdf(x_parts_test[0], loc=[0, 1])) | ||
np.testing.assert_allclose(logp_x2_eval, st.norm.logpdf(x_parts_test[1], loc=[2])) | ||
np.testing.assert_allclose(logp_x3_eval, st.norm.logpdf(x_parts_test[2], loc=[3, 4])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this legit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so? In MarginalMixture we decided to set the whole logp on the first entry, and zero for others, I like this approach more