Skip to content

Commit dd9243a

Browse files
committed
[MLIR][Linalg] Remove matmul_transpose variants
Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate. Added specialization for transpose variants: No code changes needed, all handled by the new specialization classes. Fixing a bug in the previous matmul builder where the affine map was being ignored. Checking that the affine map is the correct one for the transpose variant to make sure classof works as expected. Also exposing a method to check the affine map in case one needs a check for `matmul && isTransposeA` to replace the old checks for `matmul_transpose_a` and others. Adding new create methods following #147168 Common API for all matmul variants Moved isExpectedAffineMaps -> isDefaultIndexingMaps Moved getAffineMaps -> getDefaultIndexingMaps Added isDefaultIndexingMaps to the core ops (ODS)
1 parent a9dacb1 commit dd9243a

File tree

23 files changed

+689
-902
lines changed

23 files changed

+689
-902
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 190 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
145145
#define GET_OP_CLASSES
146146
#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.h.inc"
147147

148-
namespace mlir {
149-
namespace linalg {
148+
namespace mlir::linalg {
150149

151150
/// Returns the outer shape in the packed domain before applying the
152151
/// transposition.
@@ -155,7 +154,194 @@ template <typename OpTy,
155154
std::is_same_v<OpTy, linalg::UnPackOp>>>
156155
SmallVector<int64_t> getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack);
157156

158-
} // namespace linalg
159-
} // namespace mlir
157+
/// Specialization of `linalg.matmul` op that has a transpose map on A
158+
class MatmulTransposeAOp : public MatmulOp {
159+
/// Create an affine map for a transpose-A matmul. Used only in the builders.
160+
static SmallVector<AffineMap> getDefaultIndexingMaps(OpBuilder &builder);
161+
162+
public:
163+
using MatmulOp::MatmulOp;
164+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
165+
166+
/// Build a transpose A matmul.
167+
static void build(OpBuilder &builder, OperationState &result,
168+
ValueRange inputs, ValueRange outputs,
169+
ArrayRef<NamedAttribute> attributes = {});
170+
171+
static MatmulTransposeAOp create(OpBuilder &builder, Location location,
172+
ValueRange inputs, ValueRange outputs,
173+
ArrayRef<NamedAttribute> attributes = {});
174+
175+
/// Build a transpose A matmul with a specific result type.
176+
static void build(OpBuilder &builder, OperationState &result,
177+
TypeRange resultTensorTypes, ValueRange inputs,
178+
ValueRange outputs,
179+
ArrayRef<NamedAttribute> attributes = {});
180+
181+
static MatmulTransposeAOp create(OpBuilder &builder, Location location,
182+
TypeRange resultTensorTypes,
183+
ValueRange inputs, ValueRange outputs,
184+
ArrayRef<NamedAttribute> attributes = {});
185+
186+
/// Build a transpose A matmul with a specific result type and a cast type.
187+
static void build(OpBuilder &builder, OperationState &result,
188+
TypeRange resultTensorTypes, ValueRange inputs,
189+
ValueRange outputs, Attribute cast,
190+
ArrayRef<NamedAttribute> attributes = {});
191+
192+
static MatmulTransposeAOp create(OpBuilder &builder, Location location,
193+
TypeRange resultTensorTypes,
194+
ValueRange inputs, ValueRange outputs,
195+
Attribute cast,
196+
ArrayRef<NamedAttribute> attributes = {});
197+
198+
/// Checks if the affine map is the expected one for this operation
199+
static bool isDefaultIndexingMaps(Attribute attr);
200+
201+
static bool classof(Operation *op);
202+
};
203+
204+
/// Specialization of `linalg.matmul` op that has a transpose map on B
205+
class MatmulTransposeBOp : public MatmulOp {
206+
/// Create an affine map for a transpose-B matmul. Used only in the builders.
207+
static SmallVector<AffineMap> getDefaultIndexingMaps(OpBuilder &builder);
208+
209+
public:
210+
using MatmulOp::MatmulOp;
211+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<MatmulOp>(); }
212+
213+
/// Build a transpose B matmul.
214+
static void build(OpBuilder &builder, OperationState &result,
215+
ValueRange inputs, ValueRange outputs,
216+
ArrayRef<NamedAttribute> attributes = {});
217+
218+
static MatmulTransposeBOp create(OpBuilder &builder, Location location,
219+
ValueRange inputs, ValueRange outputs,
220+
ArrayRef<NamedAttribute> attributes = {});
221+
222+
/// Build a transpose B matmul with a specific result type.
223+
static void build(OpBuilder &builder, OperationState &result,
224+
TypeRange resultTensorTypes, ValueRange inputs,
225+
ValueRange outputs,
226+
ArrayRef<NamedAttribute> attributes = {});
227+
228+
static MatmulTransposeBOp create(OpBuilder &builder, Location location,
229+
TypeRange resultTensorTypes,
230+
ValueRange inputs, ValueRange outputs,
231+
ArrayRef<NamedAttribute> attributes = {});
232+
233+
/// Build a transpose B matmul with a specific result type and a cast type.
234+
static void build(OpBuilder &builder, OperationState &result,
235+
TypeRange resultTensorTypes, ValueRange inputs,
236+
ValueRange outputs, Attribute cast,
237+
ArrayRef<NamedAttribute> attributes = {});
238+
239+
static MatmulTransposeBOp create(OpBuilder &builder, Location location,
240+
TypeRange resultTensorTypes,
241+
ValueRange inputs, ValueRange outputs,
242+
Attribute cast,
243+
ArrayRef<NamedAttribute> attributes = {});
244+
245+
/// Checks if the affine map is the expected one for this operation
246+
static bool isDefaultIndexingMaps(Attribute attr);
247+
248+
static bool classof(Operation *op);
249+
};
250+
251+
/// Specialization of `linalg.batch_matmul` op that has a transpose map on A
252+
class BatchMatmulTransposeAOp : public BatchMatmulOp {
253+
/// Create an affine map for a transpose-A batch_matmul. Used only in the
254+
/// builders.
255+
static SmallVector<AffineMap> getDefaultIndexingMaps(OpBuilder &builder);
256+
257+
public:
258+
using BatchMatmulOp::BatchMatmulOp;
259+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
260+
261+
/// Build a transpose A matmul.
262+
static void build(OpBuilder &builder, OperationState &result,
263+
ValueRange inputs, ValueRange outputs,
264+
ArrayRef<NamedAttribute> attributes = {});
265+
266+
static BatchMatmulTransposeAOp
267+
create(OpBuilder &builder, Location location, ValueRange inputs,
268+
ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
269+
270+
/// Build a transpose A matmul with a specific result type.
271+
static void build(OpBuilder &builder, OperationState &result,
272+
TypeRange resultTensorTypes, ValueRange inputs,
273+
ValueRange outputs,
274+
ArrayRef<NamedAttribute> attributes = {});
275+
276+
static BatchMatmulTransposeAOp
277+
create(OpBuilder &builder, Location location, TypeRange resultTensorTypes,
278+
ValueRange inputs, ValueRange outputs,
279+
ArrayRef<NamedAttribute> attributes = {});
280+
281+
/// Build a transpose A matmul with a specific result type and a cast type.
282+
static void build(OpBuilder &builder, OperationState &result,
283+
TypeRange resultTensorTypes, ValueRange inputs,
284+
ValueRange outputs, Attribute cast,
285+
ArrayRef<NamedAttribute> attributes = {});
286+
287+
static BatchMatmulTransposeAOp
288+
create(OpBuilder &builder, Location location, TypeRange resultTensorTypes,
289+
ValueRange inputs, ValueRange outputs, Attribute cast,
290+
ArrayRef<NamedAttribute> attributes = {});
291+
292+
/// Checks if the affine map is the expected one for this operation
293+
static bool isDefaultIndexingMaps(Attribute attr);
294+
295+
static bool classof(Operation *op);
296+
};
297+
298+
/// Specialization of `linalg.batch_matmul` op that has a transpose map on B
299+
class BatchMatmulTransposeBOp : public BatchMatmulOp {
300+
/// Create an affine map for a transpose-B batch_matmul. Used only in the
301+
/// builders.
302+
static SmallVector<AffineMap> getDefaultIndexingMaps(OpBuilder &builder);
303+
304+
public:
305+
using BatchMatmulOp::BatchMatmulOp;
306+
static ::mlir::TypeID resolveTypeID() { return TypeID::get<BatchMatmulOp>(); }
307+
308+
/// Build a transpose B matmul.
309+
static void build(OpBuilder &builder, OperationState &result,
310+
ValueRange inputs, ValueRange outputs,
311+
ArrayRef<NamedAttribute> attributes = {});
312+
313+
static BatchMatmulTransposeBOp
314+
create(OpBuilder &builder, Location location, ValueRange inputs,
315+
ValueRange outputs, ArrayRef<NamedAttribute> attributes = {});
316+
317+
/// Build a transpose B matmul with a specific result type.
318+
static void build(OpBuilder &builder, OperationState &result,
319+
TypeRange resultTensorTypes, ValueRange inputs,
320+
ValueRange outputs,
321+
ArrayRef<NamedAttribute> attributes = {});
322+
323+
static BatchMatmulTransposeBOp
324+
create(OpBuilder &builder, Location location, TypeRange resultTensorTypes,
325+
ValueRange inputs, ValueRange outputs,
326+
ArrayRef<NamedAttribute> attributes = {});
327+
328+
/// Build a transpose B matmul with a specific result type and a cast type.
329+
static void build(OpBuilder &builder, OperationState &result,
330+
TypeRange resultTensorTypes, ValueRange inputs,
331+
ValueRange outputs, Attribute cast,
332+
ArrayRef<NamedAttribute> attributes = {});
333+
334+
static BatchMatmulTransposeBOp
335+
create(OpBuilder &builder, Location location, TypeRange resultTensorTypes,
336+
ValueRange inputs, ValueRange outputs, Attribute cast,
337+
ArrayRef<NamedAttribute> attributes = {});
338+
339+
/// Checks if the affine map is the expected one for this operation
340+
static bool isDefaultIndexingMaps(Attribute attr);
341+
342+
static bool classof(Operation *op);
343+
};
344+
345+
} // namespace mlir::linalg
160346

161347
#endif // MLIR_DIALECT_LINALG_IR_LINALG_H

0 commit comments

Comments
 (0)