Skip to content

Conversation

@lialan
Copy link
Member

@lialan lialan commented Jul 23, 2025

This is a reapply of patch #149851. The reapply also fixes a CMake/Bazel build issue, which was the reason of the revert. (Thanks @rupprecht )

Original patch (#149851) message:

This PR adds a new optimization pass to fold memref.subview/expand_shape/collapse_shape ops into consumer amdgpu.gather_to_lds operations.

  • Implements a new pass AmdgpuFoldMemRefOpsPass with pattern FoldMemRefOpsIntoGatherToLDSOp
  • Adds corresponding folding tests

@llvmbot llvmbot added backend:AMDGPU mlir:gpu mlir mlir:memref bazel "Peripheral" support tier build system: utils/bazel mlir:amdgpu labels Jul 23, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-amdgpu

Author: Alan Li (lialan)

Changes

This is a reapply of patch #149851. The reapply also fixes a Bazel build issue, which was the reason of the revert. (Thanks @rupprecht )

Original patch (#149851) message:

This PR adds a new optimization pass to fold memref.subview/expand_shape/collapse_shape ops into consumer amdgpu.gather_to_lds operations.

  • Implements a new pass AmdgpuFoldMemRefOpsPass with pattern FoldMemRefOpsIntoGatherToLDSOp
  • Adds corresponding folding tests

Patch is 21.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150334.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+5-1)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+12)
  • (modified) mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h (+37)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2-1)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp (+97)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (-91)
  • (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+66)
  • (added) mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir (+94)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index cc2f543e79f69..58b9c74b2f8e0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -22,8 +22,9 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
-#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
+#define GEN_PASS_DECL_AMDGPUFOLDMEMREFOPSPASS
 #define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -38,6 +39,9 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
 void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 1);
 
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
 } // namespace amdgpu
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 8d0e6829ab0cc..8664f971cabde 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -70,4 +70,16 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
     "memref::MemRefDialect"
   ];
 }
+
+def AmdgpuFoldMemRefOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
+  let summary = "Fold memref operations into their parent operations";
+  let description = [{
+    This pass identifies memref operations (subview, expand_shape, collapse_shape)
+    that are sources of `GatherToLDSOp` and attempts to fold the source ops,
+    potentially simplifying the overall operation and improving performance.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect"
+  ];
+}
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 34ad279a07a8b..dd3b3dea6ef26 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -116,6 +116,43 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
 /// the source memref (i.e. implements ViewLikeOpInterface).
 MemrefValue skipViewLikeOps(MemrefValue source);
 
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a expand_shape op, returns the indices w.r.t to the source memref of the
+/// expand_shape op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = memref.expand_shape %0 [[0, 1], [2]]
+///    : memref<12x42xf32> into memref<2x6x42xf32>
+/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
+///
+/// could be folded into
+///
+/// %2 = load %0[6 * i1 + i2, %i3] :
+///          memref<12x42xf32>
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a collapse_shape op, returns the indices w.r.t to the source memref of
+/// the collapse_shape op. For example
+///
+/// %0 = ... : memref<2x6x42xf32>
+/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
+///    : memref<2x6x42xf32> into memref<12x42xf32>
+/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
+///
+/// could be folded into
+///
+/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
+///          memref<2x6x42xf32>
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices);
+
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 17bbe54ea6c0c..3b0c072ed1217 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,8 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
-  ResolveStridedMetadata.cpp
+  FoldMemRefsOps.cpp
   MaskedloadToLoad.cpp
+  ResolveStridedMetadata.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
new file mode 100644
index 0000000000000..a3fdc7ee385ed
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -0,0 +1,97 @@
+//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+
+struct AmdgpuFoldMemRefOpsPass final
+    : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuFoldMemRefOpsPatterns(patterns);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
+struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherToLDSOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    Value memrefSource;
+    SmallVector<Value> sourceIndices;
+    auto foldResult =
+        llvm::TypeSwitch<Operation *, LogicalResult>(
+            op.getSrc().getDefiningOp())
+            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+              // If the source is a SubViewOp, we can directly rewrite the
+              // GatherToLDSOp.
+              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+                  rewriter, loc, subviewOp.getMixedOffsets(),
+                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+                  op.getSrcIndices(), sourceIndices);
+              memrefSource = subviewOp.getSource();
+              return success();
+            })
+            .Case<memref::ExpandShapeOp>(
+                [&](memref::ExpandShapeOp expandShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
+                          sourceIndices, false))) {
+                    return failure();
+                  }
+                  memrefSource = expandShapeOp.getViewSource();
+                  return success();
+                })
+            .Case<memref::CollapseShapeOp>(
+                [&](memref::CollapseShapeOp collapseShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
+                          sourceIndices))) {
+                    return failure();
+                  }
+                  memrefSource = collapseShapeOp.getViewSource();
+                  return success();
+                })
+            .Default([&](Operation *op) {
+              // If the source is not a SubViewOp, ExpandShapeOp, or
+              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
+              return rewriter.notifyMatchFailure(
+                  op,
+                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
+                  "CollapseShapeOp");
+            });
+
+    if (failed(foldResult)) {
+      return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
+                                               op.getDst(), op.getDstIndices(),
+                                               op.getTransferType());
+
+    return success();
+  }
+};
+
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit) {
+  patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
+}
+} // namespace mlir::amdgpu
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 89be188af9129..24da447ad7685 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -44,97 +44,6 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a expand_shape op, returns the indices w.r.t to the source memref of the
-/// expand_shape op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = memref.expand_shape %0 [[0, 1], [2]]
-///    : memref<12x42xf32> into memref<2x6x42xf32>
-/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
-///
-/// could be folded into
-///
-/// %2 = load %0[6 * i1 + i2, %i3] :
-///          memref<12x42xf32>
-static LogicalResult resolveSourceIndicesExpandShape(
-    Location loc, PatternRewriter &rewriter,
-    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
-    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
-  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
-
-  // Traverse all reassociation groups to determine the appropriate indices
-  // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-    if (groupSize == 1) {
-      sourceIndices.push_back(indices[group[0]]);
-      continue;
-    }
-    SmallVector<OpFoldResult> groupBasis =
-        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
-    SmallVector<Value> groupIndices =
-        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
-    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
-        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
-    sourceIndices.push_back(collapsedIndex);
-  }
-  return success();
-}
-
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a collapse_shape op, returns the indices w.r.t to the source memref of
-/// the collapse_shape op. For example
-///
-/// %0 = ... : memref<2x6x42xf32>
-/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
-///    : memref<2x6x42xf32> into memref<12x42xf32>
-/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
-///
-/// could be folded into
-///
-/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
-///          memref<2x6x42xf32>
-static LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
-                                  memref::CollapseShapeOp collapseShapeOp,
-                                  ValueRange indices,
-                                  SmallVectorImpl<Value> &sourceIndices) {
-  // Note: collapse_shape requires a strided memref, we can do this.
-  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-      loc, collapseShapeOp.getSrc());
-  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
-  for (auto [index, group] :
-       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-
-    if (groupSize == 1) {
-      sourceIndices.push_back(index);
-      continue;
-    }
-
-    SmallVector<OpFoldResult> basis =
-        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
-    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
-        loc, index, basis, /*hasOuterBound=*/true);
-    llvm::append_range(sourceIndices, delinearize.getResults());
-  }
-  if (collapseShapeOp.getReassociationIndices().empty()) {
-    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
-    int64_t srcRank =
-        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
-    for (int64_t i = 0; i < srcRank; i++) {
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
-    }
-  }
-  return success();
-}
-
 /// Helpers to access the memref operand for each op.
 template <typename LoadOrStoreOpTy>
 static Value getMemRefOperand(LoadOrStoreOpTy op) {
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index a50b4cfc74708..97fe3cb5b4705 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
@@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
   return source;
 }
 
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
+
+  // Traverse all reassociation groups to determine the appropriate indices
+  // corresponding to each one of them post op folding.
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
+    }
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
+  }
+  return success();
+}
+
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices) {
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
+    }
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
+  }
+  if (collapseShapeOp.getReassociationIndices().empty()) {
+    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+    int64_t srcRank =
+        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
+    for (int64_t i = 0; i < srcRank; i++) {
+      sourceIndices.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    }
+  }
+  return success();
+}
+
 } // namespace memref
 } // namespace mlir
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
new file mode 100644
index 0000000000000..57afa127c9da8
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt --amdgpu-fold-memrefs-ops --split-input-file %s | FileCheck %s
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_subview_folding
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_subview_folding(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]], %[[ARG1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[0, 0][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1]>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1]>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 32)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 64)>
+
+// CHECK: func @subview_folding_offset
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+  // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[32, 64][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1], offset: 4160>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1], offset: 4160>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_expand_shape
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+
+  %alloc = memre...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 23, 2025

@llvm/pr-subscribers-mlir

Author: Alan Li (lialan)

Changes

This is a reapply of patch #149851. The reapply also fixes a Bazel build issue, which was the reason of the revert. (Thanks @rupprecht )

Original patch (#149851) message:

This PR adds a new optimization pass to fold memref.subview/expand_shape/collapse_shape ops into consumer amdgpu.gather_to_lds operations.

  • Implements a new pass AmdgpuFoldMemRefOpsPass with pattern FoldMemRefOpsIntoGatherToLDSOp
  • Adds corresponding folding tests

Patch is 21.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150334.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h (+5-1)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td (+12)
  • (modified) mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h (+37)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt (+2-1)
  • (added) mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp (+97)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (-91)
  • (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+66)
  • (added) mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir (+94)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
index cc2f543e79f69..58b9c74b2f8e0 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
@@ -22,8 +22,9 @@ class ConversionTarget;
 namespace amdgpu {
 
 #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
-#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
+#define GEN_PASS_DECL_AMDGPUFOLDMEMREFOPSPASS
 #define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
+#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
 
@@ -38,6 +39,9 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
 void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
                                             PatternBenefit benefit = 1);
 
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit = 1);
+
 } // namespace amdgpu
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 8d0e6829ab0cc..8664f971cabde 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -70,4 +70,16 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
     "memref::MemRefDialect"
   ];
 }
+
+def AmdgpuFoldMemRefOpsPass : Pass<"amdgpu-fold-memrefs-ops"> {
+  let summary = "Fold memref operations into their parent operations";
+  let description = [{
+    This pass identifies memref operations (subview, expand_shape, collapse_shape)
+    that are sources of `GatherToLDSOp` and attempts to fold the source ops,
+    potentially simplifying the overall operation and improving performance.
+  }];
+  let dependentDialects = [
+    "memref::MemRefDialect"
+  ];
+}
 #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
index 34ad279a07a8b..dd3b3dea6ef26 100644
--- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
+++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
@@ -116,6 +116,43 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
 /// the source memref (i.e. implements ViewLikeOpInterface).
 MemrefValue skipViewLikeOps(MemrefValue source);
 
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a expand_shape op, returns the indices w.r.t to the source memref of the
+/// expand_shape op. For example
+///
+/// %0 = ... : memref<12x42xf32>
+/// %1 = memref.expand_shape %0 [[0, 1], [2]]
+///    : memref<12x42xf32> into memref<2x6x42xf32>
+/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
+///
+/// could be folded into
+///
+/// %2 = load %0[6 * i1 + i2, %i3] :
+///          memref<12x42xf32>
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
+
+/// Given the 'indices' of a load/store operation where the memref is a result
+/// of a collapse_shape op, returns the indices w.r.t to the source memref of
+/// the collapse_shape op. For example
+///
+/// %0 = ... : memref<2x6x42xf32>
+/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
+///    : memref<2x6x42xf32> into memref<12x42xf32>
+/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
+///
+/// could be folded into
+///
+/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
+///          memref<2x6x42xf32>
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices);
+
 } // namespace memref
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index 17bbe54ea6c0c..3b0c072ed1217 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -1,7 +1,8 @@
 add_mlir_dialect_library(MLIRAMDGPUTransforms
   EmulateAtomics.cpp
-  ResolveStridedMetadata.cpp
+  FoldMemRefsOps.cpp
   MaskedloadToLoad.cpp
+  ResolveStridedMetadata.cpp
 
   ADDITIONAL_HEADER_DIRS
   {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
new file mode 100644
index 0000000000000..a3fdc7ee385ed
--- /dev/null
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
@@ -0,0 +1,97 @@
+//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir::amdgpu {
+#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
+
+struct AmdgpuFoldMemRefOpsPass final
+    : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateAmdgpuFoldMemRefOpsPatterns(patterns);
+    walkAndApplyPatterns(getOperation(), std::move(patterns));
+  }
+};
+
+struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GatherToLDSOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+
+    Value memrefSource;
+    SmallVector<Value> sourceIndices;
+    auto foldResult =
+        llvm::TypeSwitch<Operation *, LogicalResult>(
+            op.getSrc().getDefiningOp())
+            .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
+              // If the source is a SubViewOp, we can directly rewrite the
+              // GatherToLDSOp.
+              mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
+                  rewriter, loc, subviewOp.getMixedOffsets(),
+                  subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
+                  op.getSrcIndices(), sourceIndices);
+              memrefSource = subviewOp.getSource();
+              return success();
+            })
+            .Case<memref::ExpandShapeOp>(
+                [&](memref::ExpandShapeOp expandShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesExpandShape(
+                          loc, rewriter, expandShapeOp, op.getSrcIndices(),
+                          sourceIndices, false))) {
+                    return failure();
+                  }
+                  memrefSource = expandShapeOp.getViewSource();
+                  return success();
+                })
+            .Case<memref::CollapseShapeOp>(
+                [&](memref::CollapseShapeOp collapseShapeOp) {
+                  if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
+                          loc, rewriter, collapseShapeOp, op.getSrcIndices(),
+                          sourceIndices))) {
+                    return failure();
+                  }
+                  memrefSource = collapseShapeOp.getViewSource();
+                  return success();
+                })
+            .Default([&](Operation *op) {
+              // If the source is not a SubViewOp, ExpandShapeOp, or
+              // CollapseShapeOp, we cannot fold the GatherToLDSOp.
+              return rewriter.notifyMatchFailure(
+                  op,
+                  "source producer is not one of SubViewOp, ExpandShapeOp, or "
+                  "CollapseShapeOp");
+            });
+
+    if (failed(foldResult)) {
+      return failure();
+    }
+
+    rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
+                                               op.getDst(), op.getDstIndices(),
+                                               op.getTransferType());
+
+    return success();
+  }
+};
+
+void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns,
+                                         PatternBenefit benefit) {
+  patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit);
+}
+} // namespace mlir::amdgpu
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index 89be188af9129..24da447ad7685 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -44,97 +44,6 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a expand_shape op, returns the indices w.r.t to the source memref of the
-/// expand_shape op. For example
-///
-/// %0 = ... : memref<12x42xf32>
-/// %1 = memref.expand_shape %0 [[0, 1], [2]]
-///    : memref<12x42xf32> into memref<2x6x42xf32>
-/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
-///
-/// could be folded into
-///
-/// %2 = load %0[6 * i1 + i2, %i3] :
-///          memref<12x42xf32>
-static LogicalResult resolveSourceIndicesExpandShape(
-    Location loc, PatternRewriter &rewriter,
-    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
-    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
-  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
-
-  // Traverse all reassociation groups to determine the appropriate indices
-  // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-    if (groupSize == 1) {
-      sourceIndices.push_back(indices[group[0]]);
-      continue;
-    }
-    SmallVector<OpFoldResult> groupBasis =
-        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
-    SmallVector<Value> groupIndices =
-        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
-    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
-        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
-    sourceIndices.push_back(collapsedIndex);
-  }
-  return success();
-}
-
-/// Given the 'indices' of a load/store operation where the memref is a result
-/// of a collapse_shape op, returns the indices w.r.t to the source memref of
-/// the collapse_shape op. For example
-///
-/// %0 = ... : memref<2x6x42xf32>
-/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
-///    : memref<2x6x42xf32> into memref<12x42xf32>
-/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
-///
-/// could be folded into
-///
-/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
-///          memref<2x6x42xf32>
-static LogicalResult
-resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
-                                  memref::CollapseShapeOp collapseShapeOp,
-                                  ValueRange indices,
-                                  SmallVectorImpl<Value> &sourceIndices) {
-  // Note: collapse_shape requires a strided memref, we can do this.
-  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-      loc, collapseShapeOp.getSrc());
-  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
-  for (auto [index, group] :
-       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
-    assert(!group.empty() && "association indices groups cannot be empty");
-    int64_t groupSize = group.size();
-
-    if (groupSize == 1) {
-      sourceIndices.push_back(index);
-      continue;
-    }
-
-    SmallVector<OpFoldResult> basis =
-        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
-    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
-        loc, index, basis, /*hasOuterBound=*/true);
-    llvm::append_range(sourceIndices, delinearize.getResults());
-  }
-  if (collapseShapeOp.getReassociationIndices().empty()) {
-    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
-    int64_t srcRank =
-        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
-    for (int64_t i = 0; i < srcRank; i++) {
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
-    }
-  }
-  return success();
-}
-
 /// Helpers to access the memref operand for each op.
 template <typename LoadOrStoreOpTy>
 static Value getMemRefOperand(LoadOrStoreOpTy op) {
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index a50b4cfc74708..97fe3cb5b4705 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/STLExtras.h"
@@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
   return source;
 }
 
+LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
+
+  // Traverse all reassociation groups to determine the appropriate indices
+  // corresponding to each one of them post op folding.
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
+    }
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
+  }
+  return success();
+}
+
+LogicalResult
+resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
+                                  memref::CollapseShapeOp collapseShapeOp,
+                                  ValueRange indices,
+                                  SmallVectorImpl<Value> &sourceIndices) {
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
+    }
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
+  }
+  if (collapseShapeOp.getReassociationIndices().empty()) {
+    auto zeroAffineMap = rewriter.getConstantAffineMap(0);
+    int64_t srcRank =
+        cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
+    for (int64_t i = 0; i < srcRank; i++) {
+      sourceIndices.push_back(
+          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    }
+  }
+  return success();
+}
+
 } // namespace memref
 } // namespace mlir
diff --git a/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
new file mode 100644
index 0000000000000..57afa127c9da8
--- /dev/null
+++ b/mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
@@ -0,0 +1,94 @@
+// RUN: mlir-opt --amdgpu-fold-memrefs-ops --split-input-file %s | FileCheck %s
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_subview_folding
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_subview_folding(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]], %[[ARG1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[0, 0][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1]>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1]>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: #[[MAP:.*]] = affine_map<()[s0] -> (s0 + 32)>
+// CHECK: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 64)>
+
+// CHECK: func @subview_folding_offset
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX0:.*]] = affine.apply #[[MAP]]()[%[[ARG0]]]
+  // CHECK: %[[IDX1:.*]] = affine.apply #[[MAP1]]()[%[[ARG1]]]
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX0]], %[[IDX1]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
+
+  %alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
+  %mem = memref.alloc() : memref<64x128xf16>
+  %subview = memref.subview %mem[32, 64][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1], offset: 4160>>
+  %c0 = arith.constant 0 : index
+  amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
+    : vector<8xf16>, memref<32x64xf16, strided<[128, 1], offset: 4160>>, memref<64x64xf16, #gpu_lds_addrspace>
+  func.return
+}
+
+// -----
+
+#gpu_lds_addrspace = 3
+
+// CHECK: func @test_expand_shape
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
+func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
+  // CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
+  // CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
+  // CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
+  // CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
+
+  %alloc = memre...
[truncated]

@rupprecht
Copy link
Collaborator

The reapply also fixes a Bazel build issue, which was the reason of the revert

The buildbot failures I see on the original commit are all from non-Bazel builds, so the more important thing to do in this PR is to update the relevant CMakeLists.txt file w/ whatever dep needs to be included. Bazel failures are generally non-blocking.

The link error looks similar to the bazel error, so I think you need to add MLIRAffineUtils to the deps list in mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt.

The bazel change here LGTM. Thanks!

@lialan lialan requested review from krzysz00 and removed request for aaronmondal, keith and rupprecht July 23, 2025 23:53
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lialan lialan merged commit 1c3e4e9 into llvm:main Jul 24, 2025
9 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 24, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building mlir,utils at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/16026

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/GPU/CUDA/async.mlir' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-kernel-outlining  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary="format=fatbin"  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_async_runtime.so    --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so    --entry-point-result=void -O0  | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-kernel-outlining
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt '-pass-pipeline=builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm),nvvm-attach-target)'
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -gpu-async-region -gpu-to-llvm -reconcile-unrealized-casts -gpu-module-to-binary=format=fatbin
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -async-to-async-runtime -async-runtime-ref-counting
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt -convert-async-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_cuda_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_async_runtime.so --shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_runner_utils.so --entry-point-result=void -O0
# .---command stderr------------
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuStreamWaitEvent(stream, event, 0)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventSynchronize(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# | 'cuEventDestroy(event)' failed with 'CUDA_ERROR_CONTEXT_IS_DESTROYED'
# `-----------------------------
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# .---command stderr------------
# | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir:68:12: error: CHECK: expected string not found in input
# |  // CHECK: [84, 84]
# |            ^
# | <stdin>:1:1: note: scanning from here
# | Unranked Memref base@ = 0x5adc35727a00 rank = 1 offset = 0 sizes = [2] strides = [1] data = 
# | ^
# | <stdin>:2:1: note: possible intended match here
# | [42, 42]
# | ^
# | 
# | Input file: <stdin>
# | Check file: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/GPU/CUDA/async.mlir
# | 
# | -dump-input=help explains the following input dump.
# | 
# | Input was:
# | <<<<<<
# |             1: Unranked Memref base@ = 0x5adc35727a00 rank = 1 offset = 0 sizes = [2] strides = [1] data =  
# | check:68'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
# |             2: [42, 42] 
# | check:68'0     ~~~~~~~~~
# | check:68'1     ?         possible intended match
...

@lialan lialan deleted the lialan/fold_memrefs branch July 24, 2025 13:35
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jul 24, 2025

LLVM Buildbot has detected a new failure on builder premerge-monolithic-linux running on premerge-linux-1 while building mlir,utils at step 7 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/153/builds/39210

Here is the relevant piece of the build log for the reference
Step 7 (test-build-unified-tree-check-all) failure: test (failure)
...
PASS: lld :: COFF/duplicate-dwarf.s (98851 of 101866)
PASS: lld :: COFF/duplicate-cv.s (98852 of 101866)
PASS: lld :: COFF/delayimports-error.test (98853 of 101866)
PASS: lld :: COFF/duplicate-absolute.s (98854 of 101866)
PASS: lld :: COFF/baserel.test (98855 of 101866)
PASS: lld :: COFF/defparser.test (98856 of 101866)
PASS: lld :: COFF/arm64x-import.test (98857 of 101866)
PASS: lld :: COFF/def-name.test (98858 of 101866)
PASS: lld :: COFF/duplicate.test (98859 of 101866)
TIMEOUT: MLIR :: Examples/standalone/test.toy (98860 of 101866)
******************** TEST 'MLIR :: Examples/standalone/test.toy' FAILED ********************
Exit Code: 1
Timeout: Reached timeout of 60 seconds

Command Output (stdout):
--
# RUN: at line 1
"/etc/cmake/bin/cmake" "/build/buildbot/premerge-monolithic-linux/llvm-project/mlir/examples/standalone" -G "Ninja"  -DCMAKE_CXX_COMPILER=/usr/bin/clang++ -DCMAKE_C_COMPILER=/usr/bin/clang  -DLLVM_ENABLE_LIBCXX=OFF -DMLIR_DIR=/build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir  -DLLVM_USE_LINKER=lld  -DPython3_EXECUTABLE="/usr/bin/python3.10"
# executed command: /etc/cmake/bin/cmake /build/buildbot/premerge-monolithic-linux/llvm-project/mlir/examples/standalone -G Ninja -DCMAKE_CXX_COMPILER=/usr/bin/clang++ -DCMAKE_C_COMPILER=/usr/bin/clang -DLLVM_ENABLE_LIBCXX=OFF -DMLIR_DIR=/build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir -DLLVM_USE_LINKER=lld -DPython3_EXECUTABLE=/usr/bin/python3.10
# .---command stdout------------
# | -- The CXX compiler identification is Clang 16.0.6
# | -- The C compiler identification is Clang 16.0.6
# | -- Detecting CXX compiler ABI info
# | -- Detecting CXX compiler ABI info - done
# | -- Check for working CXX compiler: /usr/bin/clang++ - skipped
# | -- Detecting CXX compile features
# | -- Detecting CXX compile features - done
# | -- Detecting C compiler ABI info
# | -- Detecting C compiler ABI info - done
# | -- Check for working C compiler: /usr/bin/clang - skipped
# | -- Detecting C compile features
# | -- Detecting C compile features - done
# | -- Looking for histedit.h
# | -- Looking for histedit.h - found
# | -- Found LibEdit: /usr/include (found version "2.11") 
# | -- Found ZLIB: /usr/lib/x86_64-linux-gnu/libz.so (found version "1.2.11") 
# | -- Found LibXml2: /usr/lib/x86_64-linux-gnu/libxml2.so (found version "2.9.13") 
# | -- Using MLIRConfig.cmake in: /build/buildbot/premerge-monolithic-linux/build/lib/cmake/mlir
# | -- Using LLVMConfig.cmake in: /build/buildbot/premerge-monolithic-linux/build/lib/cmake/llvm
# | -- Linker detection: unknown
# | -- Performing Test LLVM_LIBSTDCXX_MIN
# | -- Performing Test LLVM_LIBSTDCXX_MIN - Success
# | -- Performing Test LLVM_LIBSTDCXX_SOFT_ERROR
# | -- Performing Test LLVM_LIBSTDCXX_SOFT_ERROR - Success
# | -- Performing Test CXX_SUPPORTS_CUSTOM_LINKER
# | -- Performing Test CXX_SUPPORTS_CUSTOM_LINKER - Success
# | -- Performing Test C_SUPPORTS_FPIC
# | -- Performing Test C_SUPPORTS_FPIC - Success
# | -- Performing Test CXX_SUPPORTS_FPIC

mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…nto `amdgpu.gather_to_lds`" (llvm#150334)

This is a reapply of patch llvm#149851. The reapply also fixes a CMake/Bazel
build issue, which was the reason of the revert. (Thanks @rupprecht )

Original patch (llvm#149851) message:
-----
This PR adds a new optimization pass to fold
`memref.subview/expand_shape/collapse_shape` ops into consumer
`amdgpu.gather_to_lds` operations.
* Implements a new pass `AmdgpuFoldMemRefOpsPass` with pattern
`FoldMemRefOpsIntoGatherToLDSOp`
* Adds corresponding folding tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:AMDGPU bazel "Peripheral" support tier build system: utils/bazel mlir:amdgpu mlir:gpu mlir:memref mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants