Skip to content

Commit 3db555d

Browse files
committed
update tests.
1 parent cf50f5f commit 3db555d

File tree

2 files changed

+94
-34
lines changed

2 files changed

+94
-34
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
namespace mlir::amdgpu {
1919
#define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
2020
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
21-
} // namespace mlir::amdgpu
22-
23-
using namespace mlir;
24-
using namespace mlir::amdgpu;
2521

26-
namespace {
2722
struct AmdgpuFoldSubviewOpsPass
2823
: public amdgpu::impl::AmdgpuFoldSubviewOpsPassBase<
2924
AmdgpuFoldSubviewOpsPass> {
@@ -43,32 +38,51 @@ struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
4338

4439
Value memrefSource;
4540
SmallVector<Value> sourceIndices;
46-
llvm::TypeSwitch<Operation *>(op.getSrc().getDefiningOp())
47-
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
48-
// If the source is a SubViewOp, we can directly rewrite the
49-
// GatherToLDSOp.
50-
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
51-
rewriter, loc, subviewOp.getMixedOffsets(),
52-
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
53-
op.getSrcIndices(), sourceIndices);
54-
memrefSource = subviewOp.getSource();
55-
})
56-
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
57-
mlir::memref::resolveSourceIndicesExpandShape(
58-
loc, rewriter, expandShapeOp, op.getSrcIndices(), sourceIndices,
59-
false);
60-
memrefSource = expandShapeOp.getViewSource();
61-
})
62-
.Case<memref::CollapseShapeOp>(
63-
[&](memref::CollapseShapeOp collapseShapeOp) {
64-
mlir::memref::resolveSourceIndicesCollapseShape(
65-
loc, rewriter, collapseShapeOp, op.getSrcIndices(),
66-
sourceIndices);
67-
memrefSource = collapseShapeOp.getViewSource();
41+
auto foldResult =
42+
llvm::TypeSwitch<Operation *, LogicalResult>(
43+
op.getSrc().getDefiningOp())
44+
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
45+
// If the source is a SubViewOp, we can directly rewrite the
46+
// GatherToLDSOp.
47+
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
48+
rewriter, loc, subviewOp.getMixedOffsets(),
49+
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
50+
op.getSrcIndices(), sourceIndices);
51+
memrefSource = subviewOp.getSource();
52+
return success();
53+
})
54+
.Case<memref::ExpandShapeOp>(
55+
[&](memref::ExpandShapeOp expandShapeOp) {
56+
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
57+
loc, rewriter, expandShapeOp, op.getSrcIndices(),
58+
sourceIndices, false))) {
59+
return failure();
60+
}
61+
memrefSource = expandShapeOp.getViewSource();
62+
return success();
63+
})
64+
.Case<memref::CollapseShapeOp>(
65+
[&](memref::CollapseShapeOp collapseShapeOp) {
66+
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
67+
loc, rewriter, collapseShapeOp, op.getSrcIndices(),
68+
sourceIndices))) {
69+
return failure();
70+
}
71+
memrefSource = collapseShapeOp.getViewSource();
72+
return success();
73+
})
74+
.Default([&](Operation *op) {
75+
// If the source is not a SubViewOp, ExpandShapeOp, or
76+
// CollapseShapeOp, we cannot fold the GatherToLDSOp.
77+
return rewriter.notifyMatchFailure(
78+
op,
79+
"source producer is not one of SubViewOp, ExpandShapeOp, or "
80+
"CollapseShapeOp");
6881
});
6982

70-
if (!memrefSource)
83+
if (failed(foldResult)) {
7184
return failure();
85+
}
7286

7387
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
7488
op.getDst(), op.getDstIndices(),
@@ -77,9 +91,9 @@ struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
7791
return success();
7892
}
7993
};
80-
} // namespace
8194

82-
void mlir::amdgpu::populateAmdgpuFoldSubviewOpsPatterns(
83-
RewritePatternSet &patterns, PatternBenefit benefit) {
95+
void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
96+
PatternBenefit benefit) {
8497
patterns.add<FoldSubviewIntoGatherToLDSOp>(patterns.getContext(), benefit);
8598
}
99+
} // namespace mlir::amdgpu

mlir/test/Dialect/AMDGPU/amdgpu-fold-subviews.mlir

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
// RUN: mlir-opt -amdgpu-fold-subview-ops -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt --amdgpu-fold-subview-ops --split-input-file %s | FileCheck %s
22

33
#gpu_lds_addrspace = 3
44

5-
// CHECK: func @test_memref
5+
// CHECK: func @test_subview_folding
66
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
7-
func.func @test_memref(%offset_i: index, %offset_j: index) {
7+
func.func @test_subview_folding(%offset_i: index, %offset_j: index) {
88
// CHECK: %[[C0:.*]] = arith.constant 0 : index
99
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
1010
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
@@ -48,3 +48,49 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
4848
: vector<8xf16>, memref<32x64xf16, strided<[128, 1], offset: 4160>>, memref<64x64xf16, #gpu_lds_addrspace>
4949
func.return
5050
}
51+
52+
// -----
53+
54+
#gpu_lds_addrspace = 3
55+
56+
// CHECK: func @test_expand_shape
57+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
58+
func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
59+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
60+
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
61+
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
62+
// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
63+
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
64+
// CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
65+
66+
%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
67+
%mem = memref.alloc() : memref<8192xf16>
68+
%expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
69+
%c0 = arith.constant 0 : index
70+
amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
71+
: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
72+
func.return
73+
}
74+
75+
// -----
76+
77+
#gpu_lds_addrspace = 3
78+
79+
// CHECK: func @test_collapse_shape
80+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
81+
func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
82+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
83+
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
84+
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
85+
// CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
86+
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
87+
// CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>
88+
89+
%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
90+
%mem = memref.alloc() : memref<64x128xf16>
91+
%collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
92+
%c0 = arith.constant 0 : index
93+
amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
94+
: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
95+
func.return
96+
}

0 commit comments

Comments
 (0)