Skip to content

Commit 6e12ca4

Browse files
committed
[AMDGPU] fold memref.subview into amdgpu.gather_to_lds
1 parent 6932080 commit 6e12ca4

File tree

5 files changed

+105
-2
lines changed

5 files changed

+105
-2
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ class ConversionTarget;
2222
namespace amdgpu {
2323

2424
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
25-
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
25+
#define GEN_PASS_DECL_AMDGPUFOLDSUBVIEWOPSPASS
2626
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
27+
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
2728
#define GEN_PASS_REGISTRATION
2829
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
2930

@@ -38,6 +39,9 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
3839
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
3940
PatternBenefit benefit = 1);
4041

42+
void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
43+
PatternBenefit benefit = 1);
44+
4145
} // namespace amdgpu
4246
} // namespace mlir
4347

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,16 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
7070
"memref::MemRefDialect"
7171
];
7272
}
73+
74+
def AmdgpuFoldSubviewOpsPass : Pass<"amdgpu-fold-subview-ops"> {
75+
let summary = "Fold subview operations into their parent operations";
76+
let description = [{
77+
This pass identifies `memref.subview` source of `GatherToLDSOp` and
78+
attempts to fold the source op, potentially simplifying the overall
79+
operation and improving performance.
80+
}];
81+
let dependentDialects = [
82+
"memref::MemRefDialect"
83+
];
84+
}
7385
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
add_mlir_dialect_library(MLIRAMDGPUTransforms
22
EmulateAtomics.cpp
3-
ResolveStridedMetadata.cpp
3+
FoldSubviewOps.cpp
44
MaskedloadToLoad.cpp
5+
ResolveStridedMetadata.cpp
56

67
ADDITIONAL_HEADER_DIRS
78
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===- FoldSubviewOps.cpp - AMDGPU fold subview ops ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12+
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
13+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15+
16+
namespace mlir::amdgpu {
17+
#define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
18+
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
19+
} // namespace mlir::amdgpu
20+
21+
using namespace mlir;
22+
using namespace mlir::amdgpu;
23+
24+
namespace {
25+
struct AmdgpuFoldSubviewOpsPass
26+
: public amdgpu::impl::AmdgpuFoldSubviewOpsPassBase<
27+
AmdgpuFoldSubviewOpsPass> {
28+
void runOnOperation() override {
29+
RewritePatternSet patterns(&getContext());
30+
populateAmdgpuFoldSubviewOpsPatterns(patterns);
31+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
32+
signalPassFailure();
33+
}
34+
};
35+
36+
struct FoldSubviewIntoGatherToLDSOp : public OpRewritePattern<GatherToLDSOp> {
37+
using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
38+
LogicalResult matchAndRewrite(GatherToLDSOp op,
39+
PatternRewriter &rewriter) const override {
40+
Location loc = op.getLoc();
41+
42+
// Check if the source is a subview operation:
43+
auto subviewOp = dyn_cast<memref::SubViewOp>(op.getSrc().getDefiningOp());
44+
if (!subviewOp)
45+
return rewriter.notifyMatchFailure(
46+
loc, "GatherToLDSOp can only be folded if the source is a SubviewOp");
47+
48+
SmallVector<Value> sourceIndices;
49+
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
50+
rewriter, loc, subviewOp.getMixedOffsets(), subviewOp.getMixedStrides(),
51+
subviewOp.getDroppedDims(), op.getSrcIndices(), sourceIndices);
52+
53+
rewriter.replaceOpWithNewOp<GatherToLDSOp>(
54+
op, subviewOp.getSource(), sourceIndices, op.getDst(),
55+
op.getDstIndices(), op.getTransferType());
56+
57+
return success();
58+
}
59+
};
60+
} // namespace
61+
62+
void mlir::amdgpu::populateAmdgpuFoldSubviewOpsPatterns(
63+
RewritePatternSet &patterns, PatternBenefit benefit) {
64+
patterns.add<FoldSubviewIntoGatherToLDSOp>(patterns.getContext(), benefit);
65+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt -amdgpu-fold-subview-ops -split-input-file %s | FileCheck %s
2+
3+
#gpu_lds_addrspace = 3
4+
5+
// CHECK: func @test_memref
6+
// CHECK-SAME: %arg0: index, %arg1: index
7+
func.func @test_memref(%offset_i: index, %offset_j: index) {
8+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
9+
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
10+
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
11+
// CHECK: %[[MEM]][%arg0, %arg1], %[[LOCAL]][%[[C0]], %[[C0]]]
12+
// CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
13+
14+
%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
15+
%mem = memref.alloc() : memref<64x128xf16>
16+
%subview = memref.subview %mem[0, 0][32, 64][1, 1] : memref<64x128xf16> to memref<32x64xf16, strided<[128, 1]>>
17+
%c0 = arith.constant 0 : index
18+
amdgpu.gather_to_lds %subview[%offset_i, %offset_j], %alloc[%c0, %c0]
19+
: vector<8xf16>, memref<32x64xf16, strided<[128, 1]>>, memref<64x64xf16, #gpu_lds_addrspace>
20+
func.return
21+
}

0 commit comments

Comments
 (0)