Skip to content

Implement pack/unpack helpers #1578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Aug 10, 2025

Description

Adds pt.pack and pt.unpack helpers, roughly conforming to the einops functions of the same name.

These helps are for situations where we have a ragged list of inputs that need to be raveled into a single flat list for some intermediate step. This occurs in places like optimization.

Example usage:

x = pt.tensor("x", shape=shapes[0])
y = pt.tensor("y", shape=shapes[1])
z = pt.tensor("z", shape=shapes[2])

flat_params, packed_shapes = pt.pack(x, y, z)

Unpack simply undoes the computation, although there's norewrite to ensure pt.unpack(*pt.pack(*inputs)) is the identity function:

x, y, z = pt.unpack(flat_params, packed_shapes)

The use-case I forsee is creating replacement for a function of the inputs we're packing, for example:

loss = (x + y.sum() + z.sum()) ** 2

flat_packed, packed_shapes = pack(x, y, z)
new_input = flat_packed.type()
new_outputs = unpack(new_input, packed_shapes)

loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
fn = pytensor.function([new_input], loss)

Note that the final compiled function depends only on new_input, only because the shapes of the 3 packed variables were statically known. This leads to my design choices section:

  1. I decided to work with the static shapes directly if they are available. This means that pack will eagerly return a list of integer shapes as packed_shapes if possible. If not possible, they will be symbolic shapes. This is maybe an anti-pattern -- we might prefer a rewrite to handle this later, but it seemed easy enough to do eagerly.
  2. I didn't add support for batch dims. This is left to the user to do himself using pt.vectorize.
  3. The einops API has arguments to support packing/unpacking on arbitrary subsets of dimensions. I didn't do this, because I couldn't think of a use-case that a user couldn't get himself using DimShuffle and vectorize.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1578.org.readthedocs.build/en/1578/

@jessegrabowski jessegrabowski added the enhancement New feature or request label Aug 10, 2025
@jessegrabowski
Copy link
Member Author

The pack -> type -> unpack -> replace pattern might be common enough to merit it's own helper. PyMC has tools for doing this, for example, in RaveledArray and DictToArrayBijector, that could be replaced with appropriate symbolic operations.

One other thing I forgot to mention is that this will all fail on inputs with shape 0, since that will ruin the prod(shape) used to get the shape of the flat output. Not sure what to do in that case.

Copy link

codecov bot commented Aug 10, 2025

Codecov Report

❌ Patch coverage is 78.94737% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.54%. Comparing base (f9a3234) to head (9ead211).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/extra_ops.py 78.94% 2 Missing and 2 partials ⚠️

❌ Your patch check has failed because the patch coverage (78.94%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1578      +/-   ##
==========================================
- Coverage   81.54%   81.54%   -0.01%     
==========================================
  Files         230      230              
  Lines       53136    53153      +17     
  Branches     9448     9451       +3     
==========================================
+ Hits        43329    43342      +13     
- Misses       7370     7372       +2     
- Partials     2437     2439       +2     
Files with missing lines Coverage Δ
pytensor/tensor/extra_ops.py 88.28% <78.94%> (-0.29%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 10, 2025

  1. Better to have the same types as return, static shape to constant is introduced during rewrites already

2 and 3. I would really like to have these, it's what I needed for the batched_dot_to_core rewrites.This isn't a simple case of vectorize because the dims I want to pack are both on the left and right of other dims

for shape in packed_shapes:
size = prod(shape, no_zeros_in_input=True)
end = start + size
unpacked_tensors.append(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take uses advanced indexing. Actually the best here is probably split. Join and split are the inverses of each other, and will be easier to rewrite away

start = 0
unpacked_tensors = []
for shape in packed_shapes:
size = prod(shape, no_zeros_in_input=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no zeros in input? The shape doesn't show up in gradients if that's what you were worried about

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX needs it as well iirc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why you can't have zeros in the shapes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok ok ok i'll fix it

@@ -2074,6 +2074,73 @@ def concat_with_broadcast(tensor_list, axis=0):
return join(axis, *bcast_tensor_inputs)


def pack(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a docstring (doctested) example

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i will of course add docstrings. This was just to get a PR on the board and see your reaction to the 3 issues I raised. I didn't want to document everything before we decided on the final API

@ricardoV94
Copy link
Member

I am inclined to making this a core op and not just a helper. It obliviates most uses of reshape and it's much easier to reason about, not having to worry about pesky -1 or whether the reshape shape comes from the original input shapes or not.

That would pretty much address #883

We could use OFG and/or specialize to reshape/split later. It need also not be done in this PR. It's an implementation detail as far as the user is concerned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement pack/unpack Ops
2 participants