6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
9
10
#include " mlir/Dialect/Linalg/IR/Linalg.h"
10
11
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
11
12
#include " mlir/Dialect/Linalg/Utils/Utils.h"
12
13
#include " mlir/Dialect/Tensor/IR/Tensor.h"
14
+ #include " mlir/Dialect/UB/IR/UBOps.h"
13
15
#include " mlir/Dialect/Utils/IndexingUtils.h"
14
16
#include " mlir/IR/Dominance.h"
15
17
#include " llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1236
1238
ControlPropagationFn controlFn;
1237
1239
};
1238
1240
1241
+ // This struct contains infomation about extract_slice dims.
1242
+ struct SliceDimInfo {
1243
+ OpFoldResult offset;
1244
+ OpFoldResult sliceSize;
1245
+ OpFoldResult outputSize;
1246
+ };
1247
+
1248
+ // / Return the first input extract slice operand, if present, for the current
1249
+ // / generic op.
1250
+ static FailureOr<std::tuple<OpOperand *, unsigned >>
1251
+ getSliceOperandAndIndex (GenericOp genericOp) {
1252
+ OpOperand *sliceOperand = nullptr ;
1253
+ unsigned operandIndex;
1254
+ for (auto [idx, operand] : llvm::enumerate (genericOp.getDpsInputOperands ())) {
1255
+ auto extractOp = operand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1256
+ if (!extractOp)
1257
+ continue ;
1258
+ sliceOperand = operand;
1259
+ operandIndex = idx;
1260
+ break ;
1261
+ }
1262
+ if (!sliceOperand) {
1263
+ return failure ();
1264
+ }
1265
+ return std::make_tuple (sliceOperand, operandIndex);
1266
+ }
1267
+
1268
+ // Return a map of dims that have non full slices on them so that other operands
1269
+ // can use this information. Also return a bool mentioning if a reduction dim
1270
+ // has a non full slice as that can be used to fold the original extract slice.
1271
+ static FailureOr<std::tuple<llvm::DenseMap<int64_t , SliceDimInfo>, bool >>
1272
+ getNonFullSliceDimInfo (GenericOp genericOp, OpOperand *sliceOperand,
1273
+ tensor::ExtractSliceOp producerSliceOp) {
1274
+ llvm::DenseMap<int64_t , SliceDimInfo> nonZeroSliceDimMap;
1275
+ bool hasNonZeroReductionDimSlice = false ;
1276
+ SmallVector<utils::IteratorType> iterators =
1277
+ genericOp.getIteratorTypesArray ();
1278
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets ();
1279
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes ();
1280
+
1281
+ SmallVector<OpFoldResult> shape = llvm::map_to_vector (
1282
+ producerSliceOp.getSourceType ().getShape (),
1283
+ [&](int64_t sz) -> OpFoldResult {
1284
+ return getAsIndexOpFoldResult (genericOp.getContext (), sz);
1285
+ });
1286
+
1287
+ for (auto [idx, expr] : llvm::enumerate (
1288
+ genericOp.getMatchingIndexingMap (sliceOperand).getResults ())) {
1289
+ if (isConstantIntValue (offsets[idx], 0 ) &&
1290
+ isEqualConstantIntOrValue (sizes[idx], shape[idx])) {
1291
+ continue ;
1292
+ }
1293
+ if (!isa<AffineDimExpr>(expr)) {
1294
+ return failure ();
1295
+ }
1296
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1297
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition ();
1298
+ nonZeroSliceDimMap[dimPos] = sliceDimInfo;
1299
+ if (iterators[dimPos] == utils::IteratorType::reduction) {
1300
+ hasNonZeroReductionDimSlice = true ;
1301
+ }
1302
+ }
1303
+ // Next check if the dims with non zero slice info are used as non
1304
+ // AffineDimExpr and if they are then bail-out.
1305
+ for (OpOperand &operand : genericOp->getOpOperands ()) {
1306
+ if (operand == *sliceOperand) {
1307
+ continue ;
1308
+ }
1309
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (&operand);
1310
+ if (llvm::any_of (IndexingMap.getResults (), [&](AffineExpr expr) {
1311
+ if (isa<AffineDimExpr>(expr)) {
1312
+ return false ;
1313
+ }
1314
+ WalkResult status = expr.walk ([&](AffineExpr expr) {
1315
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1316
+ if (nonZeroSliceDimMap.count (dimExpr.getPosition ()) != 0 ) {
1317
+ return WalkResult::interrupt ();
1318
+ }
1319
+ }
1320
+ return WalkResult::advance ();
1321
+ });
1322
+ if (status.wasInterrupted ()) {
1323
+ return true ;
1324
+ }
1325
+ return false ;
1326
+ })) {
1327
+ return failure ();
1328
+ }
1329
+ }
1330
+ return std::make_tuple (nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
1331
+ }
1332
+
1333
+ static FailureOr<std::tuple<GenericOp, Value>>
1334
+ pushDownExtractSliceOpThroughGenericOp (RewriterBase &rewriter,
1335
+ GenericOp genericOp,
1336
+ ControlPropagationFn controlFn) {
1337
+ if (genericOp.getNumResults () != 1 )
1338
+ return failure ();
1339
+ if (hasGatherSemantics (genericOp))
1340
+ return failure ();
1341
+ // Collect the unPacked operand, if present.
1342
+ auto maybeSliceOperandAndIndex = getSliceOperandAndIndex (genericOp);
1343
+ if (failed (maybeSliceOperandAndIndex))
1344
+ return failure ();
1345
+ OpOperand *sliceOperand = std::get<0 >(*maybeSliceOperandAndIndex);
1346
+ unsigned OperandIndex = std::get<1 >(*maybeSliceOperandAndIndex);
1347
+
1348
+ if (!controlFn (sliceOperand))
1349
+ return failure ();
1350
+
1351
+ tensor::ExtractSliceOp producerSliceOp =
1352
+ sliceOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1353
+ assert (producerSliceOp && " expect a valid UnPackOp" );
1354
+
1355
+ if (producerSliceOp.getSource ().getType ().getRank () !=
1356
+ producerSliceOp.getResult ().getType ().getRank ()) {
1357
+ return failure ();
1358
+ }
1359
+
1360
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides ();
1361
+ if (!areAllConstantIntValue (strides, 1 ))
1362
+ return failure ();
1363
+
1364
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets ();
1365
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes ();
1366
+
1367
+ // check if we can support the propagation of this extractSlice
1368
+ // through the generic op and if so return the dimensions that
1369
+
1370
+ auto maybeNonZeroSliceDimMap =
1371
+ getNonFullSliceDimInfo (genericOp, sliceOperand, producerSliceOp);
1372
+
1373
+ if (failed (maybeNonZeroSliceDimMap)) {
1374
+ return failure ();
1375
+ }
1376
+
1377
+ auto nonZeroSliceDimMap = std::get<0 >(*maybeNonZeroSliceDimMap);
1378
+ bool hasNonZeroReductionDimSlice = std::get<1 >(*maybeNonZeroSliceDimMap);
1379
+
1380
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1381
+ Location loc = genericOp->getLoc ();
1382
+ AffineExpr dim0, dim1;
1383
+ bindDims (rewriter.getContext (), dim0, dim1);
1384
+ auto subMap = AffineMap::get (2 , 0 , {dim0 - dim1});
1385
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1386
+ return affine::makeComposedFoldedAffineApply (rewriter, loc, subMap,
1387
+ {v1, v2});
1388
+ };
1389
+
1390
+ MLIRContext *ctx = genericOp.getContext ();
1391
+ SmallVector<Value> paddedInputs;
1392
+ for (auto [idx, operand] : llvm::enumerate (genericOp.getDpsInputOperands ())) {
1393
+ if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
1394
+ paddedInputs.push_back (producerSliceOp.getSource ());
1395
+ continue ;
1396
+ }
1397
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (operand);
1398
+ SmallVector<OpFoldResult> operandLowPads (IndexingMap.getNumResults (),
1399
+ getAsIndexOpFoldResult (ctx, 0 ));
1400
+ SmallVector<OpFoldResult> operandHighPads (IndexingMap.getNumResults (),
1401
+ getAsIndexOpFoldResult (ctx, 0 ));
1402
+ for (auto [idx, expr] : llvm::enumerate (IndexingMap.getResults ())) {
1403
+ if (!isa<AffineDimExpr>(expr)) {
1404
+ continue ;
1405
+ }
1406
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1407
+ if (nonZeroSliceDimMap.contains (dimExpr.getPosition ())) {
1408
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition ()];
1409
+ operandLowPads[idx] = sliceDimInfo.offset ;
1410
+ operandHighPads[idx] =
1411
+ sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1412
+ sliceDimInfo.sliceSize );
1413
+ }
1414
+ }
1415
+ auto paddingValue = ub::PoisonOp::create (
1416
+ rewriter, loc, getElementTypeOrSelf (operand->get ().getType ()));
1417
+ auto paddedOperand = tensor::PadOp::create (
1418
+ rewriter, loc, Type (), operand->get (), operandLowPads, operandHighPads,
1419
+ paddingValue, /* nofold=*/ false );
1420
+ paddedInputs.push_back (paddedOperand);
1421
+ }
1422
+ AffineMap outputIndexingMap =
1423
+ genericOp.getMatchingIndexingMap (genericOp.getDpsInitOperand (0 ));
1424
+
1425
+ auto outputShapeType =
1426
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand (0 )->get ().getType ());
1427
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector (
1428
+ outputShapeType.getShape (),
1429
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr (sz); });
1430
+ SmallVector<OpFoldResult> newSizes = OutputShape;
1431
+ SmallVector<OpFoldResult> outputLowPads (outputIndexingMap.getNumResults (),
1432
+ getAsIndexOpFoldResult (ctx, 0 ));
1433
+ SmallVector<OpFoldResult> outputHighPads (outputIndexingMap.getNumResults (),
1434
+ getAsIndexOpFoldResult (ctx, 0 ));
1435
+ SmallVector<OpFoldResult> newStrides (outputIndexingMap.getNumResults (),
1436
+ getAsIndexOpFoldResult (ctx, 1 ));
1437
+ for (auto [idx, expr] : llvm::enumerate (outputIndexingMap.getResults ())) {
1438
+ if (!isa<AffineDimExpr>(expr)) {
1439
+ continue ;
1440
+ }
1441
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1442
+ if (nonZeroSliceDimMap.contains (dimExpr.getPosition ())) {
1443
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition ()];
1444
+ outputLowPads[idx] = sliceDimInfo.offset ;
1445
+ outputHighPads[idx] =
1446
+ sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1447
+ sliceDimInfo.sliceSize );
1448
+ OutputShape[idx] = sliceDimInfo.outputSize ;
1449
+ newSizes[idx] = sliceDimInfo.sliceSize ;
1450
+ }
1451
+ }
1452
+ Value newPadOutput;
1453
+ auto outputElType =
1454
+ getElementTypeOrSelf (genericOp.getDpsInits ()[0 ].getType ());
1455
+ if (isGenericOutsNotUsed (genericOp)) {
1456
+ newPadOutput =
1457
+ tensor::EmptyOp::create (rewriter, loc, OutputShape, outputElType);
1458
+
1459
+ } else {
1460
+
1461
+ auto paddingValue = ub::PoisonOp::create (rewriter, loc, outputElType);
1462
+ newPadOutput = tensor::PadOp::create (
1463
+ rewriter, loc, Type (), genericOp.getDpsInits ()[0 ], outputLowPads,
1464
+ outputHighPads, paddingValue, /* nofold=*/ false );
1465
+ }
1466
+
1467
+ auto newGenericOp = linalg::GenericOp::create (
1468
+ rewriter, loc, newPadOutput.getType (), paddedInputs, {newPadOutput},
1469
+ genericOp.getIndexingMapsArray (), genericOp.getIteratorTypesArray (),
1470
+ /* bodyBuild=*/ nullptr , linalg::getPrunedAttributeList (genericOp));
1471
+ rewriter.cloneRegionBefore (genericOp.getRegion (), newGenericOp.getRegion (),
1472
+ newGenericOp.getRegion ().begin ());
1473
+
1474
+ auto extractOp = tensor::ExtractSliceOp::create (
1475
+ rewriter, loc,
1476
+ newGenericOp.getTiedOpResult (newGenericOp.getDpsInitOperand (0 )),
1477
+ outputLowPads, newSizes, newStrides);
1478
+ Value extractRes = extractOp.getResult ();
1479
+
1480
+ return std::make_tuple (newGenericOp, extractRes);
1481
+ }
1482
+
1483
+ class PushDownExtractSliceOpThroughGenericOp final
1484
+ : public OpRewritePattern<GenericOp> {
1485
+ public:
1486
+ PushDownExtractSliceOpThroughGenericOp (MLIRContext *context,
1487
+ ControlPropagationFn fun)
1488
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1489
+
1490
+ LogicalResult matchAndRewrite (GenericOp genericOp,
1491
+ PatternRewriter &rewriter) const override {
1492
+ auto genericAndRepl =
1493
+ pushDownExtractSliceOpThroughGenericOp (rewriter, genericOp, controlFn);
1494
+ if (failed (genericAndRepl))
1495
+ return failure ();
1496
+ rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
1497
+ return success ();
1498
+ }
1499
+
1500
+ private:
1501
+ ControlPropagationFn controlFn;
1502
+ };
1503
+
1239
1504
} // namespace
1240
1505
1241
1506
void mlir::linalg::populateDataLayoutPropagationPatterns (
@@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
1247
1512
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1248
1513
patterns.getContext (), controlPackUnPackPropagation);
1249
1514
}
1515
+
1516
+ void mlir::linalg::populateExtractSlicePropagationPatterns (
1517
+ RewritePatternSet &patterns,
1518
+ const ControlPropagationFn &controlPackUnPackPropagation) {
1519
+ patterns.insert <PushDownExtractSliceOpThroughGenericOp>(
1520
+ patterns.getContext (), controlPackUnPackPropagation);
1521
+ }
0 commit comments