From 4eebe2174cc773b213a2f512b7405e14174c4714 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Fri, 8 Aug 2025 14:44:54 -0700 Subject: [PATCH] [Linalg] Add pattern to push down extract slice through generic Signed-off-by: Nirvedh Meshram --- .../Dialect/Linalg/Transforms/Transforms.h | 5 + .../Transforms/DataLayoutPropagation.cpp | 272 ++++++++++++++++++ .../Linalg/data-layout-propagation.mlir | 110 +++++++ .../Linalg/TestDataLayoutPropagation.cpp | 2 + 4 files changed, 389 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 8d5306dca43e3..680fdffa9e587 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns( RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation); +/// Patterns to sink extract slice across other operations. +void populateExtractSliceSinkingPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation); + /// Pattern to remove dead operands and results of `linalg.generic` operations. /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`. void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 0a9c1766425bd..d50ab8cf03714 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" #include "llvm/ADT/SetOperations.h" @@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { ControlPropagationFn controlFn; }; +// This struct contains infomation about extract_slice dims. +struct SliceDimInfo { + OpFoldResult offset; + OpFoldResult sliceSize; + OpFoldResult outputSize; +}; + +/// Return the first input extract slice operand, if present, for the current +/// generic op. +static FailureOr> +getSliceOperandAndIndex(GenericOp genericOp) { + OpOperand *sliceOperand = nullptr; + unsigned operandIndex; + for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { + auto extractOp = operand->get().getDefiningOp(); + if (!extractOp) + continue; + sliceOperand = operand; + operandIndex = idx; + break; + } + if (!sliceOperand) { + return failure(); + } + return std::make_tuple(sliceOperand, operandIndex); +} + +// Return a map of dims that have non full slices on them so that other operands +// can use this information. Also return a bool mentioning if a reduction dim +// has a non full slice as that can be used to fold the original extract slice. +static FailureOr, bool>> +getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand, + tensor::ExtractSliceOp producerSliceOp) { + llvm::DenseMap nonZeroSliceDimMap; + bool hasNonZeroReductionDimSlice = false; + SmallVector iterators = + genericOp.getIteratorTypesArray(); + SmallVector offsets = producerSliceOp.getMixedOffsets(); + SmallVector sizes = producerSliceOp.getMixedSizes(); + + SmallVector shape = llvm::map_to_vector( + producerSliceOp.getSourceType().getShape(), + [&](int64_t sz) -> OpFoldResult { + return getAsIndexOpFoldResult(genericOp.getContext(), sz); + }); + + for (auto [idx, expr] : llvm::enumerate( + genericOp.getMatchingIndexingMap(sliceOperand).getResults())) { + if (isConstantIntValue(offsets[idx], 0) && + isEqualConstantIntOrValue(sizes[idx], shape[idx])) { + continue; + } + if (!isa(expr)) { + return failure(); + } + SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]}; + int64_t dimPos = cast(expr).getPosition(); + nonZeroSliceDimMap[dimPos] = sliceDimInfo; + if (iterators[dimPos] == utils::IteratorType::reduction) { + hasNonZeroReductionDimSlice = true; + } + } + // Next check if the dims with non zero slice info are used as non + // AffineDimExpr and if they are then bail-out. + for (OpOperand &operand : genericOp->getOpOperands()) { + if (operand == *sliceOperand) { + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand); + if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) { + if (isa(expr)) { + return false; + } + WalkResult status = expr.walk([&](AffineExpr expr) { + if (auto dimExpr = dyn_cast(expr)) { + if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + if (status.wasInterrupted()) { + return true; + } + return false; + })) { + return failure(); + } + } + return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice); +} + +static FailureOr> +pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, + GenericOp genericOp, + ControlPropagationFn controlFn) { + if (genericOp.getNumResults() != 1) + return failure(); + if (hasGatherSemantics(genericOp)) + return failure(); + // Collect the unPacked operand, if present. + auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp); + if (failed(maybeSliceOperandAndIndex)) + return failure(); + OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex); + unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex); + + if (!controlFn(sliceOperand)) + return failure(); + + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp(); + assert(producerSliceOp && "expect a valid UnPackOp"); + + if (producerSliceOp.getSource().getType().getRank() != + producerSliceOp.getResult().getType().getRank()) { + return failure(); + } + + SmallVector strides = producerSliceOp.getMixedStrides(); + if (!areAllConstantIntValue(strides, 1)) + return failure(); + + SmallVector offsets = producerSliceOp.getMixedOffsets(); + SmallVector sizes = producerSliceOp.getMixedSizes(); + + // check if we can support the propagation of this extractSlice + // through the generic op and if so return the dimensions that + + auto maybeNonZeroSliceDimMap = + getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp); + + if (failed(maybeNonZeroSliceDimMap)) { + return failure(); + } + + auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap); + bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap); + + // Store the padding information as (dimPos, lowPad, highPad, PaddedShape). + Location loc = genericOp->getLoc(); + AffineExpr dim0, dim1; + bindDims(rewriter.getContext(), dim0, dim1); + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap, + {v1, v2}); + }; + + MLIRContext *ctx = genericOp.getContext(); + SmallVector paddedInputs; + for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { + if (idx == OperandIndex && !hasNonZeroReductionDimSlice) { + paddedInputs.push_back(producerSliceOp.getSource()); + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand); + SmallVector operandLowPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector operandHighPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) { + if (!isa(expr)) { + continue; + } + AffineDimExpr dimExpr = cast(expr); + if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) { + SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()]; + operandLowPads[idx] = sliceDimInfo.offset; + operandHighPads[idx] = + sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + } + } + auto paddingValue = ub::PoisonOp::create( + rewriter, loc, getElementTypeOrSelf(operand->get().getType())); + auto paddedOperand = tensor::PadOp::create( + rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads, + paddingValue, /*nofold=*/false); + paddedInputs.push_back(paddedOperand); + } + AffineMap outputIndexingMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); + + auto outputShapeType = + llvm::cast(genericOp.getDpsInitOperand(0)->get().getType()); + SmallVector OutputShape = llvm::map_to_vector( + outputShapeType.getShape(), + [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); }); + SmallVector newSizes = OutputShape; + SmallVector outputLowPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector outputHighPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector newStrides(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 1)); + for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) { + if (!isa(expr)) { + continue; + } + AffineDimExpr dimExpr = cast(expr); + if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) { + SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()]; + outputLowPads[idx] = sliceDimInfo.offset; + outputHighPads[idx] = + sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + OutputShape[idx] = sliceDimInfo.outputSize; + newSizes[idx] = sliceDimInfo.sliceSize; + } + } + Value newPadOutput; + auto outputElType = + getElementTypeOrSelf(genericOp.getDpsInits()[0].getType()); + if (isGenericOutsNotUsed(genericOp)) { + newPadOutput = + tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType); + + } else { + + auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType); + newPadOutput = tensor::PadOp::create( + rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads, + outputHighPads, paddingValue, /*nofold=*/false); + } + + auto newGenericOp = linalg::GenericOp::create( + rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput}, + genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), + newGenericOp.getRegion().begin()); + + auto extractOp = tensor::ExtractSliceOp::create( + rewriter, loc, + newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), + outputLowPads, newSizes, newStrides); + Value extractRes = extractOp.getResult(); + + return std::make_tuple(newGenericOp, extractRes); +} + +class PushDownExtractSliceOpThroughGenericOp final + : public OpRewritePattern { +public: + PushDownExtractSliceOpThroughGenericOp(MLIRContext *context, + ControlPropagationFn fun) + : OpRewritePattern(context), controlFn(std::move(fun)) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + auto genericAndRepl = + pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn); + if (failed(genericAndRepl)) + return failure(); + rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); + return success(); + } + +private: + ControlPropagationFn controlFn; +}; + } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( @@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns( PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( patterns.getContext(), controlPackUnPackPropagation); } + +void mlir::linalg::populateExtractSliceSinkingPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation) { + patterns.insert( + patterns.getContext(), controlPackUnPackPropagation); +} diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index cc26fa48abf4b..723eecb52351b 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar // CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]] // CHECK-SAME: into %[[ARG1]] // CHECK: return %[[UNPACK2]] : tensor + +// ----- + +module { + func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor, %arg2: tensor, %arg3: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor + } +} + +// CHECK-LABEL: func.func @push_extract_through_generic +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +// CHECK: %[[POISON:.+]] = ub.poison : f32 +// CHECK: %[[PADDED:.+]] = tensor.pad %arg1 +// CHECK: tensor.yield %[[POISON]] : f32 +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor +// CHECK: return %[[EXTRACT]] + +// ----- + +func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor, %arg2: tensor, %arg3: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1 +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] + +// ----- + +func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> { + %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor<128x?x128xbf16> + return %0 : tensor<128x?x128xbf16> +} + +// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2 +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] + +// ----- + +func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor, %arg2: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + %2 = arith.addf %1, %out : bf16 + linalg.yield %2 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[POISON_BF16:.+]] = ub.poison : bf16 +// CHECK: %[[POISON_F32:.+]] = ub.poison : f32 +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor +// CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]] +// CHECK: tensor.yield %[[POISON_F32]] : f32 +// CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]] +// CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]] +// CHECK: tensor.yield %[[POISON_BF16]] : bf16 +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[PADDED]] +// CHECK-SAME: outs(%[[PADDED1]] +// CHECK: %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor to tensor +// CHECK: return %[[EXTRACT1]] + + +// ----- + +func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor, %arg2: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + %2 = arith.addf %1, %out : bf16 + linalg.yield %2 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @nopush_rankreducingextract +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp index d0700f9a4f1a4..2cf25d8fc8c19 100644 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -34,6 +34,8 @@ struct TestDataLayoutPropagationPass RewritePatternSet patterns(context); linalg::populateDataLayoutPropagationPatterns( patterns, [](OpOperand *opOperand) { return true; }); + linalg::populateExtractSliceSinkingPatterns( + patterns, [](OpOperand *opOperand) { return true; }); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); }