Skip to content

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Nov 11, 2025

[Pallas/Mosaic TPU] Allow non-leading and non-matching batch dimensions in dot_general.

The constraints on lhs_batch_dims and rhs_batch_dims for dot_general in Pallas/Mosaic on TPU are now relaxed. Batch dimensions do not have to be at the front of the shape, and the dimension indices used for batching on the LHS and RHS can be different (which requires to update the semantics of output_dim_order in tpu.dot_dimension_numbers).

The remaining gap compared to JAX is the lack of support for multiple batch dimensions.

@copybara-service copybara-service bot force-pushed the test_831015967 branch 6 times, most recently from ad20fe1 to 81cf5e5 Compare November 12, 2025 03:37
…ns in `dot_general`.

The constraints on `lhs_batch_dims` and `rhs_batch_dims` for `dot_general` in Pallas/Mosaic on TPU are now relaxed. Batch dimensions do not have to be at the front of the shape, and the dimension indices used for batching on the LHS and RHS can be different (which requires to update the semantics of `output_dim_order` in `tpu.dot_dimension_numbers`).

The remaining gap compared to JAX is the lack of support for multiple batch dimensions.

PiperOrigin-RevId: 831015967
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants