-
Notifications
You must be signed in to change notification settings - Fork 139
Rewrite solves involving kron to eliminate kron #1559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d71df9f
to
0f76af9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a rewrite optimization for solving linear systems involving Kronecker products. The goal is to transform expressions of the form solve(kron(A, B), x)
into an equivalent form that eliminates the Kronecker product computation, providing significant performance improvements.
Key changes:
- Added a new rewrite rule
rewrite_solve_kron_to_solve
that transforms Kronecker-based solves using mathematical identities - Comprehensive test coverage including correctness tests and benchmarks demonstrating substantial speedups
- Support for both batched and non-batched operations with limitations for certain matrix dimensions
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
pytensor/tensor/rewriting/linalg.py | Implements the core rewrite logic with mathematical transformation from Kronecker solve to nested solves |
tests/tensor/rewriting/test_linalg.py | Adds comprehensive test suite including correctness verification and performance benchmarks |
0f76af9
to
5be594a
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1559 +/- ##
=======================================
Coverage 81.53% 81.54%
=======================================
Files 230 230
Lines 53066 53144 +78
Branches 9423 9445 +22
=======================================
+ Hits 43269 43336 +67
- Misses 7364 7370 +6
- Partials 2433 2438 +5
🚀 New features to boost your workflow:
|
return None | ||
|
||
m, n = x1.shape[-2], x2.shape[-2] | ||
batch_shapes = x1.shape[:-2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x1/x2 batch shapes could broadcast in blockwise
# If x1 and x2 have statically known core shapes, check that they are square. If not, the rewrite will be invalid. | ||
# We will proceed if they are unknown, but this makes the rewrite shape unsafe. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape_unsafe
is when a rewrite can mask an originally invalid graph, but it / we aren't allowed to turn a previously valid graph into an invalid one. Is that what's happening here?
(*batch_shapes, -1, b_batch) | ||
) | ||
|
||
return [res] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing copy_stack_trace
# If shapes are static, it should always be applied | ||
A = pt.tensor("A", shape=(3, None, None)) | ||
B = pt.tensor("B", shape=(3, None, None)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Back to the previous comment, is the previous C
a valid graph? If so, we can't rewrite and break the graph if we don't know the core shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C is valid, because C is square. The "problem" is that we can kron together two non-square matrices and end up with a square one (e.g. kron((4,3), (3,4)) -> (7, 7)
). So the rewrite is invalid in this case.
This is another case where we really really wish we had a tag for "square matrix", without having to commit to shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wiki seems to suggest Kron(A, B) is only invertible if both A and B are invertible, so you couldn't solve C in the first place if this wasn't the case?
Is that correct? In that case it's fine to have the rewrite when the shapes are unknown (perhaps add a comment?). Otherwise it's not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The theory looks right.
The only issue I guess is that currently, you won't get an error if you have an "invalid" graph like:
A = rng.normal(size=(4, 3))
B = rng.normal(size=(3, 4))
A_pt, B_pt = pt.dmatrices('A', 'B')
y_pt = pt.dvector('y')
C = pt.linalg.kron(A_pt, B_pt)
x = pt.linalg.solve(C, y_pt)
fn = pytensor.function([A_pt, B_pt, y_pt], x)
You get a warning about numerical instability, but it gives you some numbers. Obviously these numbers are just nonsense, but it doesn't error. After the rewrite, you will get a shape error, which might be very surprising for someone who isn't providing a valid graph in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solve of C doesn't raise for "singular matrix"?
if not A.owner or not ( | ||
isinstance(A.owner.op, KroneckerProduct) | ||
or isinstance(A.owner.op, Blockwise) | ||
and isinstance(A.owner.op.core_op, KroneckerProduct) | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably broke the parenthesis, but you get the idea. Negate the whole condition that is required and group the Blockwies + KroneckerProduct
if not A.owner or not ( | |
isinstance(A.owner.op, KroneckerProduct) | |
or isinstance(A.owner.op, Blockwise) | |
and isinstance(A.owner.op.core_op, KroneckerProduct) | |
): | |
if not (A.owner and ( | |
isinstance(A.owner.op, KroneckerProduct) | |
or (isinstance(A.owner.op, Blockwise) | |
and isinstance(A.owner.op.core_op, KroneckerProduct))) | |
): |
Description
Rewrite graphs of the form
solve(kron(A, B), x)
tosolve(A, solve(B, x.reshape).mT).mT.reshape
. This eliminates the kronecker product, and provides significant speedup.Important limitation is that it only covers the case when
b_ndim=1
, because the math underpinning the rewrite requires thatx
is a vector. This is still an important case, however, because it's what arises in the logp of a multivariate normal when the covariance matrix is kronecker.Also I hit what appears to be a numerical bug in the batch case when
assume_a = 'pos'
. There is disagreement, but only in the 2nd row of the outputs. No matter the batch size, it's always the 2nd batch that has a numerical problem -- all other batches agree. I've left in the failing test for now. We don't even vectorize kron by default, so if I can't figure it out I might just disable the rewrite for the Blockwise(Kron) case for now.Benchmarks follow, with:
Related Issue
Solve
involvingKron
#1557Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1559.org.readthedocs.build/en/1559/