Skip to content

Commit cf50f5f

Browse files
committed
Support Expandshape and collapse shape.
1 parent 5d6483d commit cf50f5f

File tree

4 files changed

+137
-107
lines changed

4 files changed

+137
-107
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,43 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
116116
/// the source memref (i.e. implements ViewLikeOpInterface).
117117
MemrefValue skipViewLikeOps(MemrefValue source);
118118

119+
/// Given the 'indices' of a load/store operation where the memref is a result
120+
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
121+
/// expand_shape op. For example
122+
///
123+
/// %0 = ... : memref<12x42xf32>
124+
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
125+
/// : memref<12x42xf32> into memref<2x6x42xf32>
126+
/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
127+
///
128+
/// could be folded into
129+
///
130+
/// %2 = load %0[6 * i1 + i2, %i3] :
131+
/// memref<12x42xf32>
132+
LogicalResult resolveSourceIndicesExpandShape(
133+
Location loc, PatternRewriter &rewriter,
134+
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
135+
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);
136+
137+
/// Given the 'indices' of a load/store operation where the memref is a result
138+
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
139+
/// the collapse_shape op. For example
140+
///
141+
/// %0 = ... : memref<2x6x42xf32>
142+
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
143+
/// : memref<2x6x42xf32> into memref<12x42xf32>
144+
/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
145+
///
146+
/// could be folded into
147+
///
148+
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
149+
/// memref<2x6x42xf32>
150+
LogicalResult
151+
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
152+
memref::CollapseShapeOp collapseShapeOp,
153+
ValueRange indices,
154+
SmallVectorImpl<Value> &sourceIndices);
155+
119156
} // namespace memref
120157
} // namespace mlir
121158

mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
1212
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
1313
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1415
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16+
#include "llvm/ADT/TypeSwitch.h"
1517

1618
namespace mlir::amdgpu {
1719
#define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
@@ -33,28 +35,44 @@ struct AmdgpuFoldSubviewOpsPass
3335
}
3436
};
3537

36-
struct FoldSubviewIntoGatherToLDSOp : public OpRewritePattern<GatherToLDSOp> {
37-
using OpRewritePattern<GatherToLDSOp>::OpRewritePattern;
38+
struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
39+
using OpRewritePattern::OpRewritePattern;
3840
LogicalResult matchAndRewrite(GatherToLDSOp op,
3941
PatternRewriter &rewriter) const override {
4042
Location loc = op.getLoc();
4143

42-
// Check if the source is a subview operation:
43-
auto subviewOp = dyn_cast<memref::SubViewOp>(op.getSrc().getDefiningOp());
44-
if (!subviewOp)
45-
return rewriter.notifyMatchFailure(
46-
loc, "GatherToLDSOp folding is currently supported only when the "
47-
"source is a SubviewOp. This is one specific pattern, and other "
48-
"scenarios may be added in the future.");
49-
44+
Value memrefSource;
5045
SmallVector<Value> sourceIndices;
51-
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
52-
rewriter, loc, subviewOp.getMixedOffsets(), subviewOp.getMixedStrides(),
53-
subviewOp.getDroppedDims(), op.getSrcIndices(), sourceIndices);
46+
llvm::TypeSwitch<Operation *>(op.getSrc().getDefiningOp())
47+
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
48+
// If the source is a SubViewOp, we can directly rewrite the
49+
// GatherToLDSOp.
50+
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
51+
rewriter, loc, subviewOp.getMixedOffsets(),
52+
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
53+
op.getSrcIndices(), sourceIndices);
54+
memrefSource = subviewOp.getSource();
55+
})
56+
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
57+
mlir::memref::resolveSourceIndicesExpandShape(
58+
loc, rewriter, expandShapeOp, op.getSrcIndices(), sourceIndices,
59+
false);
60+
memrefSource = expandShapeOp.getViewSource();
61+
})
62+
.Case<memref::CollapseShapeOp>(
63+
[&](memref::CollapseShapeOp collapseShapeOp) {
64+
mlir::memref::resolveSourceIndicesCollapseShape(
65+
loc, rewriter, collapseShapeOp, op.getSrcIndices(),
66+
sourceIndices);
67+
memrefSource = collapseShapeOp.getViewSource();
68+
});
69+
70+
if (!memrefSource)
71+
return failure();
5472

55-
rewriter.replaceOpWithNewOp<GatherToLDSOp>(
56-
op, subviewOp.getSource(), sourceIndices, op.getDst(),
57-
op.getDstIndices(), op.getTransferType());
73+
rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
74+
op.getDst(), op.getDstIndices(),
75+
op.getTransferType());
5876

5977
return success();
6078
}

mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp

Lines changed: 0 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -44,97 +44,6 @@ using namespace mlir;
4444
// Utility functions
4545
//===----------------------------------------------------------------------===//
4646

47-
/// Given the 'indices' of a load/store operation where the memref is a result
48-
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
49-
/// expand_shape op. For example
50-
///
51-
/// %0 = ... : memref<12x42xf32>
52-
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
53-
/// : memref<12x42xf32> into memref<2x6x42xf32>
54-
/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
55-
///
56-
/// could be folded into
57-
///
58-
/// %2 = load %0[6 * i1 + i2, %i3] :
59-
/// memref<12x42xf32>
60-
static LogicalResult resolveSourceIndicesExpandShape(
61-
Location loc, PatternRewriter &rewriter,
62-
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
63-
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
64-
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
65-
66-
// Traverse all reassociation groups to determine the appropriate indices
67-
// corresponding to each one of them post op folding.
68-
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
69-
assert(!group.empty() && "association indices groups cannot be empty");
70-
int64_t groupSize = group.size();
71-
if (groupSize == 1) {
72-
sourceIndices.push_back(indices[group[0]]);
73-
continue;
74-
}
75-
SmallVector<OpFoldResult> groupBasis =
76-
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
77-
SmallVector<Value> groupIndices =
78-
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
79-
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
80-
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
81-
sourceIndices.push_back(collapsedIndex);
82-
}
83-
return success();
84-
}
85-
86-
/// Given the 'indices' of a load/store operation where the memref is a result
87-
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
88-
/// the collapse_shape op. For example
89-
///
90-
/// %0 = ... : memref<2x6x42xf32>
91-
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
92-
/// : memref<2x6x42xf32> into memref<12x42xf32>
93-
/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
94-
///
95-
/// could be folded into
96-
///
97-
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
98-
/// memref<2x6x42xf32>
99-
static LogicalResult
100-
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
101-
memref::CollapseShapeOp collapseShapeOp,
102-
ValueRange indices,
103-
SmallVectorImpl<Value> &sourceIndices) {
104-
// Note: collapse_shape requires a strided memref, we can do this.
105-
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
106-
loc, collapseShapeOp.getSrc());
107-
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
108-
for (auto [index, group] :
109-
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
110-
assert(!group.empty() && "association indices groups cannot be empty");
111-
int64_t groupSize = group.size();
112-
113-
if (groupSize == 1) {
114-
sourceIndices.push_back(index);
115-
continue;
116-
}
117-
118-
SmallVector<OpFoldResult> basis =
119-
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
120-
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
121-
loc, index, basis, /*hasOuterBound=*/true);
122-
llvm::append_range(sourceIndices, delinearize.getResults());
123-
}
124-
if (collapseShapeOp.getReassociationIndices().empty()) {
125-
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
126-
int64_t srcRank =
127-
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
128-
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
129-
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
130-
for (int64_t i = 0; i < srcRank; i++) {
131-
sourceIndices.push_back(
132-
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
133-
}
134-
}
135-
return success();
136-
}
137-
13847
/// Helpers to access the memref operand for each op.
13948
template <typename LoadOrStoreOpTy>
14049
static Value getMemRefOperand(LoadOrStoreOpTy op) {

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1414
#include "mlir/Dialect/Affine/IR/AffineOps.h"
15+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1516
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1617
#include "mlir/Interfaces/ViewLikeInterface.h"
1718
#include "llvm/ADT/STLExtras.h"
@@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
217218
return source;
218219
}
219220

221+
LogicalResult resolveSourceIndicesExpandShape(
222+
Location loc, PatternRewriter &rewriter,
223+
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
224+
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
225+
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
226+
227+
// Traverse all reassociation groups to determine the appropriate indices
228+
// corresponding to each one of them post op folding.
229+
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
230+
assert(!group.empty() && "association indices groups cannot be empty");
231+
int64_t groupSize = group.size();
232+
if (groupSize == 1) {
233+
sourceIndices.push_back(indices[group[0]]);
234+
continue;
235+
}
236+
SmallVector<OpFoldResult> groupBasis =
237+
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
238+
SmallVector<Value> groupIndices =
239+
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
240+
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
241+
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
242+
sourceIndices.push_back(collapsedIndex);
243+
}
244+
return success();
245+
}
246+
247+
LogicalResult
248+
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
249+
memref::CollapseShapeOp collapseShapeOp,
250+
ValueRange indices,
251+
SmallVectorImpl<Value> &sourceIndices) {
252+
// Note: collapse_shape requires a strided memref, we can do this.
253+
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
254+
loc, collapseShapeOp.getSrc());
255+
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
256+
for (auto [index, group] :
257+
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
258+
assert(!group.empty() && "association indices groups cannot be empty");
259+
int64_t groupSize = group.size();
260+
261+
if (groupSize == 1) {
262+
sourceIndices.push_back(index);
263+
continue;
264+
}
265+
266+
SmallVector<OpFoldResult> basis =
267+
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
268+
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
269+
loc, index, basis, /*hasOuterBound=*/true);
270+
llvm::append_range(sourceIndices, delinearize.getResults());
271+
}
272+
if (collapseShapeOp.getReassociationIndices().empty()) {
273+
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
274+
int64_t srcRank =
275+
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
276+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
277+
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
278+
for (int64_t i = 0; i < srcRank; i++) {
279+
sourceIndices.push_back(
280+
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
281+
}
282+
}
283+
return success();
284+
}
285+
220286
} // namespace memref
221287
} // namespace mlir

0 commit comments

Comments
 (0)