Skip to content

Commit da89b9d

Browse files
Implement pack/unpack helpers
1 parent 79444a3 commit da89b9d

File tree

2 files changed

+143
-3
lines changed

2 files changed

+143
-3
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytensor.scalar import upcast
2929
from pytensor.tensor import TensorLike, as_tensor_variable
3030
from pytensor.tensor import basic as ptb
31-
from pytensor.tensor.basic import alloc, join, second
31+
from pytensor.tensor.basic import alloc, arange, join, second
3232
from pytensor.tensor.exceptions import NotScalarConstantError
3333
from pytensor.tensor.math import abs as pt_abs
3434
from pytensor.tensor.math import all as pt_all
@@ -47,7 +47,7 @@
4747
from pytensor.tensor.math import max as pt_max
4848
from pytensor.tensor.math import sum as pt_sum
4949
from pytensor.tensor.shape import Shape_i
50-
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
50+
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor, take
5151
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
5252
from pytensor.tensor.utils import normalize_reduce_axis
5353
from pytensor.tensor.variable import TensorVariable
@@ -2074,6 +2074,73 @@ def concat_with_broadcast(tensor_list, axis=0):
20742074
return join(axis, *bcast_tensor_inputs)
20752075

20762076

2077+
def pack(
2078+
*tensors: TensorVariable,
2079+
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
2080+
"""
2081+
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
2082+
2083+
Parameters
2084+
----------
2085+
tensors: TensorVariable
2086+
Tensors to be packed into a single vector.
2087+
2088+
Returns
2089+
-------
2090+
flat_tensor: TensorVariable
2091+
A new symbolic variable representing the concatenated 1d vector of all tensor inputs
2092+
packed_shapes: list of tuples of TensorVariable
2093+
A list of tuples, where each tuple contains the symbolic shape of the original tensors.
2094+
"""
2095+
if not tensors:
2096+
raise ValueError("Cannot pack an empty list of tensors.")
2097+
2098+
# Get the shapes of the input tensors
2099+
packed_shapes = [
2100+
t.type.shape if not any(s is None for s in t.type.shape) else t.shape
2101+
for t in tensors
2102+
]
2103+
2104+
# Flatten each tensor and concatenate them into a single 1D vector
2105+
flat_tensor = join(0, *[t.ravel() for t in tensors])
2106+
2107+
return flat_tensor, packed_shapes
2108+
2109+
2110+
def unpack(
2111+
flat_tensor: TensorVariable, packed_shapes: list[tuple[TensorVariable | int]]
2112+
) -> tuple[TensorVariable, ...]:
2113+
"""
2114+
Unpack a flat tensor into its original shapes based on the provided packed shapes.
2115+
2116+
Parameters
2117+
----------
2118+
flat_tensor: TensorVariable
2119+
A 1D tensor that contains the concatenated values of the original tensors.
2120+
packed_shapes: list of tuples of TensorVariable
2121+
A list of tuples, where each tuple contains the symbolic shape of the original tensors.
2122+
2123+
Returns
2124+
-------
2125+
unpacked_tensors: tuple of TensorVariable
2126+
A tuple containing the unpacked tensors with their original shapes.
2127+
"""
2128+
if not packed_shapes:
2129+
raise ValueError("Cannot unpack an empty list of shapes.")
2130+
2131+
start = 0
2132+
unpacked_tensors = []
2133+
for shape in packed_shapes:
2134+
size = prod(shape, no_zeros_in_input=True)
2135+
end = start + size
2136+
unpacked_tensors.append(
2137+
take(flat_tensor, arange(start, end, dtype="int").reshape(shape))
2138+
)
2139+
start = end
2140+
2141+
return tuple(unpacked_tensors)
2142+
2143+
20772144
__all__ = [
20782145
"searchsorted",
20792146
"cumsum",
@@ -2096,4 +2163,6 @@ def concat_with_broadcast(tensor_list, axis=0):
20962163
"logspace",
20972164
"linspace",
20982165
"broadcast_arrays",
2166+
"pack",
2167+
"unpack",
20992168
]

tests/tensor/test_extra_ops.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile.mode import Mode
1010
from pytensor.configdefaults import config
1111
from pytensor.graph import rewrite_graph
12-
from pytensor.graph.basic import Constant, applys_between, equal_computations
12+
from pytensor.graph.basic import Constant, Variable, applys_between, equal_computations
1313
from pytensor.npy_2_compat import old_np_unique
1414
from pytensor.raise_op import Assert
1515
from pytensor.tensor import alloc
@@ -37,11 +37,13 @@
3737
diff,
3838
fill_diagonal,
3939
fill_diagonal_offset,
40+
pack,
4041
ravel_multi_index,
4142
repeat,
4243
searchsorted,
4344
squeeze,
4445
to_one_hot,
46+
unpack,
4547
unravel_index,
4648
)
4749
from pytensor.tensor.type import (
@@ -1378,3 +1380,72 @@ def test_concat_with_broadcast():
13781380
a = pt.tensor("a", shape=(1, 3, 5))
13791381
b = pt.tensor("b", shape=(3, 5))
13801382
pt.concat_with_broadcast([a, b], axis=1)
1383+
1384+
1385+
@pytest.mark.parametrize(
1386+
"shapes, expected_flat_shape",
1387+
[([(), (5,), (3, 3)], 15), ([(), (None,), (None, None)], None)],
1388+
ids=["static", "symbolic"],
1389+
)
1390+
def test_pack(shapes, expected_flat_shape):
1391+
rng = np.random.default_rng()
1392+
1393+
x = pt.tensor("x", shape=shapes[0])
1394+
y = pt.tensor("y", shape=shapes[1])
1395+
z = pt.tensor("z", shape=shapes[2])
1396+
1397+
has_static_shape = [not any(s is None for s in shape) for shape in shapes]
1398+
1399+
flat_packed, packed_shapes = pack(x, y, z)
1400+
1401+
assert flat_packed.type.shape[0] == expected_flat_shape
1402+
1403+
for i, (packed_shape, has_static) in enumerate(
1404+
zip(packed_shapes, has_static_shape)
1405+
):
1406+
if has_static:
1407+
assert packed_shape == shapes[i]
1408+
else:
1409+
assert isinstance(packed_shape, Variable)
1410+
1411+
new_outputs = unpack(flat_packed, packed_shapes)
1412+
1413+
assert len(new_outputs) == 3
1414+
assert all(
1415+
out.type.shape == var.type.shape for out, var in zip(new_outputs, [x, y, z])
1416+
)
1417+
1418+
fn = function([x, y, z], new_outputs, mode="FAST_COMPILE")
1419+
1420+
input_vals = [
1421+
rng.normal(size=shape).astype(config.floatX)
1422+
for var, shape in zip([x, y, z], [(), (5,), (3, 3)])
1423+
]
1424+
new_output_vals = fn(*input_vals)
1425+
for input, output in zip(input_vals, new_output_vals):
1426+
np.testing.assert_allclose(input, output)
1427+
1428+
1429+
def test_make_replacements_with_pack_unpack():
1430+
rng = np.random.default_rng()
1431+
1432+
x = pt.tensor("x", shape=())
1433+
y = pt.tensor("y", shape=(5,))
1434+
z = pt.tensor("z", shape=(3, 3))
1435+
1436+
loss = (x + y.sum() + z.sum()) ** 2
1437+
1438+
flat_packed, packed_shapes = pack(x, y, z)
1439+
new_input = flat_packed.type()
1440+
new_outputs = unpack(new_input, packed_shapes)
1441+
1442+
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
1443+
fn = pytensor.function([new_input], loss, mode="FAST_COMPILE")
1444+
1445+
input_vals = [
1446+
rng.normal(size=(var.type.shape).astype(config.floatX)) for var in [x, y, z]
1447+
]
1448+
flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0)
1449+
output_val = fn(flat_inputs)
1450+
1451+
assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2)

0 commit comments

Comments
 (0)