Skip to content

Commit be1d8cc

Browse files
committed
Add prefetch functionality for gfx1250
1 parent 2f9b83b commit be1d8cc

14 files changed

+445
-15
lines changed

mlir/include/mlir/Dialect/Rock/IR/AmdArchDb.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ struct AmdArchInfo {
5252
};
5353

5454
AmdArchInfo lookupArchInfo(StringRef arch);
55+
bool isGlobalPrefetchSupported(StringRef arch);
5556
} // namespace rock
5657
} // namespace mlir
5758

mlir/include/mlir/Dialect/Rock/IR/RockOps.td

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,26 @@ def Rock_InBoundsStoreOp
10171017
let hasVerifier = 1;
10181018
}
10191019

1020+
// global_prefetch
1021+
def Rock_GlobalPrefetchOp
1022+
: Rock_Op<"global_prefetch">,
1023+
Arguments<(
1024+
ins Arg<MemRefOf<SupportedMemoryElems>, "source memory">:$source,
1025+
Variadic<Index>:$sourceCoord)> {
1026+
let summary = "Prefetch from global memory";
1027+
let description = [{
1028+
`global_prefetch` prefetches an item from the provided memref
1029+
(applying no coordinate transformations), starting at `sourceCoord`.
1030+
1031+
Note that it's ok if `sourceCoord` are out of bounds because we use Speculative Prefetch.
1032+
}];
1033+
let assemblyFormat = [{
1034+
$source `[` $sourceCoord `]` attr-dict
1035+
`:` type($source)
1036+
}];
1037+
let hasVerifier = 1;
1038+
}
1039+
10201040
// global_load
10211041
def Rock_GlobalLoadOp
10221042
: Rock_Op<
@@ -1188,7 +1208,7 @@ def Rock_ThreadwiseReadIntoOp
11881208
L is the length of `%dest`, V is the maximum vectorization computed
11891209
for MAPS.
11901210

1191-
The input to extraViews ; (the transforms on %source) must have the form
1211+
The input to extraViews: (the transforms on %source) must have the form
11921212
(extraIdx0, ... , extraIdxN, iteration_number)
11931213

11941214
Primarily, extraIndices would be used to pass in tid and bid. This would need
@@ -1235,6 +1255,59 @@ def Rock_ThreadwiseReadIntoOp
12351255
let hasVerifier = 1;
12361256
}
12371257

1258+
// threadwise_prefetch
1259+
def Rock_ThreadwisePrefetchOp
1260+
: Rock_Op<"threadwise_prefetch", [DeclareOpInterfaceMethods<
1261+
RockAcceptingViewOpInterface>]>,
1262+
Arguments<(ins Arg<MemRefOf<SupportedMemoryElems>, "source view">:$source,
1263+
TransformMapArrayAttr:$extraViews, Variadic<Index>:$extraIndices,
1264+
UnitAttr:$forceUnroll, UnitAttr:$useIndexDiffs)> {
1265+
let summary = "Prefetch values from transformed source";
1266+
1267+
let description = [{
1268+
A high-level representation of a global prefetch loop that
1269+
accounts for coordinate transformations.
1270+
1271+
If `%source = rock.transform #transform_mapN %buffer`
1272+
(with the one transformation representing an entire sequence),
1273+
the operation
1274+
1275+
```mlir
1276+
rock.threadwise_prefetch [#transform_mapM](%source)
1277+
```
1278+
1279+
will lower to
1280+
```mlir
1281+
%bid = rock.workgroup_id
1282+
%tid = rock.workitem_id
1283+
rock.transforming_for
1284+
(%args, ...) = MAPS(%bid, %tid, %c0)
1285+
(%_, %_, %i) = [](%c0, %c0, %c0)
1286+
bounds = [1, 1, L], strides = [1, 1, 1] {
1287+
rock.global_prefetch %buffer[%args]
1288+
}
1289+
```
1290+
1291+
where MAPS is `[#transform_mapM, #transform_mapN]`,
1292+
L is the length of the upper view last dimension (number of elements).
1293+
1294+
The input to extraViews ; (the transforms on %source) must have the form
1295+
(extraIdx0, ... , extraIdxN, iteration_number)
1296+
1297+
Primarily, extraIndices would be used to pass in tid and bid. This would need
1298+
to have a matching view in [extraViews]source. The extraIndices could be used
1299+
to integrate loop induction vars that is outside of the op.
1300+
If extraIndices are used, the [extraViews]source must have the form
1301+
(extraIdx0, ... , extraIdxN, iteration_number).
1302+
}];
1303+
1304+
let assemblyFormat = [{
1305+
attr-dict $extraViews `(` $source `)` (`[` $extraIndices^ `]`)?
1306+
`:` type($source)
1307+
}];
1308+
let hasVerifier = 1;
1309+
}
1310+
12381311
// threadwise_write_all
12391312
def Rock_ThreadwiseWriteAllOp
12401313
: Rock_Op<"threadwise_write_all",

mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,18 @@ static constexpr AmdArchInfo
112112
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/12,
113113
/*hasFp8ConversionInstrs=*/false,
114114
/*hasOcpFp8ConversionInstrs=*/true, /*hasScaledGemm=*/false,
115-
/*maxNumXCC=*/1);
115+
/*maxNumXCC=*/1),
116+
// TODO: update with right information
117+
gfx1250Info(GemmFeatures::dot | GemmFeatures::atomic_add |
118+
GemmFeatures::atomic_fmax_f32 | GemmFeatures::wmma |
119+
GemmFeatures::atomic_add_f16 |
120+
GemmFeatures::atomic_add_bf16,
121+
/*waveSize=*/32, /*maxWavesPerEU*/ 16, /*totalSGPRPerEU*/ 800,
122+
/*totalVGPRPerEU*/ 1536, /*totalSharedMemPerCU*/ 131072,
123+
/*maxSharedMemPerWG*/ 65536, /*numEUPerCU=*/4, /*minNumCU=*/12,
124+
/*hasFp8ConversionInstrs=*/false,
125+
/*hasOcpFp8ConversionInstrs=*/true, /*hasScaledGemm=*/false,
126+
/*maxNumXCC=*/1);
116127

117128
static std::tuple<StringRef, unsigned> parseArchString(StringRef arch) {
118129
std::tuple<StringRef, unsigned> ret("", 0);
@@ -367,7 +378,9 @@ AmdArchInfo mlir::rock::lookupArchInfo(StringRef arch) {
367378
return rdna3Info;
368379
}
369380
if (major == "gfx12") {
370-
return rdna4Info;
381+
return llvm::StringSwitch<AmdArchInfo>(minor)
382+
.Case("50", gfx1250Info)
383+
.Default(rdna4Info);
371384
}
372385
auto msg = "Unsupported architecture: " + arch.str();
373386
llvm_unreachable(msg.c_str());
@@ -413,3 +426,7 @@ GemmFeatures mlir::rock::AmdArchInfo::getDefaultFeatures(Type dataType) {
413426
}
414427
return theseFeatures;
415428
}
429+
430+
bool mlir::rock::isGlobalPrefetchSupported(StringRef arch) {
431+
return arch == "gfx1250";
432+
}

mlir/lib/Dialect/Rock/IR/RockDialect.cpp

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,15 +1835,13 @@ LogicalResult IndexDiffUpdateOp::verify() {
18351835
return success();
18361836
}
18371837

1838-
template <typename Load>
1839-
static LogicalResult verifyGlobalLoad(Load op) {
1838+
template <typename LoadOrPrefetch>
1839+
static LogicalResult verifyGlobalLoadAndPrefetch(LoadOrPrefetch op) {
18401840
MemRefType sourceType = op.getSource().getType();
18411841
size_t nDims = sourceType.getRank();
18421842

18431843
if (op.getSourceCoord().size() != nDims)
1844-
return op.emitOpError("Expected " + Twine(nDims) + " coordinates for load");
1845-
if (op.getCanReadOffEnd() && nDims != 1)
1846-
return op.emitOpError("can only have one dimension in canReadOffEnd loads");
1844+
return op.emitOpError("Expected " + Twine(nDims) + " coordinates");
18471845
Attribute memSpaceAttr = sourceType.getMemorySpace();
18481846
auto gpuMemSpaceAttr = dyn_cast_or_null<gpu::AddressSpaceAttr>(memSpaceAttr);
18491847
if (memSpaceAttr && (!gpuMemSpaceAttr ||
@@ -1852,6 +1850,27 @@ static LogicalResult verifyGlobalLoad(Load op) {
18521850
return success();
18531851
}
18541852

1853+
template <typename Load>
1854+
static LogicalResult verifyGlobalLoad(Load op) {
1855+
if (failed(verifyGlobalLoadAndPrefetch(op)))
1856+
return failure();
1857+
1858+
MemRefType sourceType = op.getSource().getType();
1859+
size_t nDims = sourceType.getRank();
1860+
1861+
if (op.getCanReadOffEnd() && nDims != 1)
1862+
return op.emitOpError("can only have one dimension in canReadOffEnd loads");
1863+
return success();
1864+
}
1865+
1866+
//===-----------------------------------------------------===//
1867+
// GlobalPrefetchOp
1868+
//===-----------------------------------------------------===//
1869+
1870+
LogicalResult GlobalPrefetchOp::verify() {
1871+
return verifyGlobalLoadAndPrefetch(*this);
1872+
}
1873+
18551874
//===-----------------------------------------------------===//
18561875
// GlobalLoadOp
18571876
//===-----------------------------------------------------===//
@@ -1956,6 +1975,70 @@ LogicalResult InBoundsStoreOp::verify() {
19561975
return success();
19571976
}
19581977

1978+
//===-----------------------------------------------------===//
1979+
// ThreadwisePrefetchOp
1980+
//===-----------------------------------------------------===//
1981+
1982+
SmallPtrSet<OpOperand *, 2> ThreadwisePrefetchOp::getAcceptingViewOperands() {
1983+
auto operands = getOperation()->getOpOperands();
1984+
return {operands.begin()};
1985+
}
1986+
1987+
std::optional<OperandRange>
1988+
ThreadwisePrefetchOp::getExtraIndices(OpOperand &operand) {
1989+
if (!getAcceptingViewOperands().contains(&operand)) {
1990+
return std::nullopt;
1991+
}
1992+
// Only one operand supports view
1993+
return getExtraIndices();
1994+
}
1995+
1996+
Operation *
1997+
ThreadwisePrefetchOp::cloneWithExtraIndices(OpBuilder &builder,
1998+
OpOperand &operand, Value view,
1999+
ArrayRef<Value> newExtraIndices) {
2000+
if (!getAcceptingViewOperands().contains(&operand)) {
2001+
return getOperation();
2002+
}
2003+
2004+
// Only one operand supports view
2005+
auto newOp = ThreadwisePrefetchOp::create(
2006+
builder, getLoc(), view, getExtraViews(), newExtraIndices,
2007+
getForceUnroll(), getUseIndexDiffs());
2008+
return newOp.getOperation();
2009+
}
2010+
2011+
LogicalResult ThreadwisePrefetchOp::verify() {
2012+
MemRefType srcType = getSource().getType();
2013+
Attribute srcMemSpaceAttr = srcType.getMemorySpace();
2014+
auto gpuSrcMemSpaceAttr =
2015+
dyn_cast_or_null<gpu::AddressSpaceAttr>(srcMemSpaceAttr);
2016+
if (srcMemSpaceAttr &&
2017+
(!gpuSrcMemSpaceAttr ||
2018+
gpuSrcMemSpaceAttr.getValue() != gpu::AddressSpace::Global))
2019+
return emitOpError("prefetching only works for global");
2020+
2021+
// we are checking below if extra indices match the upper view bounds.
2022+
// we should expect zero extra indices if we are prefetching a scalar.
2023+
// And upperBounds.size() + 1 otherwise.
2024+
ArrayAttr extraViews = getExtraViews();
2025+
ArrayRef<int64_t> inputShape;
2026+
if (extraViews.empty())
2027+
inputShape = srcType.getShape();
2028+
else
2029+
inputShape = cast<TransformMapAttr>(extraViews[0]).getUpperBounds();
2030+
2031+
size_t extraIdxCount = getExtraIndices().size();
2032+
if (inputShape.empty()) {
2033+
if (extraIdxCount != 0)
2034+
return emitOpError("read from a scalar value cannot have coordinates");
2035+
} else if (inputShape.size() != extraIdxCount + 1) {
2036+
return emitOpError("source view must be extraIndices + 1");
2037+
}
2038+
2039+
return success();
2040+
}
2041+
19592042
//===-----------------------------------------------------===//
19602043
// ThreadwiseReadIntoOp
19612044
//===-----------------------------------------------------===//

mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ traceToNonViewReaders(Operation *op, Value parentVal,
478478
if (copyOp.getSource() == parentVal) {
479479
nonViewReaders.push_back(op);
480480
}
481+
} else if (isa<ThreadwisePrefetchOp>(op)) {
482+
// ignore because ThreadwisePrefetchOp is not a reader.
481483
} else {
482484
return op->emitError() << "Found an unsupported operator that needs to "
483485
"be added reader checks \n"

mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818

1919
#include "GridLayoutEmitter.h"
2020
#include "mlir/Dialect/Affine/IR/AffineOps.h"
21+
#include "mlir/Dialect/Arith/IR/Arith.h"
2122
#include "mlir/Dialect/Func/IR/FuncOps.h"
2223
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2324
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
25+
#include "mlir/Dialect/Rock/IR/AmdArchDb.h"
2426
#include "mlir/Dialect/Rock/IR/GetRockInfo.h"
2527
#include "mlir/Dialect/Rock/IR/Rock.h"
2628
#include "mlir/Dialect/Rock/IR/RockGemmGemmWrapperInterface.h"
@@ -263,6 +265,22 @@ class LoweringBlockwiseLoadTileOp final
263265
/*dynamicValidities=*/ValueRange{},
264266
/*extraViews=*/b.getArrayAttr({}),
265267
/*extraIndices=*/indices, forceUnroll, true);
268+
269+
if (rock::isGlobalPrefetchSupported(arch)) {
270+
// add one to k_loop to prefetch next iteration
271+
SmallVector<Value> indicesNext(indices.begin(), indices.end());
272+
Value one = b.createOrFold<arith::ConstantIndexOp>(loc, 1);
273+
indicesNext[0] =
274+
arith::AddIOp::create(b, loc, indicesNext[0], one).getResult();
275+
276+
// it's acceptable if the indices are out of bounds because we use
277+
// GLOBAL_PREFETCH_B8 with Speculative Prefetch. See llvm.prefetch
278+
// documentation in AMDGPUUsage.rst
279+
rock::ThreadwisePrefetchOp::create(b, loc, wrappedSource,
280+
/*extraViews=*/b.getArrayAttr({}),
281+
/*extraIndices=*/indicesNext,
282+
forceUnroll, true);
283+
}
266284
if (stageGlobalReadNew)
267285
rock::YieldOp::create(b, loc);
268286
}

mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,26 @@ Value selectDataIf4b(Location loc, PatternRewriter &b,
11641164
return b.createOrFold<vector::ExtractOp>(loc, loadedVec, lsb);
11651165
}
11661166

1167+
struct GlobalPrefetchRewritePattern
1168+
: public OpRewritePattern<GlobalPrefetchOp> {
1169+
using OpRewritePattern<GlobalPrefetchOp>::OpRewritePattern;
1170+
LogicalResult matchAndRewrite(GlobalPrefetchOp op,
1171+
PatternRewriter &b) const override {
1172+
Value source = op.getSource();
1173+
1174+
source = asGlobal(b, source);
1175+
1176+
// it's acceptable if the indices are out of bounds because we use
1177+
// GLOBAL_PREFETCH_B8 with Speculative Prefetch. See llvm.prefetch
1178+
// documentation in AMDGPUUsage.rst localityHint=3 is translated to memory
1179+
// scope SCOPE_SE.
1180+
b.replaceOpWithNewOp<memref::PrefetchOp>(
1181+
op, source, op.getSourceCoord(), /*isWrite=*/false, /*localityHint=*/3,
1182+
/*isDataCache=*/true);
1183+
return success();
1184+
}
1185+
};
1186+
11671187
struct GlobalLoadRewritePattern : public OpRewritePattern<GlobalLoadOp> {
11681188
using OpRewritePattern<GlobalLoadOp>::OpRewritePattern;
11691189
LogicalResult matchAndRewrite(GlobalLoadOp op,
@@ -1620,8 +1640,8 @@ void RockSugarToLoopsPass::runOnOperation() {
16201640
RewritePatternSet patterns(ctx);
16211641
patterns.add<ExtractSliceRewritePattern, InsertSliceRewritePattern,
16221642
GlobalLoadRewritePattern, GlobalLoadToLDSRewritePattern,
1623-
GlobalStoreRewritePattern, InBoundsLoadRewritePattern,
1624-
InBoundsStoreRewritePattern>(ctx);
1643+
GlobalPrefetchRewritePattern, GlobalStoreRewritePattern,
1644+
InBoundsLoadRewritePattern, InBoundsStoreRewritePattern>(ctx);
16251645
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
16261646
signalPassFailure();
16271647

0 commit comments

Comments
 (0)