Skip to content

Derive logprob for Split operation #7875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,23 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import typing

from pathlib import Path

from pytensor import tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.npy_2_compat import normalize_axis_index
from pytensor.tensor import TensorVariable
from pytensor.tensor.basic import Join, MakeVector
from pytensor.tensor.basic import Join, MakeVector, Split
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.rewriting import (
local_dimshuffle_rv_lift,
)

from pymc.exceptions import NotConstantValueError
from pymc.logprob.abstract import (
MeasurableOp,
ValuedRV,
Expand All @@ -70,7 +73,7 @@


class MeasurableMakeVector(MeasurableOp, MakeVector):
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""
"""A placeholder used to specify a log-likelihood for a make_vector sub-graph."""


@_logprob.register(MeasurableMakeVector)
Expand Down Expand Up @@ -183,6 +186,64 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None:
return [measurable_stack]


class MeasurableSplit(MeasurableOp, Split):
"""A placeholder used to specify a log-likelihood for a split sub-graph."""


@node_rewriter([Split])
def find_measurable_splits(fgraph, node) -> list[TensorVariable] | None:
if isinstance(node.op, MeasurableOp):
return None

x, axis, splits = node.inputs
if not filter_measurable_variables([x]):
return None

return MeasurableSplit(node.op.len_splits).make_node(x, axis, splits).outputs


@_logprob.register(MeasurableSplit)
def logprob_split(op: MeasurableSplit, values, x, axis, splits, **kwargs):
"""Compute the log-likelihood graph for a `MeasurableSplit`."""
if len(values) != op.len_splits:
# TODO: Don't rewrite the split in the first place if not all parts are linked to value variables
# This also allows handling some cases where not all splits are used
raise ValueError("Split logp requires the number of values to match the number of splits")

# Reverse the effects of split on the value variable
join_value = pt.join(axis, *values)

join_logp = _logprob_helper(x, join_value)

reduced_dims = join_value.ndim - join_logp.ndim

if reduced_dims:
# This happens for multivariate distributions
try:
[constant_axis] = constant_fold([axis])
except NotConstantValueError:
raise NotImplementedError("Cannot split multivariate logp with non-constant axis")

constant_axis = normalize_axis_index(constant_axis, join_value.ndim) # type: ignore[arg-type, assignment]
if constant_axis >= join_logp.ndim:
# If the axis is over a dimension that was reduced in the logp (multivariate logp),
# We cannot split it into distinct entries. The mapping between values-densities breaks.
# We return the weighted logp by the split sizes. This is a good solution as any?
split_weights = splits / pt.sum(splits)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this legit?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so? In MarginalMixture we decided to set the whole logp on the first entry, and zero for others, I like this approach more

return [join_logp * split_weights[i] for i in range(typing.cast(int, op.len_splits))]
else:
# Otherwise we can split the logp as the split were over batched dimensions
# We just need to be sure to use the positive axis index
axis = constant_axis

return pt.split(
join_logp,
splits_size=splits,
n_splits=op.len_splits,
axis=axis,
)


class MeasurableDimShuffle(MeasurableOp, DimShuffle):
"""A placeholder used to specify a log-likelihood for a dimshuffle sub-graph."""

Expand Down Expand Up @@ -308,3 +369,10 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None:
"basic",
"tensor",
)

measurable_ir_rewrites_db.register(
"find_measurable_splits",
find_measurable_splits,
"basic",
"tensor",
)
155 changes: 154 additions & 1 deletion tests/logprob/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from pytensor import tensor as pt
from pytensor.graph import RewriteDatabaseQuery
from pytensor.tensor.random.type import random_generator_type
from scipy import stats as st

from pymc.logprob.basic import conditional_logp, logp
Expand Down Expand Up @@ -352,7 +353,7 @@ def test_measurable_dimshuffle(ds_order, multivariate):
np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value))


def test_unmeargeable_dimshuffles():
def test_unmeasurable_dimshuffles():
# Test that graphs with DimShuffles that cannot be lifted/merged fail

# Initial support axis is at axis=-1
Expand All @@ -372,3 +373,155 @@ def test_unmeargeable_dimshuffles():
# TODO: Check that logp is correct if this type of graphs is ever supported
with pytest.raises(RuntimeError, match="could not be derived"):
conditional_logp({w: w_vv})


class TestMeasurableSplit:
def test_univariate(self):
rng = np.random.default_rng(388)
mu = np.arange(6)[:, None]
sigma = np.arange(5) + 1

x = pt.random.normal(mu, sigma, size=(6, 5), name="x")

# axis=0
x_parts = pt.split(x, splits_size=[2, 4], n_splits=2, axis=0)
x_parts_vv = [x_part.clone() for x_part in x_parts]
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())

logp_fn = pytensor.function(x_parts_vv, logp_parts)
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
np.testing.assert_allclose(
logp_x1_eval,
st.norm.logpdf(x_parts_test[0], mu[:2], sigma),
)
np.testing.assert_allclose(
logp_x2_eval,
st.norm.logpdf(x_parts_test[1], mu[2:], sigma),
)

# axis=1
x_parts = pt.split(x, splits_size=[2, 1, 2], n_splits=3, axis=1)
x_parts_vv = [x_part.clone() for x_part in x_parts]
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())

logp_fn = pytensor.function(x_parts_vv, logp_parts)
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test)
np.testing.assert_allclose(
logp_x1_eval,
st.norm.logpdf(x_parts_test[0], mu, sigma[:2]),
)
np.testing.assert_allclose(
logp_x2_eval,
st.norm.logpdf(x_parts_test[1], mu, sigma[2:3]),
)
np.testing.assert_allclose(
logp_x3_eval,
st.norm.logpdf(x_parts_test[2], mu, sigma[3:]),
)

def test_multivariate(self):
@np.vectorize(signature=("(n),(n)->()"))
def scipy_dirichlet_logpdf(x, alpha):
"""Compute the logpdf of a Dirichlet distribution using scipy."""
return st.dirichlet.logpdf(x, alpha)

# (3, 5) Dirichlet
rng = np.random.default_rng(426)
rng_pt = random_generator_type("rng")
alpha = np.linspace(1, 10, 5) * np.array([1, 10, 100])[:, None]
x = pt.random.dirichlet(alpha, rng=rng_pt)

# axis=-2 (i.e., 0, - batch dimension)
x_parts = pt.split(x, splits_size=[2, 1], n_splits=2, axis=-2)
x_parts_vv = [x_part.clone() for x_part in x_parts]
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do i understand this correctly that each part is conditioned on the values of all other parts?

Thinking about e.g. the MVN case, where if you split the vector and condition each split on the other, you get two new MVN distributions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no marginalization going on, you can't evaluate the logp of only one part without providing the remaining ones. The only thing we do is join the value, get the logp, and split it again. We could argue that we don't want to do this for multivariate variables split along the core dimension, since there's no way you can split the logp (I did the weighing, but we can revert and raise NotImplemented)

assert logp_parts[0].type.shape == (2,)
assert logp_parts[1].type.shape == (1,)

logp_fn = pytensor.function(x_parts_vv, logp_parts)
x_parts_test = pytensor.function([rng_pt], x_parts)(rng)
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
np.testing.assert_allclose(
logp_x1_eval,
scipy_dirichlet_logpdf(x_parts_test[0], alpha[:2]),
)
np.testing.assert_allclose(
logp_x2_eval,
scipy_dirichlet_logpdf(x_parts_test[1], alpha[2:]),
)

# axis=-1 (i.e., 1, - support dimension)
x_parts = pt.split(x, splits_size=[2, 3], n_splits=2, axis=-1)
x_parts_vv = [x_part.clone() for x_part in x_parts]
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())

assert logp_parts[0].type.shape == (3,)
assert logp_parts[1].type.shape == (3,)
logp_fn = pytensor.function(x_parts_vv, logp_parts)

x_parts_test = pytensor.function([rng_pt], x_parts)(rng)
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
np.testing.assert_allclose(logp_x1_eval * 3, logp_x2_eval * 2)
logp_total = logp_x1_eval + logp_x2_eval
np.testing.assert_allclose(
logp_total,
scipy_dirichlet_logpdf(np.concatenate(x_parts_test, axis=1), alpha),
)

@pytest.mark.xfail(
reason="Rewrite from partial split to split on subtensor not implemented yet"
)
def test_not_all_splits_used(self):
x = pt.random.normal(mu=pt.arange(6), name="x")
x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[
::2
] # Only use first two splits
x_parts_vv = [x_part.clone() for x_part in x_parts]
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
assert len(logp_parts) == 2

logp_fn = pytensor.function(x_parts_vv, logp_parts)
x_parts_test = [x_part.eval() for x_part in x_parts_vv]
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
np.testing.assert_allclose(
logp_x1_eval,
st.norm.logpdf(x_parts_test[0], loc=[0, 1]),
)
np.testing.assert_allclose(
logp_x2_eval,
st.norm.logpdf(x_parts_test[1], loc=[4, 5]),
)

def test_not_all_splits_used_core_dim(self):
# TODO: We could support this for univariate/batch dimensions by rewriting as
# split(x, splits_size=[2, 2, 2], n_splits=3, axis=1)[:2] -> split(x[:-2], splits_size=[2, 2], n_splits=2, axis=1)
# And letting logp infer the probability of x[:-2]
x = pt.random.dirichlet(alphas=pt.ones(6), name="x")
x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[
:2
] # Only use first two splits
x_parts_vv = [x_part.clone() for x_part in x_parts]

with pytest.raises(
ValueError,
match="Split logp requires the number of values to match the number of splits",
):
conditional_logp(dict(zip(x_parts, x_parts_vv)))

@pytest.mark.xfail(reason="Rewrite from subtensor to split not implemented yet")
def test_subtensor_converted_to_splits(self):
rng = np.random.default_rng(388)
x = pt.random.normal(mu=pt.arange(5), name="x")

x_parts = [x[:2], x[2:3], x[3:]]
x_parts_vv = [x_part.clone() for x_part in x_parts]
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
assert len(logp_parts) == 3
logp_fn = pytensor.function(x_parts_vv, logp_parts)
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test)
np.testing.assert_allclose(logp_x1_eval, st.norm.logpdf(x_parts_test[0], loc=[0, 1]))
np.testing.assert_allclose(logp_x2_eval, st.norm.logpdf(x_parts_test[1], loc=[2]))
np.testing.assert_allclose(logp_x3_eval, st.norm.logpdf(x_parts_test[2], loc=[3, 4]))
Loading