Skip to content

Commit 7b9d96e

Browse files
[Linalg] Add pattern to push down extract slice through generic
Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 7f27482 commit 7b9d96e

File tree

4 files changed

+389
-0
lines changed

4 files changed

+389
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
19181918
RewritePatternSet &patterns,
19191919
const ControlPropagationFn &controlPackUnPackPropagation);
19201920

1921+
/// Patterns to bubble up or down extract slice across other operations.
1922+
void populateExtractSlicePropagationPatterns(
1923+
RewritePatternSet &patterns,
1924+
const ControlPropagationFn &controlPackUnPackPropagation);
1925+
19211926
/// Pattern to remove dead operands and results of `linalg.generic` operations.
19221927
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
19231928
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
910
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1011
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1112
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1213
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/Dialect/UB/IR/UBOps.h"
1315
#include "mlir/Dialect/Utils/IndexingUtils.h"
1416
#include "mlir/IR/Dominance.h"
1517
#include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
12361238
ControlPropagationFn controlFn;
12371239
};
12381240

1241+
// This struct contains infomation about extract_slice dims.
1242+
struct SliceDimInfo {
1243+
OpFoldResult offset;
1244+
OpFoldResult sliceSize;
1245+
OpFoldResult outputSize;
1246+
};
1247+
1248+
/// Return the first input extract slice operand, if present, for the current
1249+
/// generic op.
1250+
static FailureOr<std::tuple<OpOperand *, unsigned>>
1251+
getSliceOperandAndIndex(GenericOp genericOp) {
1252+
OpOperand *sliceOperand = nullptr;
1253+
unsigned operandIndex;
1254+
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1255+
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1256+
if (!extractOp)
1257+
continue;
1258+
sliceOperand = operand;
1259+
operandIndex = idx;
1260+
break;
1261+
}
1262+
if (!sliceOperand) {
1263+
return failure();
1264+
}
1265+
return std::make_tuple(sliceOperand, operandIndex);
1266+
}
1267+
1268+
// Return a map of dims that have non full slices on them so that other operands
1269+
// can use this information. Also return a bool mentioning if a reduction dim
1270+
// has a non full slice as that can be used to fold the original extract slice.
1271+
static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
1272+
getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
1273+
tensor::ExtractSliceOp producerSliceOp) {
1274+
llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
1275+
bool hasNonZeroReductionDimSlice = false;
1276+
SmallVector<utils::IteratorType> iterators =
1277+
genericOp.getIteratorTypesArray();
1278+
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1279+
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1280+
1281+
SmallVector<OpFoldResult> shape = llvm::map_to_vector(
1282+
producerSliceOp.getSourceType().getShape(),
1283+
[&](int64_t sz) -> OpFoldResult {
1284+
return getAsIndexOpFoldResult(genericOp.getContext(), sz);
1285+
});
1286+
1287+
for (auto [idx, expr] : llvm::enumerate(
1288+
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1289+
if (isConstantIntValue(offsets[idx], 0) &&
1290+
isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
1291+
continue;
1292+
}
1293+
if (!isa<AffineDimExpr>(expr)) {
1294+
return failure();
1295+
}
1296+
SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1297+
int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1298+
nonZeroSliceDimMap[dimPos] = sliceDimInfo;
1299+
if (iterators[dimPos] == utils::IteratorType::reduction) {
1300+
hasNonZeroReductionDimSlice = true;
1301+
}
1302+
}
1303+
// Next check if the dims with non zero slice info are used as non
1304+
// AffineDimExpr and if they are then bail-out.
1305+
for (OpOperand &operand : genericOp->getOpOperands()) {
1306+
if (operand == *sliceOperand) {
1307+
continue;
1308+
}
1309+
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1310+
if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
1311+
if (isa<AffineDimExpr>(expr)) {
1312+
return false;
1313+
}
1314+
WalkResult status = expr.walk([&](AffineExpr expr) {
1315+
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1316+
if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
1317+
return WalkResult::interrupt();
1318+
}
1319+
}
1320+
return WalkResult::advance();
1321+
});
1322+
if (status.wasInterrupted()) {
1323+
return true;
1324+
}
1325+
return false;
1326+
})) {
1327+
return failure();
1328+
}
1329+
}
1330+
return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
1331+
}
1332+
1333+
static FailureOr<std::tuple<GenericOp, Value>>
1334+
pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
1335+
GenericOp genericOp,
1336+
ControlPropagationFn controlFn) {
1337+
if (genericOp.getNumResults() != 1)
1338+
return failure();
1339+
if (hasGatherSemantics(genericOp))
1340+
return failure();
1341+
// Collect the unPacked operand, if present.
1342+
auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
1343+
if (failed(maybeSliceOperandAndIndex))
1344+
return failure();
1345+
OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
1346+
unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
1347+
1348+
if (!controlFn(sliceOperand))
1349+
return failure();
1350+
1351+
tensor::ExtractSliceOp producerSliceOp =
1352+
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1353+
assert(producerSliceOp && "expect a valid UnPackOp");
1354+
1355+
if (producerSliceOp.getSource().getType().getRank() !=
1356+
producerSliceOp.getResult().getType().getRank()) {
1357+
return failure();
1358+
}
1359+
1360+
SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
1361+
if (!areAllConstantIntValue(strides, 1))
1362+
return failure();
1363+
1364+
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1365+
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1366+
1367+
// check if we can support the propagation of this extractSlice
1368+
// through the generic op and if so return the dimensions that
1369+
1370+
auto maybeNonZeroSliceDimMap =
1371+
getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
1372+
1373+
if (failed(maybeNonZeroSliceDimMap)) {
1374+
return failure();
1375+
}
1376+
1377+
auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
1378+
bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
1379+
1380+
// Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1381+
Location loc = genericOp->getLoc();
1382+
AffineExpr dim0, dim1;
1383+
bindDims(rewriter.getContext(), dim0, dim1);
1384+
auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
1385+
auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1386+
return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
1387+
{v1, v2});
1388+
};
1389+
1390+
MLIRContext *ctx = genericOp.getContext();
1391+
SmallVector<Value> paddedInputs;
1392+
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1393+
if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
1394+
paddedInputs.push_back(producerSliceOp.getSource());
1395+
continue;
1396+
}
1397+
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1398+
SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
1399+
getAsIndexOpFoldResult(ctx, 0));
1400+
SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
1401+
getAsIndexOpFoldResult(ctx, 0));
1402+
for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
1403+
if (!isa<AffineDimExpr>(expr)) {
1404+
continue;
1405+
}
1406+
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1407+
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
1408+
SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
1409+
operandLowPads[idx] = sliceDimInfo.offset;
1410+
operandHighPads[idx] =
1411+
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1412+
sliceDimInfo.sliceSize);
1413+
}
1414+
}
1415+
auto paddingValue = ub::PoisonOp::create(
1416+
rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
1417+
auto paddedOperand = tensor::PadOp::create(
1418+
rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
1419+
paddingValue, /*nofold=*/false);
1420+
paddedInputs.push_back(paddedOperand);
1421+
}
1422+
AffineMap outputIndexingMap =
1423+
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1424+
1425+
auto outputShapeType =
1426+
llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1427+
SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
1428+
outputShapeType.getShape(),
1429+
[&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
1430+
SmallVector<OpFoldResult> newSizes = OutputShape;
1431+
SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
1432+
getAsIndexOpFoldResult(ctx, 0));
1433+
SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
1434+
getAsIndexOpFoldResult(ctx, 0));
1435+
SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
1436+
getAsIndexOpFoldResult(ctx, 1));
1437+
for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
1438+
if (!isa<AffineDimExpr>(expr)) {
1439+
continue;
1440+
}
1441+
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1442+
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
1443+
SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
1444+
outputLowPads[idx] = sliceDimInfo.offset;
1445+
outputHighPads[idx] =
1446+
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1447+
sliceDimInfo.sliceSize);
1448+
OutputShape[idx] = sliceDimInfo.outputSize;
1449+
newSizes[idx] = sliceDimInfo.sliceSize;
1450+
}
1451+
}
1452+
Value newPadOutput;
1453+
auto outputElType =
1454+
getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
1455+
if (isGenericOutsNotUsed(genericOp)) {
1456+
newPadOutput =
1457+
tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1458+
1459+
} else {
1460+
1461+
auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1462+
newPadOutput = tensor::PadOp::create(
1463+
rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
1464+
outputHighPads, paddingValue, /*nofold=*/false);
1465+
}
1466+
1467+
auto newGenericOp = linalg::GenericOp::create(
1468+
rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
1469+
genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1470+
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
1471+
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
1472+
newGenericOp.getRegion().begin());
1473+
1474+
auto extractOp = tensor::ExtractSliceOp::create(
1475+
rewriter, loc,
1476+
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1477+
outputLowPads, newSizes, newStrides);
1478+
Value extractRes = extractOp.getResult();
1479+
1480+
return std::make_tuple(newGenericOp, extractRes);
1481+
}
1482+
1483+
class PushDownExtractSliceOpThroughGenericOp final
1484+
: public OpRewritePattern<GenericOp> {
1485+
public:
1486+
PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
1487+
ControlPropagationFn fun)
1488+
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1489+
1490+
LogicalResult matchAndRewrite(GenericOp genericOp,
1491+
PatternRewriter &rewriter) const override {
1492+
auto genericAndRepl =
1493+
pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1494+
if (failed(genericAndRepl))
1495+
return failure();
1496+
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1497+
return success();
1498+
}
1499+
1500+
private:
1501+
ControlPropagationFn controlFn;
1502+
};
1503+
12391504
} // namespace
12401505

12411506
void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
12471512
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
12481513
patterns.getContext(), controlPackUnPackPropagation);
12491514
}
1515+
1516+
void mlir::linalg::populateExtractSlicePropagationPatterns(
1517+
RewritePatternSet &patterns,
1518+
const ControlPropagationFn &controlPackUnPackPropagation) {
1519+
patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1520+
patterns.getContext(), controlPackUnPackPropagation);
1521+
}

0 commit comments

Comments
 (0)