Skip to content

Commit 06e2831

Browse files
committed
Move to FoldMemRefAliasOps
1 parent 9552f4e commit 06e2831

File tree

4 files changed

+31
-76
lines changed

4 files changed

+31
-76
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ class ConversionTarget;
2222
namespace amdgpu {
2323

2424
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
25-
#define GEN_PASS_DECL_AMDGPUFOLDSUBVIEWOPSPASS
26-
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
2725
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
26+
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
2827
#define GEN_PASS_REGISTRATION
2928
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
3029

@@ -39,9 +38,6 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
3938
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
4039
PatternBenefit benefit = 1);
4140

42-
void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
43-
PatternBenefit benefit = 1);
44-
4541
} // namespace amdgpu
4642
} // namespace mlir
4743

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
add_mlir_dialect_library(MLIRAMDGPUTransforms
22
EmulateAtomics.cpp
3-
FoldSubviewOps.cpp
4-
MaskedloadToLoad.cpp
53
ResolveStridedMetadata.cpp
4+
MaskedloadToLoad.cpp
65

76
ADDITIONAL_HEADER_DIRS
87
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms

mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp

Lines changed: 0 additions & 67 deletions
This file was deleted.

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2121
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
2222
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
23+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
2324
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2425
#include "mlir/IR/AffineMap.h"
2526
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -732,6 +733,32 @@ LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
732733
return success();
733734
}
734735

736+
struct FoldSubviewIntoGatherToLDSOp
737+
: public OpRewritePattern<amdgpu::GatherToLDSOp> {
738+
using OpRewritePattern<amdgpu::GatherToLDSOp>::OpRewritePattern;
739+
LogicalResult
740+
matchAndRewrite(amdgpu::GatherToLDSOp op, PatternRewriter &rewriter) const override {
741+
Location loc = op.getLoc();
742+
743+
// Check if the source is a subview operation:
744+
auto subviewOp = dyn_cast<memref::SubViewOp>(op.getSrc().getDefiningOp());
745+
if (!subviewOp)
746+
return rewriter.notifyMatchFailure(
747+
loc, "GatherToLDSOp can only be folded if the source is a SubviewOp");
748+
749+
SmallVector<Value> sourceIndices;
750+
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
751+
rewriter, loc, subviewOp.getMixedOffsets(), subviewOp.getMixedStrides(),
752+
subviewOp.getDroppedDims(), op.getSrcIndices(), sourceIndices);
753+
754+
rewriter.replaceOpWithNewOp<admgpu::GatherToLDSOp>(
755+
op, subviewOp.getSource(), sourceIndices, op.getDst(), op.getDstIndices(),
756+
op.getTransferType());
757+
758+
return success();
759+
}
760+
};
761+
735762
void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
736763
patterns.add<LoadOpOfSubViewOpFolder<affine::AffineLoadOp>,
737764
LoadOpOfSubViewOpFolder<memref::LoadOp>,
@@ -762,8 +789,8 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
762789
StoreOpOfCollapseShapeOpFolder<memref::StoreOp>,
763790
StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
764791
StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
765-
SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
766-
patterns.getContext());
792+
SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder,
793+
FoldSubviewIntoGatherToLDSOp>(patterns.getContext());
767794
}
768795

769796
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)