Skip to content

Commit efef821

Browse files
[Linalg] Add pattern to push down extract slice through generic
Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 351b38f commit efef821

File tree

4 files changed

+386
-0
lines changed

4 files changed

+386
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,11 @@ void populateDataLayoutPropagationPatterns(
19141914
RewritePatternSet &patterns,
19151915
const ControlPropagationFn &controlPackUnPackPropagation);
19161916

1917+
/// Patterns to bubble up or down extract slice across other operations.
1918+
void populateExtractSlicePropagationPatterns(
1919+
RewritePatternSet &patterns,
1920+
const ControlPropagationFn &controlPackUnPackPropagation);
1921+
19171922
/// Pattern to remove dead operands and results of `linalg.generic` operations.
19181923
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
19191924
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
910
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1011
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1112
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1213
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/Dialect/UB/IR/UBOps.h"
1315
#include "mlir/Dialect/Utils/IndexingUtils.h"
1416
#include "mlir/IR/Dominance.h"
1517
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,266 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
12361238
ControlPropagationFn controlFn;
12371239
};
12381240

1241+
// This struct contains infomation about extract_slice dims.
1242+
struct SliceDimInfo {
1243+
OpFoldResult offset;
1244+
OpFoldResult sliceSize;
1245+
OpFoldResult outputSize;
1246+
};
1247+
1248+
/// Return the first input extract slice operand, if present, for the current
1249+
/// generic op.
1250+
static FailureOr<std::tuple<OpOperand *, unsigned>>
1251+
getSliceOperandAndIndex(GenericOp genericOp) {
1252+
OpOperand *sliceOperand = nullptr;
1253+
unsigned operandIndex;
1254+
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1255+
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1256+
if (!extractOp)
1257+
continue;
1258+
sliceOperand = operand;
1259+
operandIndex = idx;
1260+
break;
1261+
}
1262+
if (!sliceOperand) {
1263+
return failure();
1264+
}
1265+
return std::make_tuple(sliceOperand, operandIndex);
1266+
}
1267+
1268+
static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
1269+
getNonZeroSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
1270+
tensor::ExtractSliceOp producerSliceOp) {
1271+
llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
1272+
bool hasNonZeroReductionDimSlice = false;
1273+
SmallVector<utils::IteratorType> iterators =
1274+
genericOp.getIteratorTypesArray();
1275+
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1276+
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1277+
1278+
SmallVector<OpFoldResult> shape = llvm::map_to_vector(
1279+
producerSliceOp.getSourceType().getShape(),
1280+
[&](int64_t sz) -> OpFoldResult {
1281+
return getAsIndexOpFoldResult(genericOp.getContext(), sz);
1282+
});
1283+
1284+
for (auto [idx, expr] : llvm::enumerate(
1285+
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1286+
if (isConstantIntValue(offsets[idx], 0) &&
1287+
isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
1288+
continue;
1289+
}
1290+
if (!isa<AffineDimExpr>(expr)) {
1291+
return failure();
1292+
}
1293+
SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1294+
int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1295+
nonZeroSliceDimMap[dimPos] = sliceDimInfo;
1296+
if (iterators[dimPos] == utils::IteratorType::reduction) {
1297+
hasNonZeroReductionDimSlice = true;
1298+
}
1299+
}
1300+
// Next check if the dims with non zero slice info are used as non
1301+
// AffineDimExpr and if they are then bail-out.
1302+
for (OpOperand &operand : genericOp->getOpOperands()) {
1303+
if (operand == *sliceOperand) {
1304+
continue;
1305+
}
1306+
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1307+
if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
1308+
if (isa<AffineDimExpr>(expr)) {
1309+
return false;
1310+
}
1311+
WalkResult status = expr.walk([&](AffineExpr expr) {
1312+
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1313+
if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
1314+
return WalkResult::interrupt();
1315+
}
1316+
}
1317+
return WalkResult::advance();
1318+
});
1319+
if (status.wasInterrupted()) {
1320+
return true;
1321+
}
1322+
return false;
1323+
})) {
1324+
return failure();
1325+
}
1326+
}
1327+
return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
1328+
}
1329+
1330+
static FailureOr<std::tuple<GenericOp, Value>>
1331+
pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
1332+
GenericOp genericOp,
1333+
ControlPropagationFn controlFn) {
1334+
if (genericOp.getNumResults() != 1)
1335+
return failure();
1336+
if (hasGatherSemantics(genericOp))
1337+
return failure();
1338+
// Collect the unPacked operand, if present.
1339+
auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
1340+
if (failed(maybeSliceOperandAndIndex))
1341+
return failure();
1342+
OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
1343+
unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
1344+
1345+
if (!controlFn(sliceOperand))
1346+
return failure();
1347+
1348+
tensor::ExtractSliceOp producerSliceOp =
1349+
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1350+
assert(producerSliceOp && "expect a valid UnPackOp");
1351+
1352+
if (producerSliceOp.getSource().getType().getRank() !=
1353+
producerSliceOp.getResult().getType().getRank()) {
1354+
return failure();
1355+
}
1356+
1357+
SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
1358+
if (!areAllConstantIntValue(strides, 1))
1359+
return failure();
1360+
1361+
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1362+
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1363+
1364+
// check if we can support the propagation of this extractSlice
1365+
// through the generic op and if so return the dimensions that
1366+
1367+
auto maybeNonZeroSliceDimMap =
1368+
getNonZeroSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
1369+
1370+
if (failed(maybeNonZeroSliceDimMap)) {
1371+
return failure();
1372+
}
1373+
1374+
auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
1375+
bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
1376+
1377+
// Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1378+
Location loc = genericOp->getLoc();
1379+
AffineExpr dim0, dim1;
1380+
bindDims(rewriter.getContext(), dim0, dim1);
1381+
auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
1382+
auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1383+
return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
1384+
{v1, v2});
1385+
};
1386+
1387+
MLIRContext *ctx = genericOp.getContext();
1388+
SmallVector<Value> paddedInputs;
1389+
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1390+
if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
1391+
paddedInputs.push_back(producerSliceOp.getSource());
1392+
continue;
1393+
}
1394+
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1395+
SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
1396+
getAsIndexOpFoldResult(ctx, 0));
1397+
SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
1398+
getAsIndexOpFoldResult(ctx, 0));
1399+
for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
1400+
if (!isa<AffineDimExpr>(expr)) {
1401+
continue;
1402+
}
1403+
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1404+
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
1405+
SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
1406+
operandLowPads[idx] = sliceDimInfo.offset;
1407+
operandHighPads[idx] =
1408+
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1409+
sliceDimInfo.sliceSize);
1410+
}
1411+
}
1412+
auto paddingValue = ub::PoisonOp::create(
1413+
rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
1414+
auto paddedOperand = tensor::PadOp::create(
1415+
rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
1416+
paddingValue, /*nofold=*/false);
1417+
paddedInputs.push_back(paddedOperand);
1418+
}
1419+
AffineMap outputIndexingMap =
1420+
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1421+
1422+
auto outputShapeType =
1423+
llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1424+
SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
1425+
outputShapeType.getShape(),
1426+
[&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
1427+
SmallVector<OpFoldResult> newSizes = OutputShape;
1428+
SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
1429+
getAsIndexOpFoldResult(ctx, 0));
1430+
SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
1431+
getAsIndexOpFoldResult(ctx, 0));
1432+
SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
1433+
getAsIndexOpFoldResult(ctx, 1));
1434+
for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
1435+
if (!isa<AffineDimExpr>(expr)) {
1436+
continue;
1437+
}
1438+
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1439+
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
1440+
SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
1441+
outputLowPads[idx] = sliceDimInfo.offset;
1442+
outputHighPads[idx] =
1443+
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1444+
sliceDimInfo.sliceSize);
1445+
OutputShape[idx] = sliceDimInfo.outputSize;
1446+
newSizes[idx] = sliceDimInfo.sliceSize;
1447+
}
1448+
}
1449+
Value newPadOutput;
1450+
auto outputElType =
1451+
getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
1452+
if (isGenericOutsNotUsed(genericOp)) {
1453+
newPadOutput =
1454+
tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1455+
1456+
} else {
1457+
1458+
auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1459+
newPadOutput = tensor::PadOp::create(
1460+
rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
1461+
outputHighPads, paddingValue, /*nofold=*/false);
1462+
}
1463+
1464+
auto newGenericOp = linalg::GenericOp::create(
1465+
rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
1466+
genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1467+
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
1468+
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
1469+
newGenericOp.getRegion().begin());
1470+
1471+
auto extractOp = tensor::ExtractSliceOp::create(
1472+
rewriter, loc,
1473+
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1474+
outputLowPads, newSizes, newStrides);
1475+
Value extractRes = extractOp.getResult();
1476+
1477+
return std::make_tuple(newGenericOp, extractRes);
1478+
}
1479+
1480+
class PushDownExtractSliceOpThroughGenericOp final
1481+
: public OpRewritePattern<GenericOp> {
1482+
public:
1483+
PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
1484+
ControlPropagationFn fun)
1485+
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1486+
1487+
LogicalResult matchAndRewrite(GenericOp genericOp,
1488+
PatternRewriter &rewriter) const override {
1489+
auto genericAndRepl =
1490+
pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1491+
if (failed(genericAndRepl))
1492+
return failure();
1493+
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1494+
return success();
1495+
}
1496+
1497+
private:
1498+
ControlPropagationFn controlFn;
1499+
};
1500+
12391501
} // namespace
12401502

12411503
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1509,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
12471509
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
12481510
patterns.getContext(), controlPackUnPackPropagation);
12491511
}
1512+
1513+
void mlir::linalg::populateExtractSlicePropagationPatterns(
1514+
RewritePatternSet &patterns,
1515+
const ControlPropagationFn &controlPackUnPackPropagation) {
1516+
patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1517+
patterns.getContext(), controlPackUnPackPropagation);
1518+
}

0 commit comments

Comments
 (0)