-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[Linalg] Add pattern to push down extract slice through linalg generic op #154162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,10 +6,12 @@ | |||||
// | ||||||
//===----------------------------------------------------------------------===// | ||||||
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||||||
#include "mlir/Dialect/Linalg/Utils/Utils.h" | ||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||||||
#include "mlir/Dialect/UB/IR/UBOps.h" | ||||||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||||||
#include "mlir/IR/Dominance.h" | ||||||
#include "llvm/ADT/SetOperations.h" | ||||||
|
@@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { | |||||
ControlPropagationFn controlFn; | ||||||
}; | ||||||
|
||||||
// This struct contains infomation about extract_slice dims. | ||||||
struct SliceDimInfo { | ||||||
OpFoldResult offset; | ||||||
OpFoldResult sliceSize; | ||||||
OpFoldResult outputSize; | ||||||
}; | ||||||
|
||||||
/// Return the first input extract slice operand, if present, for the current | ||||||
/// generic op. | ||||||
static FailureOr<std::tuple<OpOperand *, unsigned>> | ||||||
getSliceOperandAndIndex(GenericOp genericOp) { | ||||||
OpOperand *sliceOperand = nullptr; | ||||||
unsigned operandIndex; | ||||||
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { | ||||||
auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>(); | ||||||
if (!extractOp) | ||||||
continue; | ||||||
sliceOperand = operand; | ||||||
operandIndex = idx; | ||||||
break; | ||||||
} | ||||||
if (!sliceOperand) { | ||||||
return failure(); | ||||||
} | ||||||
return std::make_tuple(sliceOperand, operandIndex); | ||||||
} | ||||||
|
||||||
// Return a map of dims that have non full slices on them so that other operands | ||||||
// can use this information. Also return a bool mentioning if a reduction dim | ||||||
// has a non full slice as that can be used to fold the original extract slice. | ||||||
static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>> | ||||||
getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: you dont need to pass the |
||||||
tensor::ExtractSliceOp producerSliceOp) { | ||||||
llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I was confused by |
||||||
bool hasNonZeroReductionDimSlice = false; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with this. |
||||||
SmallVector<utils::IteratorType> iterators = | ||||||
genericOp.getIteratorTypesArray(); | ||||||
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets(); | ||||||
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes(); | ||||||
|
||||||
SmallVector<OpFoldResult> shape = llvm::map_to_vector( | ||||||
producerSliceOp.getSourceType().getShape(), | ||||||
[&](int64_t sz) -> OpFoldResult { | ||||||
return getAsIndexOpFoldResult(genericOp.getContext(), sz); | ||||||
}); | ||||||
Comment on lines
+1281
to
+1285
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can pass an ArrayRef to |
||||||
|
||||||
for (auto [idx, expr] : llvm::enumerate( | ||||||
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) { | ||||||
if (isConstantIntValue(offsets[idx], 0) && | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add comments as to what each of these conditions are checking? |
||||||
isEqualConstantIntOrValue(sizes[idx], shape[idx])) { | ||||||
continue; | ||||||
} | ||||||
if (!isa<AffineDimExpr>(expr)) { | ||||||
return failure(); | ||||||
} | ||||||
SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]}; | ||||||
int64_t dimPos = cast<AffineDimExpr>(expr).getPosition(); | ||||||
nonZeroSliceDimMap[dimPos] = sliceDimInfo; | ||||||
if (iterators[dimPos] == utils::IteratorType::reduction) { | ||||||
hasNonZeroReductionDimSlice = true; | ||||||
} | ||||||
} | ||||||
// Next check if the dims with non zero slice info are used as non | ||||||
// AffineDimExpr and if they are then bail-out. | ||||||
for (OpOperand &operand : genericOp->getOpOperands()) { | ||||||
if (operand == *sliceOperand) { | ||||||
continue; | ||||||
} | ||||||
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand); | ||||||
if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) { | ||||||
if (isa<AffineDimExpr>(expr)) { | ||||||
return false; | ||||||
} | ||||||
WalkResult status = expr.walk([&](AffineExpr expr) { | ||||||
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { | ||||||
if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) { | ||||||
return WalkResult::interrupt(); | ||||||
Comment on lines
+1316
to
+1317
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Use |
||||||
} | ||||||
} | ||||||
return WalkResult::advance(); | ||||||
}); | ||||||
if (status.wasInterrupted()) { | ||||||
return true; | ||||||
} | ||||||
return false; | ||||||
})) { | ||||||
return failure(); | ||||||
} | ||||||
} | ||||||
return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice); | ||||||
} | ||||||
|
||||||
static FailureOr<std::tuple<GenericOp, Value>> | ||||||
pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, | ||||||
GenericOp genericOp, | ||||||
ControlPropagationFn controlFn) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: using |
||||||
if (genericOp.getNumResults() != 1) | ||||||
return failure(); | ||||||
if (hasGatherSemantics(genericOp)) | ||||||
return failure(); | ||||||
// Collect the unPacked operand, if present. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp); | ||||||
if (failed(maybeSliceOperandAndIndex)) | ||||||
return failure(); | ||||||
OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex); | ||||||
unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex); | ||||||
|
||||||
if (!controlFn(sliceOperand)) | ||||||
return failure(); | ||||||
|
||||||
tensor::ExtractSliceOp producerSliceOp = | ||||||
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); | ||||||
assert(producerSliceOp && "expect a valid UnPackOp"); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
if (producerSliceOp.getSource().getType().getRank() != | ||||||
producerSliceOp.getResult().getType().getRank()) { | ||||||
return failure(); | ||||||
} | ||||||
|
||||||
SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides(); | ||||||
if (!areAllConstantIntValue(strides, 1)) | ||||||
return failure(); | ||||||
|
||||||
SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets(); | ||||||
SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes(); | ||||||
|
||||||
// check if we can support the propagation of this extractSlice | ||||||
// through the generic op and if so return the dimensions that | ||||||
|
||||||
auto maybeNonZeroSliceDimMap = | ||||||
getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp); | ||||||
|
||||||
if (failed(maybeNonZeroSliceDimMap)) { | ||||||
return failure(); | ||||||
} | ||||||
|
||||||
auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap); | ||||||
bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap); | ||||||
Comment on lines
+1377
to
+1378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. optional nit: Might be easier to follow if this just returns the dim map, and then we check if any reduction dims are present in the dim map afterwards. |
||||||
|
||||||
// Store the padding information as (dimPos, lowPad, highPad, PaddedShape). | ||||||
Location loc = genericOp->getLoc(); | ||||||
AffineExpr dim0, dim1; | ||||||
bindDims(rewriter.getContext(), dim0, dim1); | ||||||
auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); | ||||||
auto sub = [&](OpFoldResult v1, OpFoldResult v2) { | ||||||
return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap, | ||||||
{v1, v2}); | ||||||
}; | ||||||
|
||||||
MLIRContext *ctx = genericOp.getContext(); | ||||||
SmallVector<Value> paddedInputs; | ||||||
for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { | ||||||
if (idx == OperandIndex && !hasNonZeroReductionDimSlice) { | ||||||
paddedInputs.push_back(producerSliceOp.getSource()); | ||||||
continue; | ||||||
} | ||||||
AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand); | ||||||
SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(), | ||||||
getAsIndexOpFoldResult(ctx, 0)); | ||||||
SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(), | ||||||
getAsIndexOpFoldResult(ctx, 0)); | ||||||
for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) { | ||||||
if (!isa<AffineDimExpr>(expr)) { | ||||||
continue; | ||||||
} | ||||||
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr); | ||||||
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Use early |
||||||
SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()]; | ||||||
operandLowPads[idx] = sliceDimInfo.offset; | ||||||
operandHighPads[idx] = | ||||||
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), | ||||||
sliceDimInfo.sliceSize); | ||||||
} | ||||||
} | ||||||
auto paddingValue = ub::PoisonOp::create( | ||||||
rewriter, loc, getElementTypeOrSelf(operand->get().getType())); | ||||||
auto paddedOperand = tensor::PadOp::create( | ||||||
rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads, | ||||||
paddingValue, /*nofold=*/false); | ||||||
paddedInputs.push_back(paddedOperand); | ||||||
} | ||||||
AffineMap outputIndexingMap = | ||||||
genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); | ||||||
|
||||||
auto outputShapeType = | ||||||
llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType()); | ||||||
SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector( | ||||||
outputShapeType.getShape(), | ||||||
[&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); }); | ||||||
SmallVector<OpFoldResult> newSizes = OutputShape; | ||||||
SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(), | ||||||
getAsIndexOpFoldResult(ctx, 0)); | ||||||
SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(), | ||||||
getAsIndexOpFoldResult(ctx, 0)); | ||||||
SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(), | ||||||
getAsIndexOpFoldResult(ctx, 1)); | ||||||
for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) { | ||||||
if (!isa<AffineDimExpr>(expr)) { | ||||||
continue; | ||||||
} | ||||||
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr); | ||||||
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) { | ||||||
SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()]; | ||||||
outputLowPads[idx] = sliceDimInfo.offset; | ||||||
outputHighPads[idx] = | ||||||
sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), | ||||||
sliceDimInfo.sliceSize); | ||||||
OutputShape[idx] = sliceDimInfo.outputSize; | ||||||
newSizes[idx] = sliceDimInfo.sliceSize; | ||||||
} | ||||||
} | ||||||
Value newPadOutput; | ||||||
auto outputElType = | ||||||
getElementTypeOrSelf(genericOp.getDpsInits()[0].getType()); | ||||||
if (isGenericOutsNotUsed(genericOp)) { | ||||||
newPadOutput = | ||||||
tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType); | ||||||
|
||||||
} else { | ||||||
|
||||||
auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType); | ||||||
newPadOutput = tensor::PadOp::create( | ||||||
rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads, | ||||||
outputHighPads, paddingValue, /*nofold=*/false); | ||||||
} | ||||||
|
||||||
auto newGenericOp = linalg::GenericOp::create( | ||||||
rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput}, | ||||||
genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(), | ||||||
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); | ||||||
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), | ||||||
newGenericOp.getRegion().begin()); | ||||||
|
||||||
auto extractOp = tensor::ExtractSliceOp::create( | ||||||
rewriter, loc, | ||||||
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), | ||||||
outputLowPads, newSizes, newStrides); | ||||||
Value extractRes = extractOp.getResult(); | ||||||
|
||||||
return std::make_tuple(newGenericOp, extractRes); | ||||||
} | ||||||
|
||||||
class PushDownExtractSliceOpThroughGenericOp final | ||||||
: public OpRewritePattern<GenericOp> { | ||||||
public: | ||||||
PushDownExtractSliceOpThroughGenericOp(MLIRContext *context, | ||||||
ControlPropagationFn fun) | ||||||
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} | ||||||
|
||||||
LogicalResult matchAndRewrite(GenericOp genericOp, | ||||||
PatternRewriter &rewriter) const override { | ||||||
auto genericAndRepl = | ||||||
pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn); | ||||||
if (failed(genericAndRepl)) | ||||||
return failure(); | ||||||
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); | ||||||
return success(); | ||||||
} | ||||||
|
||||||
private: | ||||||
ControlPropagationFn controlFn; | ||||||
}; | ||||||
|
||||||
} // namespace | ||||||
|
||||||
void mlir::linalg::populateDataLayoutPropagationPatterns( | ||||||
|
@@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns( | |||||
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( | ||||||
patterns.getContext(), controlPackUnPackPropagation); | ||||||
} | ||||||
|
||||||
void mlir::linalg::populateExtractSliceSinkingPatterns( | ||||||
RewritePatternSet &patterns, | ||||||
const ControlPropagationFn &controlPackUnPackPropagation) { | ||||||
patterns.insert<PushDownExtractSliceOpThroughGenericOp>( | ||||||
patterns.getContext(), controlPackUnPackPropagation); | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think you need to return the
operandIdx
really, this is already part ofOpOperand *
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 (
sliceOperand->getOperandNumber()
)