Skip to content

[Linalg] Add pattern to push down extract slice through linalg generic op #154162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

nirvedhmeshram
Copy link
Contributor

This PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling.

@llvmbot
Copy link
Member

llvmbot commented Aug 18, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

This PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling.


Patch is 20.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154162.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+269)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+110)
  • (modified) mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a91fcfe..046920f5ccd54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1914,6 +1914,11 @@ void populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation);
 
+/// Patterns to bubble up or down extract slice across other operations.
+void populateExtractSlicePropagationPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation);
+
 /// Pattern to remove dead operands and results of `linalg.generic` operations.
 /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
 void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..16d6ac23b0208 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,266 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
   ControlPropagationFn controlFn;
 };
 
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+  OpFoldResult offset;
+  OpFoldResult sliceSize;
+  OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+  OpOperand *sliceOperand = nullptr;
+  unsigned operandIndex;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      continue;
+    sliceOperand = operand;
+    operandIndex = idx;
+    break;
+  }
+  if (!sliceOperand) {
+    return failure();
+  }
+  return std::make_tuple(sliceOperand, operandIndex);
+}
+
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonZeroSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+                       tensor::ExtractSliceOp producerSliceOp) {
+  llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+  bool hasNonZeroReductionDimSlice = false;
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+      producerSliceOp.getSourceType().getShape(),
+      [&](int64_t sz) -> OpFoldResult {
+        return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+      });
+
+  for (auto [idx, expr] : llvm::enumerate(
+           genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+    if (isConstantIntValue(offsets[idx], 0) &&
+        isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+      continue;
+    }
+    if (!isa<AffineDimExpr>(expr)) {
+      return failure();
+    }
+    SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+    int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+    nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+    if (iterators[dimPos] == utils::IteratorType::reduction) {
+      hasNonZeroReductionDimSlice = true;
+    }
+  }
+  // Next check if the dims with non zero slice info are used as non
+  // AffineDimExpr and if they are then bail-out.
+  for (OpOperand &operand : genericOp->getOpOperands()) {
+    if (operand == *sliceOperand) {
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+    if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+          if (isa<AffineDimExpr>(expr)) {
+            return false;
+          }
+          WalkResult status = expr.walk([&](AffineExpr expr) {
+            if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+              if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+                return WalkResult::interrupt();
+              }
+            }
+            return WalkResult::advance();
+          });
+          if (status.wasInterrupted()) {
+            return true;
+          }
+          return false;
+        })) {
+      return failure();
+    }
+  }
+  return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+                                       GenericOp genericOp,
+                                       ControlPropagationFn controlFn) {
+  if (genericOp.getNumResults() != 1)
+    return failure();
+  if (hasGatherSemantics(genericOp))
+    return failure();
+  // Collect the unPacked operand, if present.
+  auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+  if (failed(maybeSliceOperandAndIndex))
+    return failure();
+  OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+  unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+  if (!controlFn(sliceOperand))
+    return failure();
+
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid UnPackOp");
+
+  if (producerSliceOp.getSource().getType().getRank() !=
+      producerSliceOp.getResult().getType().getRank()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+  if (!areAllConstantIntValue(strides, 1))
+    return failure();
+
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  // check if we can support the propagation of this extractSlice
+  // through the generic op and if so return the dimensions that
+
+  auto maybeNonZeroSliceDimMap =
+      getNonZeroSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+  if (failed(maybeNonZeroSliceDimMap)) {
+    return failure();
+  }
+
+  auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+  bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+  Location loc = genericOp->getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(rewriter.getContext(), dim0, dim1);
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+                                                 {v1, v2});
+  };
+
+  MLIRContext *ctx = genericOp.getContext();
+  SmallVector<Value> paddedInputs;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+      paddedInputs.push_back(producerSliceOp.getSource());
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+    SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+                                             getAsIndexOpFoldResult(ctx, 0));
+    SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+                                              getAsIndexOpFoldResult(ctx, 0));
+    for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+      if (!isa<AffineDimExpr>(expr)) {
+        continue;
+      }
+      AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+      if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+        SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+        operandLowPads[idx] = sliceDimInfo.offset;
+        operandHighPads[idx] =
+            sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+                sliceDimInfo.sliceSize);
+      }
+    }
+    auto paddingValue = ub::PoisonOp::create(
+        rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+    auto paddedOperand = tensor::PadOp::create(
+        rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+        paddingValue, /*nofold=*/false);
+    paddedInputs.push_back(paddedOperand);
+  }
+  AffineMap outputIndexingMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+  auto outputShapeType =
+      llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+  SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+      outputShapeType.getShape(),
+      [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+  SmallVector<OpFoldResult> newSizes = OutputShape;
+  SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+                                          getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+                                           getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+                                       getAsIndexOpFoldResult(ctx, 1));
+  for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+    if (!isa<AffineDimExpr>(expr)) {
+      continue;
+    }
+    AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+    if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+      SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+      outputLowPads[idx] = sliceDimInfo.offset;
+      outputHighPads[idx] =
+          sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+              sliceDimInfo.sliceSize);
+      OutputShape[idx] = sliceDimInfo.outputSize;
+      newSizes[idx] = sliceDimInfo.sliceSize;
+    }
+  }
+  Value newPadOutput;
+  auto outputElType =
+      getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+  if (isGenericOutsNotUsed(genericOp)) {
+    newPadOutput =
+        tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+
+  } else {
+
+    auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+    newPadOutput = tensor::PadOp::create(
+        rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+        outputHighPads, paddingValue, /*nofold=*/false);
+  }
+
+  auto newGenericOp = linalg::GenericOp::create(
+      rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+      genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+                             newGenericOp.getRegion().begin());
+
+  auto extractOp = tensor::ExtractSliceOp::create(
+      rewriter, loc,
+      newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+      outputLowPads, newSizes, newStrides);
+  Value extractRes = extractOp.getResult();
+
+  return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+    : public OpRewritePattern<GenericOp> {
+public:
+  PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+                                         ControlPropagationFn fun)
+      : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    auto genericAndRepl =
+        pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+    if (failed(genericAndRepl))
+      return failure();
+    rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1509,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
           patterns.getContext(), controlPackUnPackPropagation);
 }
+
+void mlir::linalg::populateExtractSlicePropagationPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation) {
+  patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+      patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..723eecb52351b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 // CHECK:         %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG1]]
 // CHECK:         return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+  func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+    ^bb0(%in: f32, %in_0: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x5x128xbf16>
+    return %0 : tensor<?x5x128xbf16>
+  }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON:.+]] = ub.poison : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK:           tensor.yield %[[POISON]] : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
+// CHECK-SAME:    outs(%[[EMPTY]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK:         return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x5x128xbf16>
+  return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]          
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<128x?x128xbf16>
+  return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK:         %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           tensor.yield %[[POISON_F32]] : f32
+// CHECK:         %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK:         %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK:           tensor.yield %[[POISON_BF16]] : bf16
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[PADDED]]
+// CHECK-SAME:    outs(%[[PADDED1]]
+// CHECK:         %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK:         return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..449d28fc528b1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 18, 2025

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

Changes

This PR adds a datalayout propagation pattern to push down extract slice through generic op. It adds a different populate function since there may be conditions where a user doesn't want this pattern but wants the other patterns e.g. extract slice is used as a special op when it comes to tiling.


Patch is 20.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154162.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+269)
  • (modified) mlir/test/Dialect/Linalg/data-layout-propagation.mlir (+110)
  • (modified) mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index d4ffe0a91fcfe..046920f5ccd54 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1914,6 +1914,11 @@ void populateDataLayoutPropagationPatterns(
     RewritePatternSet &patterns,
     const ControlPropagationFn &controlPackUnPackPropagation);
 
+/// Patterns to bubble up or down extract slice across other operations.
+void populateExtractSlicePropagationPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation);
+
 /// Pattern to remove dead operands and results of `linalg.generic` operations.
 /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`.
 void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 0a9c1766425bd..16d6ac23b0208 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -6,10 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "llvm/ADT/SetOperations.h"
@@ -1236,6 +1238,266 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
   ControlPropagationFn controlFn;
 };
 
+// This struct contains infomation about extract_slice dims.
+struct SliceDimInfo {
+  OpFoldResult offset;
+  OpFoldResult sliceSize;
+  OpFoldResult outputSize;
+};
+
+/// Return the first input extract slice operand, if present, for the current
+/// generic op.
+static FailureOr<std::tuple<OpOperand *, unsigned>>
+getSliceOperandAndIndex(GenericOp genericOp) {
+  OpOperand *sliceOperand = nullptr;
+  unsigned operandIndex;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
+    if (!extractOp)
+      continue;
+    sliceOperand = operand;
+    operandIndex = idx;
+    break;
+  }
+  if (!sliceOperand) {
+    return failure();
+  }
+  return std::make_tuple(sliceOperand, operandIndex);
+}
+
+static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
+getNonZeroSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
+                       tensor::ExtractSliceOp producerSliceOp) {
+  llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
+  bool hasNonZeroReductionDimSlice = false;
+  SmallVector<utils::IteratorType> iterators =
+      genericOp.getIteratorTypesArray();
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  SmallVector<OpFoldResult> shape = llvm::map_to_vector(
+      producerSliceOp.getSourceType().getShape(),
+      [&](int64_t sz) -> OpFoldResult {
+        return getAsIndexOpFoldResult(genericOp.getContext(), sz);
+      });
+
+  for (auto [idx, expr] : llvm::enumerate(
+           genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
+    if (isConstantIntValue(offsets[idx], 0) &&
+        isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
+      continue;
+    }
+    if (!isa<AffineDimExpr>(expr)) {
+      return failure();
+    }
+    SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
+    int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
+    nonZeroSliceDimMap[dimPos] = sliceDimInfo;
+    if (iterators[dimPos] == utils::IteratorType::reduction) {
+      hasNonZeroReductionDimSlice = true;
+    }
+  }
+  // Next check if the dims with non zero slice info are used as non
+  // AffineDimExpr and if they are then bail-out.
+  for (OpOperand &operand : genericOp->getOpOperands()) {
+    if (operand == *sliceOperand) {
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
+    if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
+          if (isa<AffineDimExpr>(expr)) {
+            return false;
+          }
+          WalkResult status = expr.walk([&](AffineExpr expr) {
+            if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
+              if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
+                return WalkResult::interrupt();
+              }
+            }
+            return WalkResult::advance();
+          });
+          if (status.wasInterrupted()) {
+            return true;
+          }
+          return false;
+        })) {
+      return failure();
+    }
+  }
+  return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice);
+}
+
+static FailureOr<std::tuple<GenericOp, Value>>
+pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
+                                       GenericOp genericOp,
+                                       ControlPropagationFn controlFn) {
+  if (genericOp.getNumResults() != 1)
+    return failure();
+  if (hasGatherSemantics(genericOp))
+    return failure();
+  // Collect the unPacked operand, if present.
+  auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp);
+  if (failed(maybeSliceOperandAndIndex))
+    return failure();
+  OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex);
+  unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex);
+
+  if (!controlFn(sliceOperand))
+    return failure();
+
+  tensor::ExtractSliceOp producerSliceOp =
+      sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
+  assert(producerSliceOp && "expect a valid UnPackOp");
+
+  if (producerSliceOp.getSource().getType().getRank() !=
+      producerSliceOp.getResult().getType().getRank()) {
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
+  if (!areAllConstantIntValue(strides, 1))
+    return failure();
+
+  SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
+
+  // check if we can support the propagation of this extractSlice
+  // through the generic op and if so return the dimensions that
+
+  auto maybeNonZeroSliceDimMap =
+      getNonZeroSliceDimInfo(genericOp, sliceOperand, producerSliceOp);
+
+  if (failed(maybeNonZeroSliceDimMap)) {
+    return failure();
+  }
+
+  auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
+  bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
+
+  // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
+  Location loc = genericOp->getLoc();
+  AffineExpr dim0, dim1;
+  bindDims(rewriter.getContext(), dim0, dim1);
+  auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
+  auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
+    return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
+                                                 {v1, v2});
+  };
+
+  MLIRContext *ctx = genericOp.getContext();
+  SmallVector<Value> paddedInputs;
+  for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
+    if (idx == OperandIndex && !hasNonZeroReductionDimSlice) {
+      paddedInputs.push_back(producerSliceOp.getSource());
+      continue;
+    }
+    AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
+    SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
+                                             getAsIndexOpFoldResult(ctx, 0));
+    SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
+                                              getAsIndexOpFoldResult(ctx, 0));
+    for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
+      if (!isa<AffineDimExpr>(expr)) {
+        continue;
+      }
+      AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+      if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+        SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+        operandLowPads[idx] = sliceDimInfo.offset;
+        operandHighPads[idx] =
+            sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+                sliceDimInfo.sliceSize);
+      }
+    }
+    auto paddingValue = ub::PoisonOp::create(
+        rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
+    auto paddedOperand = tensor::PadOp::create(
+        rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
+        paddingValue, /*nofold=*/false);
+    paddedInputs.push_back(paddedOperand);
+  }
+  AffineMap outputIndexingMap =
+      genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
+
+  auto outputShapeType =
+      llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
+  SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
+      outputShapeType.getShape(),
+      [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
+  SmallVector<OpFoldResult> newSizes = OutputShape;
+  SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
+                                          getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
+                                           getAsIndexOpFoldResult(ctx, 0));
+  SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
+                                       getAsIndexOpFoldResult(ctx, 1));
+  for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
+    if (!isa<AffineDimExpr>(expr)) {
+      continue;
+    }
+    AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
+    if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
+      SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()];
+      outputLowPads[idx] = sliceDimInfo.offset;
+      outputHighPads[idx] =
+          sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
+              sliceDimInfo.sliceSize);
+      OutputShape[idx] = sliceDimInfo.outputSize;
+      newSizes[idx] = sliceDimInfo.sliceSize;
+    }
+  }
+  Value newPadOutput;
+  auto outputElType =
+      getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
+  if (isGenericOutsNotUsed(genericOp)) {
+    newPadOutput =
+        tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
+
+  } else {
+
+    auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
+    newPadOutput = tensor::PadOp::create(
+        rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
+        outputHighPads, paddingValue, /*nofold=*/false);
+  }
+
+  auto newGenericOp = linalg::GenericOp::create(
+      rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
+      genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
+      /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
+  rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
+                             newGenericOp.getRegion().begin());
+
+  auto extractOp = tensor::ExtractSliceOp::create(
+      rewriter, loc,
+      newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
+      outputLowPads, newSizes, newStrides);
+  Value extractRes = extractOp.getResult();
+
+  return std::make_tuple(newGenericOp, extractRes);
+}
+
+class PushDownExtractSliceOpThroughGenericOp final
+    : public OpRewritePattern<GenericOp> {
+public:
+  PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
+                                         ControlPropagationFn fun)
+      : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(GenericOp genericOp,
+                                PatternRewriter &rewriter) const override {
+    auto genericAndRepl =
+        pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
+    if (failed(genericAndRepl))
+      return failure();
+    rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
+    return success();
+  }
+
+private:
+  ControlPropagationFn controlFn;
+};
+
 } // namespace
 
 void mlir::linalg::populateDataLayoutPropagationPatterns(
@@ -1247,3 +1509,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns(
               PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
           patterns.getContext(), controlPackUnPackPropagation);
 }
+
+void mlir::linalg::populateExtractSlicePropagationPatterns(
+    RewritePatternSet &patterns,
+    const ControlPropagationFn &controlPackUnPackPropagation) {
+  patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
+      patterns.getContext(), controlPackUnPackPropagation);
+}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index cc26fa48abf4b..723eecb52351b 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar
 // CHECK:         %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]]
 // CHECK-SAME:    into %[[ARG1]]
 // CHECK:         return %[[UNPACK2]] : tensor<?x64xf32>
+
+// -----
+
+module {
+  func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+    %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+    %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+    ^bb0(%in: f32, %in_0: f32, %out: bf16):
+      %1 = arith.truncf %in : f32 to bf16
+      linalg.yield %1 : bf16
+    } -> tensor<?x5x128xbf16>
+    return %0 : tensor<?x5x128xbf16>
+  }
+}
+
+// CHECK-LABEL: func.func @push_extract_through_generic
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG3:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON:.+]] = ub.poison : f32
+// CHECK:         %[[PADDED:.+]] = tensor.pad %arg1
+// CHECK:           tensor.yield %[[POISON]] : f32
+// CHECK:         %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16>
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[ARG0]], %[[PADDED]]   
+// CHECK-SAME:    outs(%[[EMPTY]]
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor<?x5x128xbf16>
+// CHECK:         return %[[EXTRACT]]
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor<?x5x3x128xf32>, %arg2: tensor<?x5x128xbf16>, %arg3: index) -> tensor<?x5x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor<?x5x3x128xf32>) outs(%arg2 : tensor<?x5x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<?x5x128xbf16>
+  return %0 : tensor<?x5x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]          
+
+// -----
+
+func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> {
+  %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) {
+  ^bb0(%in: f32, %in_0: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    linalg.yield %1 : bf16
+  } -> tensor<128x?x128xbf16>
+  return %0 : tensor<128x?x128xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
+
+// -----
+
+func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK:         %[[POISON_BF16:.+]] = ub.poison : bf16
+// CHECK:         %[[POISON_F32:.+]] = ub.poison : f32
+// CHECK:         %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor<?x?xf32>
+// CHECK:         %[[PADDED:.+]] = tensor.pad %[[EXTRACT]]
+// CHECK:           tensor.yield %[[POISON_F32]] : f32
+// CHECK:         %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]]
+// CHECK:         %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]]
+// CHECK:           tensor.yield %[[POISON_BF16]] : bf16
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:    ins(%[[PADDED]]
+// CHECK-SAME:    outs(%[[PADDED1]]
+// CHECK:         %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor<?xbf16> to tensor<?xbf16>
+// CHECK:         return %[[EXTRACT1]]
+
+
+// -----
+
+func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor<?xbf16>, %arg2: index) -> tensor<?xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor<?x?xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor<?x?xf32>) outs(%arg1 : tensor<?xbf16>) {
+  ^bb0(%in: f32, %out: bf16):
+    %1 = arith.truncf %in : f32 to bf16
+    %2 = arith.addf %1, %out : bf16
+    linalg.yield %2 : bf16
+  } -> tensor<?xbf16>
+  return %0 : tensor<?xbf16>
+}
+
+// CHECK-LABEL: func.func @nopush_rankreducingextract
+// CHECK:         %[[GENERIC:.+]] = linalg.generic
+// CHECK:         return %[[GENERIC]]   
diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
index d0700f9a4f1a4..449d28fc528b1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
@@ -34,6...
[truncated]

@nirvedhmeshram nirvedhmeshram force-pushed the push_down_extract branch 2 times, most recently from c2a5eb9 to 7b9d96e Compare August 18, 2025 17:40
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overall logic makes sense. Just have a few clarifying comments and clean ups.

if (!sliceOperand) {
return failure();
}
return std::make_tuple(sliceOperand, operandIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think you need to return the operandIdx really, this is already part of OpOperand *?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 (sliceOperand->getOperandNumber())

// can use this information. Also return a bool mentioning if a reduction dim
// has a non full slice as that can be used to fold the original extract slice.
static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you dont need to pass the producerSliceOp. You should be able to get it from the OpOperand *.


for (auto [idx, expr] : llvm::enumerate(
genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
if (isConstantIntValue(offsets[idx], 0) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comments as to what each of these conditions are checking?

static FailureOr<std::tuple<llvm::DenseMap<int64_t, SliceDimInfo>, bool>>
getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
tensor::ExtractSliceOp producerSliceOp) {
llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I was confused by nonZeroSliceDimMap. I think you mean partialSliceDimMap? or something like that.

getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand,
tensor::ExtractSliceOp producerSliceOp) {
llvm::DenseMap<int64_t, SliceDimInfo> nonZeroSliceDimMap;
bool hasNonZeroReductionDimSlice = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with this. hasPartialReductionDimSlice or something like that is easier to understand.

if (!sliceOperand) {
return failure();
}
return std::make_tuple(sliceOperand, operandIndex);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 (sliceOperand->getOperandNumber())

Comment on lines +1281 to +1285
SmallVector<OpFoldResult> shape = llvm::map_to_vector(
producerSliceOp.getSourceType().getShape(),
[&](int64_t sz) -> OpFoldResult {
return getAsIndexOpFoldResult(genericOp.getContext(), sz);
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can pass an ArrayRef to getAsIndexOpFoldResult, so you don't need to map_to_vector.

Comment on lines +1316 to +1317
if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) {
return WalkResult::interrupt();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use nonZeroSliceDimMap.contains(dimExpr.getPosition())?

return failure();
if (hasGatherSemantics(genericOp))
return failure();
// Collect the unPacked operand, if present.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Collect the unPacked operand, if present.
// Collect the sliced operand, if present.


tensor::ExtractSliceOp producerSliceOp =
sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
assert(producerSliceOp && "expect a valid UnPackOp");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert(producerSliceOp && "expect a valid UnPackOp");
assert(producerSliceOp && "expect a valid extract_slice op");

static FailureOr<std::tuple<GenericOp, Value>>
pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
GenericOp genericOp,
ControlPropagationFn controlFn) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: using rewriter.notifyMatchFailure() is helpful for the less obvious restrictions on this pattern (things like rank-reducing, gather semantics, etc.)

Comment on lines +1377 to +1378
auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap);
bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional nit: Might be easier to follow if this just returns the dim map, and then we check if any reduction dims are present in the dim map afterwards.

continue;
}
AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Use early continue to save nesting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants