Skip to content

Commit 81cf5e5

Browse files
yueshengysGoogle-ML-Automation
authored andcommitted
[Pallas/Mosaic TPU] Allow non-leading and non-matching batch dimensions in dot_general.
The constraints on `lhs_batch_dims` and `rhs_batch_dims` for `dot_general` in Pallas/Mosaic on TPU are now relaxed. Batch dimensions do not have to be at the front of the shape, and the dimension indices used for batching on the LHS and RHS can be different (which requires to update the semantics of `output_dim_order` in `tpu.dot_dimension_numbers`). The remaining gap compared to JAX is the lack of support for multiple batch dimensions. PiperOrigin-RevId: 831015967
1 parent 0128bbd commit 81cf5e5

File tree

4 files changed

+92
-44
lines changed

4 files changed

+92
-44
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,7 +2104,9 @@ def _proxy_fun(val, *, shape, broadcast_dimensions):
21042104
return vector.broadcast(out_type, val)
21052105

21062106

2107-
def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
2107+
def jax_dot_dims_to_tpu_dot_dot_dims(
2108+
ctx, dimension_numbers, lhs_shape, rhs_shape
2109+
):
21082110
"""Converts a jax dot dimension numbers to a tpu dot dimension numbers.
21092111
21102112
Jax dot dimension numbers are given as a tuple of tuples of sequences of ints
@@ -2142,6 +2144,11 @@ def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
21422144
output_dim_order.append(0)
21432145
output_dim_order.append(lhs_dim_map[dim])
21442146

2147+
if not (ctx.is_cloud_tpu_older_than(2025, 11, 18) and ctx.forward_compatible):
2148+
for dim in rhs_batch_dims:
2149+
output_dim_order.append(1)
2150+
output_dim_order.append(rhs_dim_map[dim])
2151+
21452152
for dim in lhs_non_contracting_dims:
21462153
output_dim_order.append(0)
21472154
output_dim_order.append(lhs_dim_map[dim])
@@ -2267,7 +2274,7 @@ def _dot_general_lowering_rule(
22672274
return vector.shape_cast(out_type, red)
22682275

22692276
tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims(
2270-
dimension_numbers, lhs_aval.shape, rhs_aval.shape
2277+
ctx, dimension_numbers, lhs_aval.shape, rhs_aval.shape
22712278
)
22722279

22732280
if precision is not None:

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,15 +1114,8 @@ LogicalResult MatmulOp::verify() {
11141114
const std::optional<int64_t> batch_dim_rhs =
11151115
rhs_batch_dims.empty() ? std::nullopt
11161116
: std::optional<int64_t>(rhs_batch_dims[0]);
1117-
if (batch_dim_lhs != batch_dim_rhs) {
1118-
emitOpError("Not Implemented: batch dims must be equal");
1119-
return failure();
1120-
}
1121-
if (batch_dim_lhs.has_value() && (batch_dim_lhs.value() != 0)) {
1122-
emitOpError("Not Implemented: batch dims pos must be 0");
1123-
return failure();
1124-
}
1125-
// Invariant above enforces only 1 batch dim atm, and that both are eq
1117+
1118+
// Invariant above enforces only 1 batch dim atm.
11261119
std::optional<int64_t> batch_size = std::nullopt;
11271120
if (batch_dim_lhs.has_value()) {
11281121
batch_size = lhs_ty.getShape()[batch_dim_lhs.value()];
@@ -1142,30 +1135,26 @@ LogicalResult MatmulOp::verify() {
11421135
"Illegal: output dim order must have an even number of elements.");
11431136
return failure();
11441137
}
1145-
if (batch_size.has_value()) {
1146-
if (output_dim_order[0] != 0 || output_dim_order[1] != 0) {
1147-
emitOpError(
1148-
"Not implemented: Output with batch size must be the lhs 0 idx for "
1149-
"now.");
1150-
return failure();
1151-
}
1152-
}
11531138

1154-
// Invariants above enforce a single batch idx for now, and that it is in
1155-
// position 0. Future extensions to this will be to:
1156-
// 1. Support multiple batch dims
1157-
// 2. Support batch dims in any position in the output dim order
1139+
// Invariants above enforce a single batch idx for now. Future extension to
1140+
// this will be to support multiple batch dims.
11581141

1159-
// Verify that the output dim order is always in the form of [0, batch_dims,
1160-
// 0, lhs_non_contracting_dims, 1, rhs_non_contracting_dims].
1142+
// Verify that the output dim order is always in the form of [0,
1143+
// lhs_batch_dims, 1, rhs_batch_dims, 0, lhs_non_contracting_dims, 1,
1144+
// rhs_non_contracting_dims].
11611145
llvm::SmallVector<int64_t> expected_output_dim_order;
1162-
expected_output_dim_order.reserve(2 * (lhs_batch_dims.size() +
1163-
lhs_non_contracting_dims.size() +
1164-
rhs_non_contracting_dims.size()));
1146+
expected_output_dim_order.reserve(
1147+
2 *
1148+
(lhs_batch_dims.size() + rhs_batch_dims.size() +
1149+
lhs_non_contracting_dims.size() + rhs_non_contracting_dims.size()));
11651150
for (int64_t dim : lhs_batch_dims) {
11661151
expected_output_dim_order.push_back(0);
11671152
expected_output_dim_order.push_back(dim);
11681153
}
1154+
for (int64_t dim : rhs_batch_dims) {
1155+
expected_output_dim_order.push_back(1);
1156+
expected_output_dim_order.push_back(dim);
1157+
}
11691158
for (int64_t dim : lhs_non_contracting_dims) {
11701159
expected_output_dim_order.push_back(0);
11711160
expected_output_dim_order.push_back(dim);
@@ -1177,7 +1166,7 @@ LogicalResult MatmulOp::verify() {
11771166
if (!absl::c_equal(output_dim_order, expected_output_dim_order)) {
11781167
emitOpError(
11791168
"Illegal: output dim order must be in the form of [0, "
1180-
"batch_dims, 0, lhs_non_contracting_dims, 1, "
1169+
"lhs_batch_dims, 1, rhs_batch_dims, 0, lhs_non_contracting_dims, 1, "
11811170
"rhs_non_contracting_dims]");
11821171
return failure();
11831172
}

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,24 @@ class CanonicalBuilder : public ImplicitLocOpBuilder {
132132
Operation *op_;
133133
};
134134

135-
// Ensures both lhs and rhs have contiguous non-contracting and contracting
136-
// dimensions by inserting transposes if needed. Returns lhs, rhs, and new
137-
// dimension numbers if a transpose was inserted, otherwise returns
138-
// std::nullopt.
135+
// Ensures both lhs and rhs are in form of [batch_dims, non_contracting_dims,
136+
// contracting_dims] or [batch_dims, contracting_dims, non_contracting_dims] by
137+
// inserting transposes if needed. Returns lhs, rhs, and new dimension numbers
138+
// if a transpose was inserted, otherwise returns std::nullopt.
139139
std::optional<std::tuple<TypedValue<VectorType>, TypedValue<VectorType>,
140140
DotDimensionNumbersAttr>>
141141
ensure_matmul_contiguous_dims(
142142
CanonicalBuilder& builder, TypedValue<VectorType> lhs,
143143
TypedValue<VectorType> rhs,
144144
const DotDimensionNumbersAttr& dimension_numbers) {
145-
// Returns a tuple of [new_operand, new_non_contracting_dims,
145+
// Returns a tuple of [new_operand, new_batch_dims, new_non_contracting_dims,
146146
// new_contracting_dims]. new_operand is nullptr if no transpose is inserted.
147147
auto maybe_insert_transpose =
148148
[&](TypedValue<VectorType> operand, ArrayRef<int64_t> batch_dims,
149149
ArrayRef<int64_t> non_contracting_dims,
150150
ArrayRef<int64_t> contracting_dims, bool is_lhs)
151151
-> std::tuple<TypedValue<VectorType>, SmallVector<int64_t>,
152-
SmallVector<int64_t>> {
152+
SmallVector<int64_t>, SmallVector<int64_t>> {
153153
VectorType vty = operand.getType();
154154
auto shape = vty.getShape();
155155
auto rank = shape.size();
@@ -170,7 +170,8 @@ ensure_matmul_contiguous_dims(
170170
contracting_dims.end());
171171
// Already in [B..., NC..., C...].
172172
if (is_identity(perm_BNC)) {
173-
return {nullptr, llvm::to_vector(non_contracting_dims),
173+
return {nullptr, llvm::to_vector(batch_dims),
174+
llvm::to_vector(non_contracting_dims),
174175
llvm::to_vector(contracting_dims)};
175176
}
176177

@@ -183,7 +184,8 @@ ensure_matmul_contiguous_dims(
183184
non_contracting_dims.end());
184185
// Already in [B..., C..., NC...].
185186
if (is_identity(perm_BCN)) {
186-
return {nullptr, llvm::to_vector(non_contracting_dims),
187+
return {nullptr, llvm::to_vector(batch_dims),
188+
llvm::to_vector(non_contracting_dims),
187189
llvm::to_vector(contracting_dims)};
188190
}
189191

@@ -246,18 +248,21 @@ ensure_matmul_contiguous_dims(
246248
};
247249

248250
// Map the dimension indices to the new dimension order.
251+
SmallVector<int64_t> new_b = map_dims(batch_dims);
249252
SmallVector<int64_t> new_c = map_dims(contracting_dims);
250253
SmallVector<int64_t> new_nc = map_dims(non_contracting_dims);
251254

252-
return {new_operand, new_nc, new_c};
255+
return {new_operand, new_b, new_nc, new_c};
253256
};
254257

255-
auto [new_lhs, new_lhs_non_contracting_dims, new_lhs_contracting_dims] =
258+
auto [new_lhs, new_lhs_batch_dims, new_lhs_non_contracting_dims,
259+
new_lhs_contracting_dims] =
256260
maybe_insert_transpose(lhs, dimension_numbers.getLhsBatchDims(),
257261
dimension_numbers.getLhsNonContractingDims(),
258262
dimension_numbers.getLhsContractingDims(),
259263
/*is_lhs=*/true);
260-
auto [new_rhs, new_rhs_non_contracting_dims, new_rhs_contracting_dims] =
264+
auto [new_rhs, new_rhs_batch_dims, new_rhs_non_contracting_dims,
265+
new_rhs_contracting_dims] =
261266
maybe_insert_transpose(rhs, dimension_numbers.getRhsBatchDims(),
262267
dimension_numbers.getRhsNonContractingDims(),
263268
dimension_numbers.getRhsContractingDims(),
@@ -267,13 +272,18 @@ ensure_matmul_contiguous_dims(
267272
}
268273

269274
SmallVector<int64_t> new_output_dim_order;
270-
new_output_dim_order.reserve(2 * (dimension_numbers.getLhsBatchDims().size() +
275+
new_output_dim_order.reserve(2 * (new_lhs_batch_dims.size() +
276+
new_rhs_batch_dims.size() +
271277
new_lhs_non_contracting_dims.size() +
272278
new_rhs_non_contracting_dims.size()));
273-
for (int64_t batch_dim : dimension_numbers.getLhsBatchDims()) {
279+
for (int64_t batch_dim : new_lhs_batch_dims) {
274280
new_output_dim_order.push_back(0);
275281
new_output_dim_order.push_back(batch_dim);
276282
}
283+
for (int64_t batch_dim : new_rhs_batch_dims) {
284+
new_output_dim_order.push_back(1);
285+
new_output_dim_order.push_back(batch_dim);
286+
}
277287
for (int64_t non_contracting_dim : new_lhs_non_contracting_dims) {
278288
new_output_dim_order.push_back(0);
279289
new_output_dim_order.push_back(non_contracting_dim);
@@ -286,8 +296,7 @@ ensure_matmul_contiguous_dims(
286296
DotDimensionNumbersAttr new_dimension_numbers = DotDimensionNumbersAttr::get(
287297
builder.getContext(), new_lhs_contracting_dims, new_rhs_contracting_dims,
288298
new_lhs_non_contracting_dims, new_rhs_non_contracting_dims,
289-
new_output_dim_order, dimension_numbers.getLhsBatchDims(),
290-
dimension_numbers.getRhsBatchDims());
299+
new_output_dim_order, new_lhs_batch_dims, new_rhs_batch_dims);
291300

292301
return std::make_tuple(new_lhs ? new_lhs : lhs, new_rhs ? new_rhs : rhs,
293302
new_dimension_numbers);
@@ -419,6 +428,10 @@ collapse_matmul_non_contracting_dims(
419428
new_output_dim_order.push_back(batch_dim);
420429
new_acc_shape.push_back(lhs.getType().getDimSize(batch_dim));
421430
}
431+
for (int64_t batch_dim : dimension_numbers.getRhsBatchDims()) {
432+
new_output_dim_order.push_back(1);
433+
new_output_dim_order.push_back(batch_dim);
434+
}
422435
for (int64_t non_contracting_dim : new_lhs_non_contracting_dims) {
423436
new_output_dim_order.push_back(0);
424437
new_output_dim_order.push_back(non_contracting_dim);

tests/pallas/ops_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,45 @@ def kernel(x_ref, y_ref, out_ref):
14981498
expected,
14991499
)
15001500

1501+
@parameterized.product(
1502+
shapes_and_dims_numbers=(
1503+
((3, 4, 128), (4, 2, 128), (((2,), (2,)), ((1,), (0,)))),
1504+
((3, 4, 128), (2, 4, 128), (((2,), (2,)), ((1,), (1,)))),
1505+
((3, 4, 256), (2, 3, 256), (((2,), (2,)), ((0,), (1,)))),
1506+
((4, 3, 2, 32), (2, 128, 32, 2), (((3,), (2,)), ((2,), (3,)))),
1507+
),
1508+
)
1509+
def test_dot_general_non_front_batch_dims(self, shapes_and_dims_numbers):
1510+
if jtu.test_device_matches(["gpu"]):
1511+
self.skipTest("TPU only test")
1512+
1513+
if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least(
1514+
2025, 11, 18
1515+
):
1516+
self.skipTest("Requires libtpu built after 2025-11-18")
1517+
1518+
x_shape, y_shape, dims_numbers = shapes_and_dims_numbers
1519+
1520+
k1, k2 = random.split(jax.random.key(0))
1521+
x = jax.random.normal(k1, x_shape, dtype=jnp.float32)
1522+
y = jax.random.normal(k2, y_shape, dtype=jnp.float32)
1523+
1524+
# Just infer shape from jax.
1525+
expected = jax.lax.dot_general(x, y, dimension_numbers=dims_numbers)
1526+
1527+
@functools.partial(
1528+
self.pallas_call,
1529+
out_shape=jax.ShapeDtypeStruct(expected.shape, jnp.float32),
1530+
)
1531+
def kernel(x_ref, y_ref, out_ref):
1532+
out_ref[...] = jax.lax.dot_general(
1533+
x_ref[...],
1534+
y_ref[...],
1535+
dimension_numbers=dims_numbers,
1536+
)
1537+
1538+
np.testing.assert_allclose(kernel(x, y), expected, atol=1e-5, rtol=1e-5)
1539+
15011540
@parameterized.product(
15021541
batch_size=(None, 1, 2),
15031542
# dims_numbers is without batch dims

0 commit comments

Comments
 (0)