Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions jaxlib/mosaic/dialect/tpu/tpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1124,15 +1124,8 @@ LogicalResult MatmulOp::verify() {
const std::optional<int64_t> batch_dim_rhs =
rhs_batch_dims.empty() ? std::nullopt
: std::optional<int64_t>(rhs_batch_dims[0]);
if (batch_dim_lhs != batch_dim_rhs) {
emitOpError("Not Implemented: batch dims must be equal");
return failure();
}
if (batch_dim_lhs.has_value() && (batch_dim_lhs.value() != 0)) {
emitOpError("Not Implemented: batch dims pos must be 0");
return failure();
}
// Invariant above enforces only 1 batch dim atm, and that both are eq

// Invariant above enforces only 1 batch dim atm.
std::optional<int64_t> batch_size = std::nullopt;
if (batch_dim_lhs.has_value()) {
batch_size = lhs_ty.getShape()[batch_dim_lhs.value()];
Expand All @@ -1152,19 +1145,9 @@ LogicalResult MatmulOp::verify() {
"Illegal: output dim order must have an even number of elements.");
return failure();
}
if (batch_size.has_value()) {
if (output_dim_order[0] != 0 || output_dim_order[1] != 0) {
emitOpError(
"Not implemented: Output with batch size must be the lhs 0 idx for "
"now.");
return failure();
}
}

// Invariants above enforce a single batch idx for now, and that it is in
// position 0. Future extensions to this will be to:
// 1. Support multiple batch dims
// 2. Support batch dims in any position in the output dim order
// Invariants above enforce a single batch idx for now. Future extension to
// this will be to support multiple batch dims.

// Verify that the output dim order is always in the form of [0, batch_dims,
// 0, lhs_non_contracting_dims, 1, rhs_non_contracting_dims].
Expand Down
34 changes: 19 additions & 15 deletions jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,24 @@ class CanonicalBuilder : public ImplicitLocOpBuilder {
Operation *op_;
};

// Ensures both lhs and rhs have contiguous non-contracting and contracting
// dimensions by inserting transposes if needed. Returns lhs, rhs, and new
// dimension numbers if a transpose was inserted, otherwise returns
// std::nullopt.
// Ensures both lhs and rhs are in form of [batch_dims, non_contracting_dims,
// contracting_dims] or [batch_dims, contracting_dims, non_contracting_dims] by
// inserting transposes if needed. Returns lhs, rhs, and new dimension numbers
// if a transpose was inserted, otherwise returns std::nullopt.
std::optional<std::tuple<TypedValue<VectorType>, TypedValue<VectorType>,
DotDimensionNumbersAttr>>
ensure_matmul_contiguous_dims(
CanonicalBuilder& builder, TypedValue<VectorType> lhs,
TypedValue<VectorType> rhs,
const DotDimensionNumbersAttr& dimension_numbers) {
// Returns a tuple of [new_operand, new_non_contracting_dims,
// Returns a tuple of [new_operand, new_batch_dims, new_non_contracting_dims,
// new_contracting_dims]. new_operand is nullptr if no transpose is inserted.
auto maybe_insert_transpose =
[&](TypedValue<VectorType> operand, ArrayRef<int64_t> batch_dims,
ArrayRef<int64_t> non_contracting_dims,
ArrayRef<int64_t> contracting_dims, bool is_lhs)
-> std::tuple<TypedValue<VectorType>, SmallVector<int64_t>,
SmallVector<int64_t>> {
SmallVector<int64_t>, SmallVector<int64_t>> {
VectorType vty = operand.getType();
auto shape = vty.getShape();
auto rank = shape.size();
Expand All @@ -170,7 +170,8 @@ ensure_matmul_contiguous_dims(
contracting_dims.end());
// Already in [B..., NC..., C...].
if (is_identity(perm_BNC)) {
return {nullptr, llvm::to_vector(non_contracting_dims),
return {nullptr, llvm::to_vector(batch_dims),
llvm::to_vector(non_contracting_dims),
llvm::to_vector(contracting_dims)};
}

Expand All @@ -183,7 +184,8 @@ ensure_matmul_contiguous_dims(
non_contracting_dims.end());
// Already in [B..., C..., NC...].
if (is_identity(perm_BCN)) {
return {nullptr, llvm::to_vector(non_contracting_dims),
return {nullptr, llvm::to_vector(batch_dims),
llvm::to_vector(non_contracting_dims),
llvm::to_vector(contracting_dims)};
}

Expand Down Expand Up @@ -246,18 +248,21 @@ ensure_matmul_contiguous_dims(
};

// Map the dimension indices to the new dimension order.
SmallVector<int64_t> new_b = map_dims(batch_dims);
SmallVector<int64_t> new_c = map_dims(contracting_dims);
SmallVector<int64_t> new_nc = map_dims(non_contracting_dims);

return {new_operand, new_nc, new_c};
return {new_operand, new_b, new_nc, new_c};
};

auto [new_lhs, new_lhs_non_contracting_dims, new_lhs_contracting_dims] =
auto [new_lhs, new_lhs_batch_dims, new_lhs_non_contracting_dims,
new_lhs_contracting_dims] =
maybe_insert_transpose(lhs, dimension_numbers.getLhsBatchDims(),
dimension_numbers.getLhsNonContractingDims(),
dimension_numbers.getLhsContractingDims(),
/*is_lhs=*/true);
auto [new_rhs, new_rhs_non_contracting_dims, new_rhs_contracting_dims] =
auto [new_rhs, new_rhs_batch_dims, new_rhs_non_contracting_dims,
new_rhs_contracting_dims] =
maybe_insert_transpose(rhs, dimension_numbers.getRhsBatchDims(),
dimension_numbers.getRhsNonContractingDims(),
dimension_numbers.getRhsContractingDims(),
Expand All @@ -267,10 +272,10 @@ ensure_matmul_contiguous_dims(
}

SmallVector<int64_t> new_output_dim_order;
new_output_dim_order.reserve(2 * (dimension_numbers.getLhsBatchDims().size() +
new_output_dim_order.reserve(2 * (new_lhs_batch_dims.size() +
new_lhs_non_contracting_dims.size() +
new_rhs_non_contracting_dims.size()));
for (int64_t batch_dim : dimension_numbers.getLhsBatchDims()) {
for (int64_t batch_dim : new_lhs_batch_dims) {
new_output_dim_order.push_back(0);
new_output_dim_order.push_back(batch_dim);
}
Expand All @@ -286,8 +291,7 @@ ensure_matmul_contiguous_dims(
DotDimensionNumbersAttr new_dimension_numbers = DotDimensionNumbersAttr::get(
builder.getContext(), new_lhs_contracting_dims, new_rhs_contracting_dims,
new_lhs_non_contracting_dims, new_rhs_non_contracting_dims,
new_output_dim_order, dimension_numbers.getLhsBatchDims(),
dimension_numbers.getRhsBatchDims());
new_output_dim_order, new_lhs_batch_dims, new_rhs_batch_dims);

return std::make_tuple(new_lhs ? new_lhs : lhs, new_rhs ? new_rhs : rhs,
new_dimension_numbers);
Expand Down
39 changes: 39 additions & 0 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,45 @@ def kernel(x_ref, y_ref, out_ref):
expected,
)

@parameterized.product(
shapes_and_dims_numbers=(
((3, 4, 128), (4, 2, 128), (((2,), (2,)), ((1,), (0,)))),
((3, 4, 128), (2, 4, 128), (((2,), (2,)), ((1,), (1,)))),
((3, 4, 256), (2, 3, 256), (((2,), (2,)), ((0,), (1,)))),
((4, 3, 2, 32), (2, 128, 32, 2), (((3,), (2,)), ((2,), (3,)))),
),
)
def test_dot_general_non_front_batch_dims(self, shapes_and_dims_numbers):
if jtu.test_device_matches(["gpu"]):
self.skipTest("TPU only test")

if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least(
2025, 11, 21
):
self.skipTest("Requires libtpu built after 2025-11-21")

x_shape, y_shape, dims_numbers = shapes_and_dims_numbers

k1, k2 = random.split(jax.random.key(0))
x = jax.random.normal(k1, x_shape, dtype=jnp.float32)
y = jax.random.normal(k2, y_shape, dtype=jnp.float32)

# Just infer shape from jax.
expected = jax.lax.dot_general(x, y, dimension_numbers=dims_numbers)

@functools.partial(
self.pallas_call,
out_shape=jax.ShapeDtypeStruct(expected.shape, jnp.float32),
)
def kernel(x_ref, y_ref, out_ref):
out_ref[...] = jax.lax.dot_general(
x_ref[...],
y_ref[...],
dimension_numbers=dims_numbers,
)

np.testing.assert_allclose(kernel(x, y), expected, atol=1e-5, rtol=1e-5)

@parameterized.product(
batch_size=(None, 1, 2),
# dims_numbers is without batch dims
Expand Down
Loading