Skip to content

[Linalg] Add pattern to push down extract slice through linalg generic op #154162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);

/// Patterns to sink extract slice across other operations.
void populateExtractSliceSinkingPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation);

/// Pattern to remove dead operands and results of `linalg.generic` operations.
/// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
Expand Down
272 changes: 272 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "llvm/ADT/SetOperations.h"
Expand Down Expand Up @@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
ControlPropagationFn controlFn;
};

// This struct contains infomation about extract_slice dims.
struct SliceDimInfo {
OpFoldResult offset;
OpFoldResult sliceSize;
OpFoldResult outputSize;
};

/// Return the first input extract slice operand, if present, for the current
/// generic op.
static FailureOr<std::tuple<OpOperand *, unsigned>>
getSliceOperandAndIndex(GenericOp genericOp) {
OpOperand *sliceOperand = nullptr;
unsigned operandIndex;
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!extractOp)
continue;
sliceOperand = operand;
operandIndex = idx;
break;
}
if (!sliceOperand) {
return failure();
}
return std::make_tuple(sliceOperand, operandIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think you need to return the operandIdx really, this is already part of OpOperand *?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 (sliceOperand->getOperandNumber())

}

// Return a map of dims that have non full slices on them so that other operands
// can use this information. Also return a bool mentioning if a reduction dim
// has a non full slice as that can be used to fold the original extract slice.
static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you dont need to pass the producerSliceOp. You should be able to get it from the OpOperand *.

tensor::ExtractSliceOp producerSliceOp) {
llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was confused by nonZeroSliceDimMap. I think you mean partialSliceDimMap? or something like that.

bool hasNonZeroReductionDimSlice = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with this. hasPartialReductionDimSlice or something like that is easier to understand.

SmallVector<utils::IteratorType> iterators =
genericOp.getIteratorTypesArray();
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();

SmallVector<OpFoldResult> shape = llvm::map_to_vector(
producerSliceOp.getSourceType().getShape(),
[&](int64_t sz) -> OpFoldResult {
return getAsIndexOpFoldResult(genericOp.getContext(), sz);
});
Comment on lines +1281 to +1285
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can pass an ArrayRef to getAsIndexOpFoldResult, so you don't need to map_to_vector.


for (auto [idx, expr] : llvm::enumerate(
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
if (isConstantIntValue(offsets[idx], 0) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comments as to what each of these conditions are checking?

isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
continue;
}
if (!isa<AffineDimExpr>(expr)) {
return failure();
}
SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
nonZeroSliceDimMap[dimPos] = sliceDimInfo;
if (iterators[dimPos] == utils::IteratorType::reduction) {
hasNonZeroReductionDimSlice = true;
}
}
// Next check if the dims with non zero slice info are used as non
// AffineDimExpr and if they are then bail-out.
for (OpOperand &operand : genericOp->getOpOperands()) {
if (operand == *sliceOperand) {
continue;
}
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
if (isa<AffineDimExpr>(expr)) {
return false;
}
WalkResult status = expr.walk([&](AffineExpr expr) {
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
return WalkResult::interrupt();
Comment on lines +1316 to +1317
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use nonZeroSliceDimMap.contains(dimExpr.getPosition())?

}
}
return WalkResult::advance();
});
if (status.wasInterrupted()) {
return true;
}
return false;
})) {
return failure();
}
}
return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
}

static FailureOr<std::tuple<GenericOp, Value>>
pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
GenericOp genericOp,
ControlPropagationFn controlFn) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: using rewriter.notifyMatchFailure() is helpful for the less obvious restrictions on this pattern (things like rank-reducing, gather semantics, etc.)

if (genericOp.getNumResults() != 1)
return failure();
if (hasGatherSemantics(genericOp))
return failure();
// Collect the unPacked operand, if present.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Collect the unPacked operand, if present.
// Collect the sliced operand, if present.

auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
if (failed(maybeSliceOperandAndIndex))
return failure();
OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);

if (!controlFn(sliceOperand))
return failure();

tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
assert(producerSliceOp && "expect a valid UnPackOp");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(producerSliceOp && "expect a valid UnPackOp");
assert(producerSliceOp && "expect a valid extract_slice op");


if (producerSliceOp.getSource().getType().getRank() !=
producerSliceOp.getResult().getType().getRank()) {
return failure();
}

SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
if (!areAllConstantIntValue(strides, 1))
return failure();

SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();

// check if we can support the propagation of this extractSlice
// through the generic op and if so return the dimensions that

auto maybeNonZeroSliceDimMap =
getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp);

if (failed(maybeNonZeroSliceDimMap)) {
return failure();
}

auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
Comment on lines +1377 to +1378
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional nit: Might be easier to follow if this just returns the dim map, and then we check if any reduction dims are present in the dim map afterwards.


// Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
Location loc = genericOp->getLoc();
AffineExpr dim0, dim1;
bindDims(rewriter.getContext(), dim0, dim1);
auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
{v1, v2});
};

MLIRContext *ctx = genericOp.getContext();
SmallVector<Value> paddedInputs;
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
paddedInputs.push_back(producerSliceOp.getSource());
continue;
}
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
if (!isa<AffineDimExpr>(expr)) {
continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use early continue to save nesting

SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
operandLowPads[idx] = sliceDimInfo.offset;
operandHighPads[idx] =
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
sliceDimInfo.sliceSize);
}
}
auto paddingValue = ub::PoisonOp::create(
rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
auto paddedOperand = tensor::PadOp::create(
rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
paddingValue, /*nofold=*/false);
paddedInputs.push_back(paddedOperand);
}
AffineMap outputIndexingMap =
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));

auto outputShapeType =
llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
outputShapeType.getShape(),
[&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
SmallVector<OpFoldResult> newSizes = OutputShape;
SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 0));
SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
getAsIndexOpFoldResult(ctx, 1));
for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
if (!isa<AffineDimExpr>(expr)) {
continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
outputLowPads[idx] = sliceDimInfo.offset;
outputHighPads[idx] =
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
sliceDimInfo.sliceSize);
OutputShape[idx] = sliceDimInfo.outputSize;
newSizes[idx] = sliceDimInfo.sliceSize;
}
}
Value newPadOutput;
auto outputElType =
getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
if (isGenericOutsNotUsed(genericOp)) {
newPadOutput =
tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);

} else {

auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
newPadOutput = tensor::PadOp::create(
rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
outputHighPads, paddingValue, /*nofold=*/false);
}

auto newGenericOp = linalg::GenericOp::create(
rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
newGenericOp.getRegion().begin());

auto extractOp = tensor::ExtractSliceOp::create(
rewriter, loc,
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
outputLowPads, newSizes, newStrides);
Value extractRes = extractOp.getResult();

return std::make_tuple(newGenericOp, extractRes);
}

class PushDownExtractSliceOpThroughGenericOp final
: public OpRewritePattern<GenericOp> {
public:
PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
ControlPropagationFn fun)
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
auto genericAndRepl =
pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
if (failed(genericAndRepl))
return failure();
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
return success();
}

private:
ControlPropagationFn controlFn;
};

} // namespace

void mlir::linalg::populateDataLayoutPropagationPatterns(
Expand All @@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
patterns.getContext(), controlPackUnPackPropagation);
}

void mlir::linalg::populateExtractSliceSinkingPatterns(
RewritePatternSet &patterns,
const ControlPropagationFn &controlPackUnPackPropagation) {
patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
patterns.getContext(), controlPackUnPackPropagation);
}
Loading