Skip to content

Commit 0ef1308

Browse files
committed
Add horizontal bached reduce
1 parent e1e48a8 commit 0ef1308

File tree

6 files changed

+204
-0
lines changed

6 files changed

+204
-0
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ class TargetInfoBase {
6464
unsigned numLaneToReduce,
6565
unsigned interleave) const = 0;
6666

67+
virtual bool
68+
warpBatchReduce(RewriterBase &rewriter, Location loc,
69+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
70+
triton::ReduceOp op, unsigned numLaneToReduce,
71+
unsigned interleave) const = 0;
72+
6773
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
6874
// Emits LLVM code with |rewriter| to print a message following the given
6975
// format from the device. |formatStrStart| is the pointer to the start of

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
6161
triton::ReduceOp op, unsigned numLaneToReduce,
6262
unsigned interleave) const override;
6363

64+
bool warpBatchReduce(RewriterBase &rewriter, Location loc,
65+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
66+
triton::ReduceOp op, unsigned numLaneToReduce,
67+
unsigned interleave) const override {
68+
return false;
69+
};
70+
6471
std::string getMulhiFuncName(Type resultElementTy) const override;
6572

6673
void printf(RewriterBase &rewriter, Value formatStrStart,

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ struct ReduceOpConversion
176176
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
177177
unsigned threadOffsetOnReductionAxis =
178178
helper.getThreadOffsetOnReductionAxis();
179+
180+
auto ret =
181+
targetInfo.warpBatchReduce(rewriter, op.getLoc(), accs, op,
182+
sizeIntraWarps, threadOffsetOnReductionAxis);
183+
184+
if (ret)
185+
return;
186+
179187
for (auto it : accs) {
180188
const SmallVector<unsigned> &key = it.first;
181189
SmallVector<Value> &acc = accs[key];

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "TargetInfo.h"
10+
#include "intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h"
11+
1012
#include "Dialect/TritonIntelGPU/IR/Utils.h"
1113
#include "SPIRVTargetInfo.h"
1214
#include "Utility.h"
@@ -112,6 +114,175 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
112114
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, blockId);
113115
}
114116

117+
bool TargetInfo::warpBatchReduce(
118+
RewriterBase &rewriter, Location loc,
119+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
120+
triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const {
121+
// No horizontal reduce required.
122+
if (numLaneToReduce == 1)
123+
return false;
124+
// Horizontal reduce with interleave stride not supported.
125+
if (interleave > 1)
126+
return false;
127+
// Check if it is a simple reduce operation supported by
128+
// TritonGEN::SubGroupReduceOp.
129+
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
130+
return false;
131+
Region &combineOp = op.getCombineOp();
132+
if (combineOp.getBlocks().size() > 1)
133+
return false;
134+
Block &block = *combineOp.begin();
135+
Operation *yield = block.getTerminator();
136+
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
137+
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
138+
reduceOp->getNumResults() != 1)
139+
return false;
140+
if (reduceOp->getOperand(0) != block.getArgument(0) ||
141+
reduceOp->getOperand(1) != block.getArgument(1))
142+
return false;
143+
144+
auto mod = op->getParentOfType<ModuleOp>();
145+
unsigned warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
146+
147+
if (!isSupportedWarpReduceOp(reduceOp, numLaneToReduce, warpSize))
148+
return false;
149+
150+
// It is only experimental code supports threads_per_warp=16
151+
if (warpSize != 16)
152+
return false;
153+
154+
if (acc.size() == 16 && isa<arith::AddFOp, arith::MaxNumFOp>(reduceOp)) {
155+
156+
// Group the acc in batch.
157+
SmallVector<Value> grouped_accs;
158+
for (auto it : acc) {
159+
const SmallVector<unsigned> &key = it.first;
160+
SmallVector<Value> &val = acc[key];
161+
assert(val.size() == 1 && "acc size has to be 1 for ungrouped input");
162+
grouped_accs.push_back(val[0]);
163+
}
164+
165+
VectorType reduceTy =
166+
vec_ty(grouped_accs[0].getType(), grouped_accs.size());
167+
Value batchedReduceVal = rewriter.create<LLVM::UndefOp>(loc, reduceTy);
168+
auto b = TritonLLVMOpBuilder(loc, rewriter);
169+
for (unsigned i = 0; i < grouped_accs.size(); ++i) {
170+
batchedReduceVal = b.insert_element(reduceTy, batchedReduceVal,
171+
grouped_accs[i], b.i32_val(i));
172+
}
173+
XeBuilder vISABuilder;
174+
std::string batchedHorizontalReduce;
175+
if (isa<arith::AddFOp>(reduceOp)) {
176+
batchedHorizontalReduce =
177+
"{\n"
178+
".decl temp_result v_type=G type=f num_elts=128 align=wordx32\n"
179+
// 1st round 2x8 + 2x8 -> 1x16
180+
"add (M1_NM, 16) temp_result(0, 0)<1> $1(0, 0)<16;8,1> $1(0, "
181+
"8)<16;8,1> \n"
182+
"add (M1_NM, 16) temp_result(1, 0)<1> $1(2, 0)<16;8,1> $1(2, "
183+
"8)<16;8,1> \n"
184+
"add (M1_NM, 16) temp_result(2, 0)<1> $1(4, 0)<16;8,1> $1(4, "
185+
"8)<16;8,1> \n"
186+
"add (M1_NM, 16) temp_result(3, 0)<1> $1(6, 0)<16;8,1> $1(6, "
187+
"8)<16;8,1> \n"
188+
"add (M1_NM, 16) temp_result(4, 0)<1> $1(8, 0)<16;8,1> $1(8, "
189+
"8)<16;8,1> \n"
190+
"add (M1_NM, 16) temp_result(5, 0)<1> $1(10, 0)<16;8,1> $1(10, "
191+
"8)<16;8,1> \n"
192+
"add (M1_NM, 16) temp_result(6, 0)<1> $1(12, 0)<16;8,1> $1(12, "
193+
"8)<16;8,1> \n"
194+
"add (M1_NM, 16) temp_result(7, 0)<1> $1(14, 0)<16;8,1> $1(14, "
195+
"8)<16;8,1> \n"
196+
197+
// 2nd round 2x2x4 + 2x2x4 -> 1x16
198+
"add (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<8;4,1> "
199+
"temp_result(0, 4)<8;4,1> \n"
200+
"add (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<8;4,1> "
201+
"temp_result(2, 4)<8;4,1> \n"
202+
"add (M1_NM, 16) temp_result(2, 0)<1> temp_result(4, 0)<8;4,1> "
203+
"temp_result(4, 4)<8;4,1> \n"
204+
"add (M1_NM, 16) temp_result(3, 0)<1> temp_result(6, 0)<8;4,1> "
205+
"temp_result(6, 4)<8;4,1> \n"
206+
207+
// 3rd round 4x2x2 + 4x2x2 -> 1x16
208+
"add (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<4;2,1> "
209+
"temp_result(0, 2)<4;2,1> \n"
210+
"add (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<4;2,1> "
211+
"temp_result(2, 2)<4;2,1> \n"
212+
213+
// 4th round 8x2x1 + 8x2x1 -> 1x16
214+
"add (M1_NM, 16) $0(0, 0)<1> temp_result(0, 0)<2;1,0> "
215+
"temp_result(0, 1)<2;1,0> \n"
216+
"}\n";
217+
} else if (isa<arith::MaxNumFOp>(reduceOp)) {
218+
batchedHorizontalReduce =
219+
"{\n"
220+
".decl temp_result v_type=G type=f num_elts=128 align=wordx32\n"
221+
// 1st round 2x8 + 2x8 -> 1x16
222+
"max (M1_NM, 16) temp_result(0, 0)<1> $1(0, 0)<16;8,1> $1(0, "
223+
"8)<16;8,1> \n"
224+
"max (M1_NM, 16) temp_result(1, 0)<1> $1(2, 0)<16;8,1> $1(2, "
225+
"8)<16;8,1> \n"
226+
"max (M1_NM, 16) temp_result(2, 0)<1> $1(4, 0)<16;8,1> $1(4, "
227+
"8)<16;8,1> \n"
228+
"max (M1_NM, 16) temp_result(3, 0)<1> $1(6, 0)<16;8,1> $1(6, "
229+
"8)<16;8,1> \n"
230+
"max (M1_NM, 16) temp_result(4, 0)<1> $1(8, 0)<16;8,1> $1(8, "
231+
"8)<16;8,1> \n"
232+
"max (M1_NM, 16) temp_result(5, 0)<1> $1(10, 0)<16;8,1> $1(10, "
233+
"8)<16;8,1> \n"
234+
"max (M1_NM, 16) temp_result(6, 0)<1> $1(12, 0)<16;8,1> $1(12, "
235+
"8)<16;8,1> \n"
236+
"max (M1_NM, 16) temp_result(7, 0)<1> $1(14, 0)<16;8,1> $1(14, "
237+
"8)<16;8,1> \n"
238+
239+
// 2nd round 2x2x4 + 2x2x4 -> 1x16
240+
"max (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<8;4,1> "
241+
"temp_result(0, 4)<8;4,1> \n"
242+
"max (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<8;4,1> "
243+
"temp_result(2, 4)<8;4,1> \n"
244+
"max (M1_NM, 16) temp_result(2, 0)<1> temp_result(4, 0)<8;4,1> "
245+
"temp_result(4, 4)<8;4,1> \n"
246+
"max (M1_NM, 16) temp_result(3, 0)<1> temp_result(6, 0)<8;4,1> "
247+
"temp_result(6, 4)<8;4,1> \n"
248+
249+
// 3rd round 4x2x2 + 4x2x2 -> 1x16
250+
"max (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<4;2,1> "
251+
"temp_result(0, 2)<4;2,1> \n"
252+
"max (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<4;2,1> "
253+
"temp_result(2, 2)<4;2,1> \n"
254+
255+
// 4th round 8x2x1 + 8x2x1 -> 1x16
256+
"max (M1_NM, 16) $0(0, 0)<1> temp_result(0, 0)<2;1,0> "
257+
"temp_result(0, 1)<2;1,0> \n"
258+
"}\n";
259+
} else {
260+
llvm_unreachable("batched reduce WIP");
261+
}
262+
263+
auto &bReduceOp = *vISABuilder.create<>(batchedHorizontalReduce);
264+
auto res = vISABuilder.newOperand("=rw.u");
265+
auto in = vISABuilder.newOperand(batchedReduceVal, "rw");
266+
bReduceOp({res, in}, /*onlyAttachMLIRArgs=*/true);
267+
Value ret = vISABuilder.launch(rewriter, loc, reduceTy, false);
268+
Type resultTy = reduceTy.getElementType();
269+
for (unsigned i = 0; i < grouped_accs.size(); ++i) {
270+
grouped_accs[i] = b.extract_element(resultTy, ret, b.i32_val(i));
271+
}
272+
273+
unsigned grouped_iter = 0;
274+
for (auto it : acc) {
275+
const SmallVector<unsigned> &key = it.first;
276+
SmallVector<Value> &val = acc[key];
277+
val[0] = grouped_accs[grouped_iter++];
278+
}
279+
280+
return true;
281+
}
282+
283+
return false;
284+
}
285+
115286
bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
116287
SmallVector<Value> &acc, triton::ReduceOp op,
117288
unsigned numLaneToReduce,

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
5353
Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp,
5454
int axis) const override;
5555

56+
bool warpBatchReduce(RewriterBase &rewriter, Location loc,
57+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
58+
triton::ReduceOp op, unsigned numLaneToReduce,
59+
unsigned interleave) const override;
60+
5661
bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector<Value> &acc,
5762
triton::ReduceOp op, unsigned numLaneToReduce,
5863
unsigned interleave) const override;

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
5151
triton::ReduceOp op, unsigned numLaneToReduce,
5252
unsigned interleave) const override;
5353

54+
bool warpBatchReduce(RewriterBase &rewriter, Location loc,
55+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
56+
triton::ReduceOp op, unsigned numLaneToReduce,
57+
unsigned interleave) const override {
58+
return false;
59+
};
60+
5461
std::string getMulhiFuncName(Type resultElementTy) const override;
5562

5663
void printf(RewriterBase &rewriter, Value formatStrStart,

0 commit comments

Comments
 (0)