6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
9
10
#include " mlir/Dialect/Linalg/IR/Linalg.h"
10
11
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
11
12
#include " mlir/Dialect/Linalg/Utils/Utils.h"
12
13
#include " mlir/Dialect/Tensor/IR/Tensor.h"
14
+ #include " mlir/Dialect/UB/IR/UBOps.h"
13
15
#include " mlir/Dialect/Utils/IndexingUtils.h"
14
16
#include " mlir/IR/Dominance.h"
15
17
#include " llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,266 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1236
1238
ControlPropagationFn controlFn;
1237
1239
};
1238
1240
1241
+ // This struct contains infomation about extract_slice dims.
1242
+ struct SliceDimInfo {
1243
+ OpFoldResult offset;
1244
+ OpFoldResult sliceSize;
1245
+ OpFoldResult outputSize;
1246
+ };
1247
+
1248
+ // / Return the first input extract slice operand, if present, for the current
1249
+ // / generic op.
1250
+ static FailureOr<std::tuple<OpOperand *, unsigned >>
1251
+ getSliceOperandAndIndex (GenericOp genericOp) {
1252
+ OpOperand *sliceOperand = nullptr ;
1253
+ unsigned operandIndex;
1254
+ for (auto [idx, operand] : llvm::enumerate (genericOp.getDpsInputOperands ())) {
1255
+ auto extractOp = operand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1256
+ if (!extractOp)
1257
+ continue ;
1258
+ sliceOperand = operand;
1259
+ operandIndex = idx;
1260
+ break ;
1261
+ }
1262
+ if (!sliceOperand) {
1263
+ return failure ();
1264
+ }
1265
+ return std::make_tuple (sliceOperand, operandIndex);
1266
+ }
1267
+
1268
+ static FailureOr<std::tuple<llvm::DenseMap<int64_t , SliceDimInfo>, bool >>
1269
+ getNonZeroSliceDimInfo (GenericOp genericOp, OpOperand *sliceOperand,
1270
+ tensor::ExtractSliceOp producerSliceOp) {
1271
+ llvm::DenseMap<int64_t , SliceDimInfo> nonZeroSliceDimMap;
1272
+ bool hasNonZeroReductionDimSlice = false ;
1273
+ SmallVector<utils::IteratorType> iterators =
1274
+ genericOp.getIteratorTypesArray ();
1275
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets ();
1276
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes ();
1277
+
1278
+ SmallVector<OpFoldResult> shape = llvm::map_to_vector (
1279
+ producerSliceOp.getSourceType ().getShape (),
1280
+ [&](int64_t sz) -> OpFoldResult {
1281
+ return getAsIndexOpFoldResult (genericOp.getContext (), sz);
1282
+ });
1283
+
1284
+ for (auto [idx, expr] : llvm::enumerate (
1285
+ genericOp.getMatchingIndexingMap (sliceOperand).getResults ())) {
1286
+ if (isConstantIntValue (offsets[idx], 0 ) &&
1287
+ isEqualConstantIntOrValue (sizes[idx], shape[idx])) {
1288
+ continue ;
1289
+ }
1290
+ if (!isa<AffineDimExpr>(expr)) {
1291
+ return failure ();
1292
+ }
1293
+ SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1294
+ int64_t dimPos = cast<AffineDimExpr>(expr).getPosition ();
1295
+ nonZeroSliceDimMap[dimPos] = sliceDimInfo;
1296
+ if (iterators[dimPos] == utils::IteratorType::reduction) {
1297
+ hasNonZeroReductionDimSlice = true ;
1298
+ }
1299
+ }
1300
+ // Next check if the dims with non zero slice info are used as non
1301
+ // AffineDimExpr and if they are then bail-out.
1302
+ for (OpOperand &operand : genericOp->getOpOperands ()) {
1303
+ if (operand == *sliceOperand) {
1304
+ continue ;
1305
+ }
1306
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (&operand);
1307
+ if (llvm::any_of (IndexingMap.getResults (), [&](AffineExpr expr) {
1308
+ if (isa<AffineDimExpr>(expr)) {
1309
+ return false ;
1310
+ }
1311
+ WalkResult status = expr.walk ([&](AffineExpr expr) {
1312
+ if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1313
+ if (nonZeroSliceDimMap.count (dimExpr.getPosition ()) != 0 ) {
1314
+ return WalkResult::interrupt ();
1315
+ }
1316
+ }
1317
+ return WalkResult::advance ();
1318
+ });
1319
+ if (status.wasInterrupted ()) {
1320
+ return true ;
1321
+ }
1322
+ return false ;
1323
+ })) {
1324
+ return failure ();
1325
+ }
1326
+ }
1327
+ return std::make_tuple (nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
1328
+ }
1329
+
1330
+ static FailureOr<std::tuple<GenericOp, Value>>
1331
+ pushDownExtractSliceOpThroughGenericOp (RewriterBase &rewriter,
1332
+ GenericOp genericOp,
1333
+ ControlPropagationFn controlFn) {
1334
+ if (genericOp.getNumResults () != 1 )
1335
+ return failure ();
1336
+ if (hasGatherSemantics (genericOp))
1337
+ return failure ();
1338
+ // Collect the unPacked operand, if present.
1339
+ auto maybeSliceOperandAndIndex = getSliceOperandAndIndex (genericOp);
1340
+ if (failed (maybeSliceOperandAndIndex))
1341
+ return failure ();
1342
+ OpOperand *sliceOperand = std::get<0 >(*maybeSliceOperandAndIndex);
1343
+ unsigned OperandIndex = std::get<1 >(*maybeSliceOperandAndIndex);
1344
+
1345
+ if (!controlFn (sliceOperand))
1346
+ return failure ();
1347
+
1348
+ tensor::ExtractSliceOp producerSliceOp =
1349
+ sliceOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
1350
+ assert (producerSliceOp && " expect a valid UnPackOp" );
1351
+
1352
+ if (producerSliceOp.getSource ().getType ().getRank () !=
1353
+ producerSliceOp.getResult ().getType ().getRank ()) {
1354
+ return failure ();
1355
+ }
1356
+
1357
+ SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides ();
1358
+ if (!areAllConstantIntValue (strides, 1 ))
1359
+ return failure ();
1360
+
1361
+ SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets ();
1362
+ SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes ();
1363
+
1364
+ // check if we can support the propagation of this extractSlice
1365
+ // through the generic op and if so return the dimensions that
1366
+
1367
+ auto maybeNonZeroSliceDimMap =
1368
+ getNonZeroSliceDimInfo (genericOp, sliceOperand, producerSliceOp);
1369
+
1370
+ if (failed (maybeNonZeroSliceDimMap)) {
1371
+ return failure ();
1372
+ }
1373
+
1374
+ auto nonZeroSliceDimMap = std::get<0 >(*maybeNonZeroSliceDimMap);
1375
+ bool hasNonZeroReductionDimSlice = std::get<1 >(*maybeNonZeroSliceDimMap);
1376
+
1377
+ // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1378
+ Location loc = genericOp->getLoc ();
1379
+ AffineExpr dim0, dim1;
1380
+ bindDims (rewriter.getContext (), dim0, dim1);
1381
+ auto subMap = AffineMap::get (2 , 0 , {dim0 - dim1});
1382
+ auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1383
+ return affine::makeComposedFoldedAffineApply (rewriter, loc, subMap,
1384
+ {v1, v2});
1385
+ };
1386
+
1387
+ MLIRContext *ctx = genericOp.getContext ();
1388
+ SmallVector<Value> paddedInputs;
1389
+ for (auto [idx, operand] : llvm::enumerate (genericOp.getDpsInputOperands ())) {
1390
+ if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
1391
+ paddedInputs.push_back (producerSliceOp.getSource ());
1392
+ continue ;
1393
+ }
1394
+ AffineMap IndexingMap = genericOp.getMatchingIndexingMap (operand);
1395
+ SmallVector<OpFoldResult> operandLowPads (IndexingMap.getNumResults (),
1396
+ getAsIndexOpFoldResult (ctx, 0 ));
1397
+ SmallVector<OpFoldResult> operandHighPads (IndexingMap.getNumResults (),
1398
+ getAsIndexOpFoldResult (ctx, 0 ));
1399
+ for (auto [idx, expr] : llvm::enumerate (IndexingMap.getResults ())) {
1400
+ if (!isa<AffineDimExpr>(expr)) {
1401
+ continue ;
1402
+ }
1403
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1404
+ if (nonZeroSliceDimMap.contains (dimExpr.getPosition ())) {
1405
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition ()];
1406
+ operandLowPads[idx] = sliceDimInfo.offset ;
1407
+ operandHighPads[idx] =
1408
+ sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1409
+ sliceDimInfo.sliceSize );
1410
+ }
1411
+ }
1412
+ auto paddingValue = ub::PoisonOp::create (
1413
+ rewriter, loc, getElementTypeOrSelf (operand->get ().getType ()));
1414
+ auto paddedOperand = tensor::PadOp::create (
1415
+ rewriter, loc, Type (), operand->get (), operandLowPads, operandHighPads,
1416
+ paddingValue, /* nofold=*/ false );
1417
+ paddedInputs.push_back (paddedOperand);
1418
+ }
1419
+ AffineMap outputIndexingMap =
1420
+ genericOp.getMatchingIndexingMap (genericOp.getDpsInitOperand (0 ));
1421
+
1422
+ auto outputShapeType =
1423
+ llvm::cast<ShapedType>(genericOp.getDpsInitOperand (0 )->get ().getType ());
1424
+ SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector (
1425
+ outputShapeType.getShape (),
1426
+ [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr (sz); });
1427
+ SmallVector<OpFoldResult> newSizes = OutputShape;
1428
+ SmallVector<OpFoldResult> outputLowPads (outputIndexingMap.getNumResults (),
1429
+ getAsIndexOpFoldResult (ctx, 0 ));
1430
+ SmallVector<OpFoldResult> outputHighPads (outputIndexingMap.getNumResults (),
1431
+ getAsIndexOpFoldResult (ctx, 0 ));
1432
+ SmallVector<OpFoldResult> newStrides (outputIndexingMap.getNumResults (),
1433
+ getAsIndexOpFoldResult (ctx, 1 ));
1434
+ for (auto [idx, expr] : llvm::enumerate (outputIndexingMap.getResults ())) {
1435
+ if (!isa<AffineDimExpr>(expr)) {
1436
+ continue ;
1437
+ }
1438
+ AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1439
+ if (nonZeroSliceDimMap.contains (dimExpr.getPosition ())) {
1440
+ SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition ()];
1441
+ outputLowPads[idx] = sliceDimInfo.offset ;
1442
+ outputHighPads[idx] =
1443
+ sub (sub (sliceDimInfo.outputSize , sliceDimInfo.offset ),
1444
+ sliceDimInfo.sliceSize );
1445
+ OutputShape[idx] = sliceDimInfo.outputSize ;
1446
+ newSizes[idx] = sliceDimInfo.sliceSize ;
1447
+ }
1448
+ }
1449
+ Value newPadOutput;
1450
+ auto outputElType =
1451
+ getElementTypeOrSelf (genericOp.getDpsInits ()[0 ].getType ());
1452
+ if (isGenericOutsNotUsed (genericOp)) {
1453
+ newPadOutput =
1454
+ tensor::EmptyOp::create (rewriter, loc, OutputShape, outputElType);
1455
+
1456
+ } else {
1457
+
1458
+ auto paddingValue = ub::PoisonOp::create (rewriter, loc, outputElType);
1459
+ newPadOutput = tensor::PadOp::create (
1460
+ rewriter, loc, Type (), genericOp.getDpsInits ()[0 ], outputLowPads,
1461
+ outputHighPads, paddingValue, /* nofold=*/ false );
1462
+ }
1463
+
1464
+ auto newGenericOp = linalg::GenericOp::create (
1465
+ rewriter, loc, newPadOutput.getType (), paddedInputs, {newPadOutput},
1466
+ genericOp.getIndexingMapsArray (), genericOp.getIteratorTypesArray (),
1467
+ /* bodyBuild=*/ nullptr , linalg::getPrunedAttributeList (genericOp));
1468
+ rewriter.cloneRegionBefore (genericOp.getRegion (), newGenericOp.getRegion (),
1469
+ newGenericOp.getRegion ().begin ());
1470
+
1471
+ auto extractOp = tensor::ExtractSliceOp::create (
1472
+ rewriter, loc,
1473
+ newGenericOp.getTiedOpResult (newGenericOp.getDpsInitOperand (0 )),
1474
+ outputLowPads, newSizes, newStrides);
1475
+ Value extractRes = extractOp.getResult ();
1476
+
1477
+ return std::make_tuple (newGenericOp, extractRes);
1478
+ }
1479
+
1480
+ class PushDownExtractSliceOpThroughGenericOp final
1481
+ : public OpRewritePattern<GenericOp> {
1482
+ public:
1483
+ PushDownExtractSliceOpThroughGenericOp (MLIRContext *context,
1484
+ ControlPropagationFn fun)
1485
+ : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1486
+
1487
+ LogicalResult matchAndRewrite (GenericOp genericOp,
1488
+ PatternRewriter &rewriter) const override {
1489
+ auto genericAndRepl =
1490
+ pushDownExtractSliceOpThroughGenericOp (rewriter, genericOp, controlFn);
1491
+ if (failed (genericAndRepl))
1492
+ return failure ();
1493
+ rewriter.replaceOp (genericOp, std::get<1 >(*genericAndRepl));
1494
+ return success ();
1495
+ }
1496
+
1497
+ private:
1498
+ ControlPropagationFn controlFn;
1499
+ };
1500
+
1239
1501
} // namespace
1240
1502
1241
1503
void mlir::linalg::populateDataLayoutPropagationPatterns (
@@ -1247,3 +1509,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
1247
1509
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1248
1510
patterns.getContext (), controlPackUnPackPropagation);
1249
1511
}
1512
+
1513
+ void mlir::linalg::populateExtractSlicePropagationPatterns (
1514
+ RewritePatternSet &patterns,
1515
+ const ControlPropagationFn &controlPackUnPackPropagation) {
1516
+ patterns.insert <PushDownExtractSliceOpThroughGenericOp>(
1517
+ patterns.getContext (), controlPackUnPackPropagation);
1518
+ }
0 commit comments