18
18
namespace mlir ::amdgpu {
19
19
#define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
20
20
#include " mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
21
- } // namespace mlir::amdgpu
22
-
23
- using namespace mlir ;
24
- using namespace mlir ::amdgpu;
25
21
26
- namespace {
27
22
struct AmdgpuFoldSubviewOpsPass
28
23
: public amdgpu::impl::AmdgpuFoldSubviewOpsPassBase<
29
24
AmdgpuFoldSubviewOpsPass> {
@@ -43,32 +38,51 @@ struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
43
38
44
39
Value memrefSource;
45
40
SmallVector<Value> sourceIndices;
46
- llvm::TypeSwitch<Operation *>(op.getSrc ().getDefiningOp ())
47
- .Case <memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
48
- // If the source is a SubViewOp, we can directly rewrite the
49
- // GatherToLDSOp.
50
- mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides (
51
- rewriter, loc, subviewOp.getMixedOffsets (),
52
- subviewOp.getMixedStrides (), subviewOp.getDroppedDims (),
53
- op.getSrcIndices (), sourceIndices);
54
- memrefSource = subviewOp.getSource ();
55
- })
56
- .Case <memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
57
- mlir::memref::resolveSourceIndicesExpandShape (
58
- loc, rewriter, expandShapeOp, op.getSrcIndices (), sourceIndices,
59
- false );
60
- memrefSource = expandShapeOp.getViewSource ();
61
- })
62
- .Case <memref::CollapseShapeOp>(
63
- [&](memref::CollapseShapeOp collapseShapeOp) {
64
- mlir::memref::resolveSourceIndicesCollapseShape (
65
- loc, rewriter, collapseShapeOp, op.getSrcIndices (),
66
- sourceIndices);
67
- memrefSource = collapseShapeOp.getViewSource ();
41
+ auto foldResult =
42
+ llvm::TypeSwitch<Operation *, LogicalResult>(
43
+ op.getSrc ().getDefiningOp ())
44
+ .Case <memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
45
+ // If the source is a SubViewOp, we can directly rewrite the
46
+ // GatherToLDSOp.
47
+ mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides (
48
+ rewriter, loc, subviewOp.getMixedOffsets (),
49
+ subviewOp.getMixedStrides (), subviewOp.getDroppedDims (),
50
+ op.getSrcIndices (), sourceIndices);
51
+ memrefSource = subviewOp.getSource ();
52
+ return success ();
53
+ })
54
+ .Case <memref::ExpandShapeOp>(
55
+ [&](memref::ExpandShapeOp expandShapeOp) {
56
+ if (failed (mlir::memref::resolveSourceIndicesExpandShape (
57
+ loc, rewriter, expandShapeOp, op.getSrcIndices (),
58
+ sourceIndices, false ))) {
59
+ return failure ();
60
+ }
61
+ memrefSource = expandShapeOp.getViewSource ();
62
+ return success ();
63
+ })
64
+ .Case <memref::CollapseShapeOp>(
65
+ [&](memref::CollapseShapeOp collapseShapeOp) {
66
+ if (failed (mlir::memref::resolveSourceIndicesCollapseShape (
67
+ loc, rewriter, collapseShapeOp, op.getSrcIndices (),
68
+ sourceIndices))) {
69
+ return failure ();
70
+ }
71
+ memrefSource = collapseShapeOp.getViewSource ();
72
+ return success ();
73
+ })
74
+ .Default ([&](Operation *op) {
75
+ // If the source is not a SubViewOp, ExpandShapeOp, or
76
+ // CollapseShapeOp, we cannot fold the GatherToLDSOp.
77
+ return rewriter.notifyMatchFailure (
78
+ op,
79
+ " source producer is not one of SubViewOp, ExpandShapeOp, or "
80
+ " CollapseShapeOp" );
68
81
});
69
82
70
- if (!memrefSource)
83
+ if (failed (foldResult)) {
71
84
return failure ();
85
+ }
72
86
73
87
rewriter.replaceOpWithNewOp <GatherToLDSOp>(op, memrefSource, sourceIndices,
74
88
op.getDst (), op.getDstIndices (),
@@ -77,9 +91,9 @@ struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
77
91
return success ();
78
92
}
79
93
};
80
- } // namespace
81
94
82
- void mlir::amdgpu:: populateAmdgpuFoldSubviewOpsPatterns (
83
- RewritePatternSet &patterns, PatternBenefit benefit) {
95
+ void populateAmdgpuFoldSubviewOpsPatterns (RewritePatternSet &patterns,
96
+ PatternBenefit benefit) {
84
97
patterns.add <FoldSubviewIntoGatherToLDSOp>(patterns.getContext (), benefit);
85
98
}
99
+ } // namespace mlir::amdgpu
0 commit comments