@@ -132,24 +132,24 @@ class CanonicalBuilder : public ImplicitLocOpBuilder {
132132 Operation *op_;
133133};
134134
135- // Ensures both lhs and rhs have contiguous non-contracting and contracting
136- // dimensions by inserting transposes if needed. Returns lhs, rhs, and new
137- // dimension numbers if a transpose was inserted, otherwise returns
138- // std::nullopt.
135+ // Ensures both lhs and rhs are in form of [batch_dims, non_contracting_dims,
136+ // contracting_dims] or [batch_dims, contracting_dims, non_contracting_dims] by
137+ // inserting transposes if needed. Returns lhs, rhs, and new dimension numbers
138+ // if a transpose was inserted, otherwise returns std::nullopt.
139139std::optional<std::tuple<TypedValue<VectorType>, TypedValue<VectorType>,
140140 DotDimensionNumbersAttr>>
141141ensure_matmul_contiguous_dims (
142142 CanonicalBuilder& builder, TypedValue<VectorType> lhs,
143143 TypedValue<VectorType> rhs,
144144 const DotDimensionNumbersAttr& dimension_numbers) {
145- // Returns a tuple of [new_operand, new_non_contracting_dims,
145+ // Returns a tuple of [new_operand, new_batch_dims, new_non_contracting_dims,
146146 // new_contracting_dims]. new_operand is nullptr if no transpose is inserted.
147147 auto maybe_insert_transpose =
148148 [&](TypedValue<VectorType> operand, ArrayRef<int64_t > batch_dims,
149149 ArrayRef<int64_t > non_contracting_dims,
150150 ArrayRef<int64_t > contracting_dims, bool is_lhs)
151151 -> std::tuple<TypedValue<VectorType>, SmallVector<int64_t >,
152- SmallVector<int64_t >> {
152+ SmallVector<int64_t >, SmallVector< int64_t > > {
153153 VectorType vty = operand.getType ();
154154 auto shape = vty.getShape ();
155155 auto rank = shape.size ();
@@ -170,7 +170,8 @@ ensure_matmul_contiguous_dims(
170170 contracting_dims.end ());
171171 // Already in [B..., NC..., C...].
172172 if (is_identity (perm_BNC)) {
173- return {nullptr , llvm::to_vector (non_contracting_dims),
173+ return {nullptr , llvm::to_vector (batch_dims),
174+ llvm::to_vector (non_contracting_dims),
174175 llvm::to_vector (contracting_dims)};
175176 }
176177
@@ -183,7 +184,8 @@ ensure_matmul_contiguous_dims(
183184 non_contracting_dims.end ());
184185 // Already in [B..., C..., NC...].
185186 if (is_identity (perm_BCN)) {
186- return {nullptr , llvm::to_vector (non_contracting_dims),
187+ return {nullptr , llvm::to_vector (batch_dims),
188+ llvm::to_vector (non_contracting_dims),
187189 llvm::to_vector (contracting_dims)};
188190 }
189191
@@ -246,18 +248,21 @@ ensure_matmul_contiguous_dims(
246248 };
247249
248250 // Map the dimension indices to the new dimension order.
251+ SmallVector<int64_t > new_b = map_dims (batch_dims);
249252 SmallVector<int64_t > new_c = map_dims (contracting_dims);
250253 SmallVector<int64_t > new_nc = map_dims (non_contracting_dims);
251254
252- return {new_operand, new_nc, new_c};
255+ return {new_operand, new_b, new_nc, new_c};
253256 };
254257
255- auto [new_lhs, new_lhs_non_contracting_dims, new_lhs_contracting_dims] =
258+ auto [new_lhs, new_lhs_batch_dims, new_lhs_non_contracting_dims,
259+ new_lhs_contracting_dims] =
256260 maybe_insert_transpose (lhs, dimension_numbers.getLhsBatchDims (),
257261 dimension_numbers.getLhsNonContractingDims (),
258262 dimension_numbers.getLhsContractingDims (),
259263 /* is_lhs=*/ true );
260- auto [new_rhs, new_rhs_non_contracting_dims, new_rhs_contracting_dims] =
264+ auto [new_rhs, new_rhs_batch_dims, new_rhs_non_contracting_dims,
265+ new_rhs_contracting_dims] =
261266 maybe_insert_transpose (rhs, dimension_numbers.getRhsBatchDims (),
262267 dimension_numbers.getRhsNonContractingDims (),
263268 dimension_numbers.getRhsContractingDims (),
@@ -267,13 +272,18 @@ ensure_matmul_contiguous_dims(
267272 }
268273
269274 SmallVector<int64_t > new_output_dim_order;
270- new_output_dim_order.reserve (2 * (dimension_numbers.getLhsBatchDims ().size () +
275+ new_output_dim_order.reserve (2 * (new_lhs_batch_dims.size () +
276+ new_rhs_batch_dims.size () +
271277 new_lhs_non_contracting_dims.size () +
272278 new_rhs_non_contracting_dims.size ()));
273- for (int64_t batch_dim : dimension_numbers. getLhsBatchDims () ) {
279+ for (int64_t batch_dim : new_lhs_batch_dims ) {
274280 new_output_dim_order.push_back (0 );
275281 new_output_dim_order.push_back (batch_dim);
276282 }
283+ for (int64_t batch_dim : new_rhs_batch_dims) {
284+ new_output_dim_order.push_back (1 );
285+ new_output_dim_order.push_back (batch_dim);
286+ }
277287 for (int64_t non_contracting_dim : new_lhs_non_contracting_dims) {
278288 new_output_dim_order.push_back (0 );
279289 new_output_dim_order.push_back (non_contracting_dim);
@@ -286,8 +296,7 @@ ensure_matmul_contiguous_dims(
286296 DotDimensionNumbersAttr new_dimension_numbers = DotDimensionNumbersAttr::get (
287297 builder.getContext (), new_lhs_contracting_dims, new_rhs_contracting_dims,
288298 new_lhs_non_contracting_dims, new_rhs_non_contracting_dims,
289- new_output_dim_order, dimension_numbers.getLhsBatchDims (),
290- dimension_numbers.getRhsBatchDims ());
299+ new_output_dim_order, new_lhs_batch_dims, new_rhs_batch_dims);
291300
292301 return std::make_tuple (new_lhs ? new_lhs : lhs, new_rhs ? new_rhs : rhs,
293302 new_dimension_numbers);
@@ -419,6 +428,10 @@ collapse_matmul_non_contracting_dims(
419428 new_output_dim_order.push_back (batch_dim);
420429 new_acc_shape.push_back (lhs.getType ().getDimSize (batch_dim));
421430 }
431+ for (int64_t batch_dim : dimension_numbers.getRhsBatchDims ()) {
432+ new_output_dim_order.push_back (1 );
433+ new_output_dim_order.push_back (batch_dim);
434+ }
422435 for (int64_t non_contracting_dim : new_lhs_non_contracting_dims) {
423436 new_output_dim_order.push_back (0 );
424437 new_output_dim_order.push_back (non_contracting_dim);
0 commit comments