Skip to content

Rewrite solves involving kron to eliminate kron #1559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jul 28, 2025

Description

Rewrite graphs of the form solve(kron(A, B), x) to solve(A, solve(B, x.reshape).mT).mT.reshape. This eliminates the kronecker product, and provides significant speedup.

Important limitation is that it only covers the case when b_ndim=1, because the math underpinning the rewrite requires that x is a vector. This is still an important case, however, because it's what arises in the logp of a multivariate normal when the covariance matrix is kronecker.

Also I hit what appears to be a numerical bug in the batch case when assume_a = 'pos'. There is disagreement, but only in the 2nd row of the outputs. No matter the batch size, it's always the 2nd batch that has a numerical problem -- all other batches agree. I've left in the failing test for now. We don't even vectorize kron by default, so if I can't figure it out I might just disable the rewrite for the Blockwise(Kron) case for now.

Benchmarks follow, with:

  • small: A, B are (10, 10)
  • medium: A, B are (50, 50)
  • large: A, B are (100, 100)
-----------------------------------------------------------------------------------------
Name (time in us)                                                            Min                       Max                      Mean                 StdDev                    Median                    IQR            Outliers          OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_rewrite_solve_kron_to_solve_benchmark[no_rewrite-small]             17.2500 (1.0)             34.6250 (1.0)             18.5068 (1.0)           2.3577 (1.0)             17.6670 (1.0)           0.8755 (1.0)          8;10  54,034.2951 (1.0)          93           1
test_rewrite_solve_kron_to_solve_benchmark[rewrite-small]                19.2910 (1.12)            98.7500 (2.85)            21.9831 (1.19)          4.0015 (1.70)            20.9160 (1.18)          3.6250 (4.14)       135;35  45,489.5626 (0.84)       3261           1

test_rewrite_solve_kron_to_solve_benchmark[no_rewrite-medium]        93,532.8330 (>1000.0)     96,359.5420 (>1000.0)     94,835.3042 (>1000.0)     857.3874 (363.65)      94,672.9585 (>1000.0)   1,327.0000 (>1000.0)       3;0      10.5446 (0.00)         10           1
test_rewrite_solve_kron_to_solve_benchmark[rewrite-medium]               66.1660 (3.84)           288.5420 (8.33)            74.0905 (4.00)          8.2418 (3.50)            72.8750 (4.12)          5.9580 (6.81)      405;317  13,497.0108 (0.25)       7247           1

test_rewrite_solve_kron_to_solve_benchmark[no_rewrite-large]      3,250,903.0000 (>1000.0)  3,333,840.0830 (>1000.0)  3,300,615.3582 (>1000.0)  31,539.6476 (>1000.0)  3,302,135.1670 (>1000.0)  38,780.1145 (>1000.0)       2;0       0.3030 (0.00)          5           1
test_rewrite_solve_kron_to_solve_benchmark[rewrite-large]               183.1670 (10.62)          357.8750 (10.34)          196.7968 (10.63)        11.5470 (4.90)           194.6250 (11.02)         8.7920 (10.04)     401;206   5,081.3837 (0.09)       3442           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1559.org.readthedocs.build/en/1559/

@jessegrabowski jessegrabowski requested review from ricardoV94 and Copilot and removed request for ricardoV94 July 28, 2025 23:59
Copilot

This comment was marked as outdated.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a rewrite optimization for solving linear systems involving Kronecker products. The goal is to transform expressions of the form solve(kron(A, B), x) into an equivalent form that eliminates the Kronecker product computation, providing significant performance improvements.

Key changes:

  • Added a new rewrite rule rewrite_solve_kron_to_solve that transforms Kronecker-based solves using mathematical identities
  • Comprehensive test coverage including correctness tests and benchmarks demonstrating substantial speedups
  • Support for both batched and non-batched operations with limitations for certain matrix dimensions

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
pytensor/tensor/rewriting/linalg.py Implements the core rewrite logic with mathematical transformation from Kronecker solve to nested solves
tests/tensor/rewriting/test_linalg.py Adds comprehensive test suite including correctness verification and performance benchmarks

Copy link

codecov bot commented Aug 3, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.54%. Comparing base (892a8f0) to head (b06b0c7).
⚠️ Report is 7 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1559   +/-   ##
=======================================
  Coverage   81.53%   81.54%           
=======================================
  Files         230      230           
  Lines       53066    53144   +78     
  Branches     9423     9445   +22     
=======================================
+ Hits        43269    43336   +67     
- Misses       7364     7370    +6     
- Partials     2433     2438    +5     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/linalg.py 92.56% <100.00%> (+0.50%) ⬆️

... and 7 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

return None

m, n = x1.shape[-2], x2.shape[-2]
batch_shapes = x1.shape[:-2]
Copy link
Member

@ricardoV94 ricardoV94 Aug 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x1/x2 batch shapes could broadcast in blockwise

Comment on lines +899 to +900
# If x1 and x2 have statically known core shapes, check that they are square. If not, the rewrite will be invalid.
# We will proceed if they are unknown, but this makes the rewrite shape unsafe.
Copy link
Member

@ricardoV94 ricardoV94 Aug 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape_unsafe is when a rewrite can mask an originally invalid graph, but it / we aren't allowed to turn a previously valid graph into an invalid one. Is that what's happening here?

(*batch_shapes, -1, b_batch)
)

return [res]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing copy_stack_trace

Comment on lines +939 to +941
# If shapes are static, it should always be applied
A = pt.tensor("A", shape=(3, None, None))
B = pt.tensor("B", shape=(3, None, None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back to the previous comment, is the previous C a valid graph? If so, we can't rewrite and break the graph if we don't know the core shapes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C is valid, because C is square. The "problem" is that we can kron together two non-square matrices and end up with a square one (e.g. kron((4,3), (3,4)) -> (7, 7)). So the rewrite is invalid in this case.

This is another case where we really really wish we had a tag for "square matrix", without having to commit to shapes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wiki seems to suggest Kron(A, B) is only invertible if both A and B are invertible, so you couldn't solve C in the first place if this wasn't the case?

Is that correct? In that case it's fine to have the rewrite when the shapes are unknown (perhaps add a comment?). Otherwise it's not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The theory looks right.

The only issue I guess is that currently, you won't get an error if you have an "invalid" graph like:

A = rng.normal(size=(4, 3))
B = rng.normal(size=(3, 4))

A_pt, B_pt = pt.dmatrices('A', 'B')
y_pt = pt.dvector('y')
C = pt.linalg.kron(A_pt, B_pt)
x = pt.linalg.solve(C, y_pt)

fn = pytensor.function([A_pt, B_pt, y_pt], x)

You get a warning about numerical instability, but it gives you some numbers. Obviously these numbers are just nonsense, but it doesn't error. After the rewrite, you will get a shape error, which might be very surprising for someone who isn't providing a valid graph in the first place?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solve of C doesn't raise for "singular matrix"?

Comment on lines +890 to +894
if not A.owner or not (
isinstance(A.owner.op, KroneckerProduct)
or isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, KroneckerProduct)
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably broke the parenthesis, but you get the idea. Negate the whole condition that is required and group the Blockwies + KroneckerProduct

Suggested change
if not A.owner or not (
isinstance(A.owner.op, KroneckerProduct)
or isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, KroneckerProduct)
):
if not (A.owner and (
isinstance(A.owner.op, KroneckerProduct)
or (isinstance(A.owner.op, Blockwise)
and isinstance(A.owner.op.core_op, KroneckerProduct)))
):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request graph rewriting linalg Linear algebra performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rewrite Solve involving Kron
2 participants