Skip to content

[AMDGPU] fold memref.subview into amdgpu.gather_to_lds #149851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ class ConversionTarget;
namespace amdgpu {

#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
#define GEN_PASS_DECL_AMDGPUFOLDSUBVIEWOPSPASS
#define GEN_PASS_DECL_AMDGPUMASKEDLOADTOLOADPASS
#define GEN_PASS_DECL_AMDGPURESOLVESTRIDEDMETADATAPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"

Expand All @@ -38,6 +39,9 @@ void populateAmdgpuResolveStridedMetadataPatterns(RewritePatternSet &patterns,
void populateAmdgpuMaskedloadToLoadPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit = 1);

} // namespace amdgpu
} // namespace mlir

Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,16 @@ def AmdgpuMaskedloadToLoadPass : Pass<"amdgpu-maskedload-to-load"> {
"memref::MemRefDialect"
];
}

def AmdgpuFoldSubviewOpsPass : Pass<"amdgpu-fold-subview-ops"> {
let summary = "Fold subview operations into their parent operations";
let description = [{
This pass identifies `memref.subview` sources of `GatherToLDSOp` and
attempts to fold the source ops, potentially simplifying the overall
operation and improving performance.
}];
let dependentDialects = [
"memref::MemRefDialect"
];
}
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
37 changes: 37 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,43 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
/// the source memref (i.e. implements ViewLikeOpInterface).
MemrefValue skipViewLikeOps(MemrefValue source);

/// Given the 'indices' of a load/store operation where the memref is a result
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
/// expand_shape op. For example
///
/// %0 = ... : memref<12x42xf32>
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
/// : memref<12x42xf32> into memref<2x6x42xf32>
/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
///
/// could be folded into
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
LogicalResult resolveSourceIndicesExpandShape(
Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds);

/// Given the 'indices' of a load/store operation where the memref is a result
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
/// the collapse_shape op. For example
///
/// %0 = ... : memref<2x6x42xf32>
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
/// : memref<2x6x42xf32> into memref<12x42xf32>
/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
///
/// could be folded into
///
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
/// memref<2x6x42xf32>
LogicalResult
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices);

} // namespace memref
} // namespace mlir

Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
ResolveStridedMetadata.cpp
FoldSubviewOps.cpp
MaskedloadToLoad.cpp
ResolveStridedMetadata.cpp

ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
Expand Down
99 changes: 99 additions & 0 deletions mlir/lib/Dialect/AMDGPU/Transforms/FoldSubviewOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
//===- FoldSubviewOps.cpp - AMDGPU fold subview ops ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"

#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUFOLDSUBVIEWOPSPASS
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"

struct AmdgpuFoldSubviewOpsPass
: public amdgpu::impl::AmdgpuFoldSubviewOpsPassBase<
AmdgpuFoldSubviewOpsPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateAmdgpuFoldSubviewOpsPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};

struct FoldSubviewIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherToLDSOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value memrefSource;
SmallVector<Value> sourceIndices;
auto foldResult =
llvm::TypeSwitch<Operation *, LogicalResult>(
op.getSrc().getDefiningOp())
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
// If the source is a SubViewOp, we can directly rewrite the
// GatherToLDSOp.
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, loc, subviewOp.getMixedOffsets(),
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
op.getSrcIndices(), sourceIndices);
memrefSource = subviewOp.getSource();
return success();
})
.Case<memref::ExpandShapeOp>(
[&](memref::ExpandShapeOp expandShapeOp) {
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
loc, rewriter, expandShapeOp, op.getSrcIndices(),
sourceIndices, false))) {
return failure();
}
memrefSource = expandShapeOp.getViewSource();
return success();
})
.Case<memref::CollapseShapeOp>(
[&](memref::CollapseShapeOp collapseShapeOp) {
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
loc, rewriter, collapseShapeOp, op.getSrcIndices(),
sourceIndices))) {
return failure();
}
memrefSource = collapseShapeOp.getViewSource();
return success();
})
.Default([&](Operation *op) {
// If the source is not a SubViewOp, ExpandShapeOp, or
// CollapseShapeOp, we cannot fold the GatherToLDSOp.
return rewriter.notifyMatchFailure(
op,
"source producer is not one of SubViewOp, ExpandShapeOp, or "
"CollapseShapeOp");
});

if (failed(foldResult)) {
return failure();
}

rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
op.getDst(), op.getDstIndices(),
op.getTransferType());

return success();
}
};

void populateAmdgpuFoldSubviewOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<FoldSubviewIntoGatherToLDSOp>(patterns.getContext(), benefit);
}
} // namespace mlir::amdgpu
91 changes: 0 additions & 91 deletions mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,97 +44,6 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//

/// Given the 'indices' of a load/store operation where the memref is a result
/// of a expand_shape op, returns the indices w.r.t to the source memref of the
/// expand_shape op. For example
///
/// %0 = ... : memref<12x42xf32>
/// %1 = memref.expand_shape %0 [[0, 1], [2]]
/// : memref<12x42xf32> into memref<2x6x42xf32>
/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32
///
/// could be folded into
///
/// %2 = load %0[6 * i1 + i2, %i3] :
/// memref<12x42xf32>
static LogicalResult resolveSourceIndicesExpandShape(
Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();

// Traverse all reassociation groups to determine the appropriate indices
// corresponding to each one of them post op folding.
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
assert(!group.empty() && "association indices groups cannot be empty");
int64_t groupSize = group.size();
if (groupSize == 1) {
sourceIndices.push_back(indices[group[0]]);
continue;
}
SmallVector<OpFoldResult> groupBasis =
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
SmallVector<Value> groupIndices =
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
sourceIndices.push_back(collapsedIndex);
}
return success();
}

/// Given the 'indices' of a load/store operation where the memref is a result
/// of a collapse_shape op, returns the indices w.r.t to the source memref of
/// the collapse_shape op. For example
///
/// %0 = ... : memref<2x6x42xf32>
/// %1 = memref.collapse_shape %0 [[0, 1], [2]]
/// : memref<2x6x42xf32> into memref<12x42xf32>
/// %2 = load %1[%i1, %i2] : memref<12x42xf32>
///
/// could be folded into
///
/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] :
/// memref<2x6x42xf32>
static LogicalResult
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
// Note: collapse_shape requires a strided memref, we can do this.
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
loc, collapseShapeOp.getSrc());
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
for (auto [index, group] :
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
assert(!group.empty() && "association indices groups cannot be empty");
int64_t groupSize = group.size();

if (groupSize == 1) {
sourceIndices.push_back(index);
continue;
}

SmallVector<OpFoldResult> basis =
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, index, basis, /*hasOuterBound=*/true);
llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
}
return success();
}

/// Helpers to access the memref operand for each op.
template <typename LoadOrStoreOpTy>
static Value getMemRefOperand(LoadOrStoreOpTy op) {
Expand Down
66 changes: 66 additions & 0 deletions mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) {
return source;
}

LogicalResult resolveSourceIndicesExpandShape(
Location loc, PatternRewriter &rewriter,
memref::ExpandShapeOp expandShapeOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();

// Traverse all reassociation groups to determine the appropriate indices
// corresponding to each one of them post op folding.
for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
assert(!group.empty() && "association indices groups cannot be empty");
int64_t groupSize = group.size();
if (groupSize == 1) {
sourceIndices.push_back(indices[group[0]]);
continue;
}
SmallVector<OpFoldResult> groupBasis =
llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
SmallVector<Value> groupIndices =
llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
sourceIndices.push_back(collapsedIndex);
}
return success();
}

LogicalResult
resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
memref::CollapseShapeOp collapseShapeOp,
ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
// Note: collapse_shape requires a strided memref, we can do this.
auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
loc, collapseShapeOp.getSrc());
SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
for (auto [index, group] :
llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
assert(!group.empty() && "association indices groups cannot be empty");
int64_t groupSize = group.size();

if (groupSize == 1) {
sourceIndices.push_back(index);
continue;
}

SmallVector<OpFoldResult> basis =
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, index, basis, /*hasOuterBound=*/true);
llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseShapeOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
}
}
return success();
}

} // namespace memref
} // namespace mlir
Loading