Skip to content

Commit 8d4e6e0

Browse files
committed
Add horizontal batched reduce.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent dca7748 commit 8d4e6e0

File tree

11 files changed

+448
-1
lines changed

11 files changed

+448
-1
lines changed

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ class TargetInfoBase {
6464
unsigned numLaneToReduce,
6565
unsigned interleave) const = 0;
6666

67+
virtual bool
68+
warpBatchReduce(RewriterBase &rewriter, Location loc,
69+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
70+
triton::ReduceOp op, unsigned numLaneToReduce,
71+
unsigned interleave) const = 0;
72+
6773
virtual std::string getMulhiFuncName(Type resultElementTy) const = 0;
6874
// Emits LLVM code with |rewriter| to print a message following the given
6975
// format from the device. |formatStrStart| is the pointer to the start of

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5252
"TRITON_INTEL_FAST_MATH",
5353
"TRITON_INTEL_RAISE_BLOCK_POINTER",
5454
"TRITON_INTEL_REDUCE_TRANSPOSE",
55+
"TRITON_INTEL_ENABLE_SIMD_REDUCE",
5556
// clang-format on
5657
};
5758

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import re
2+
import numpy as np
3+
from numpy.random import RandomState
4+
import pytest
5+
import torch
6+
import pathlib
7+
8+
import triton
9+
from triton._internal_testing import numpy_random, to_numpy
10+
11+
MIN_GROUP_SIZE = torch.xpu.get_device_capability()['sub_group_sizes'][0]
12+
13+
14+
class DpasLayout:
15+
16+
def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta,
17+
rep_cluster):
18+
self.repeatCount = repeatCount
19+
self.systolic_depth = systolic_depth
20+
self.execution_size = execution_size
21+
self.ops_per_chan = ops_per_chan
22+
self.threads_per_warp = threads_per_warp
23+
self.warps_per_cta = warps_per_cta
24+
self.rep_cluster = rep_cluster
25+
26+
def __str__(self):
27+
return f"#triton_intel_gpu.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>"
28+
29+
30+
class BlockedLayout:
31+
32+
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga=[1, 1],
33+
cta_split_num=[1, 1], cta_order=[0, 1]):
34+
self.sz_per_thread = size_per_thread
35+
self.threads_per_warp = threads_per_warp
36+
self.warps_per_cta = warps_per_cta
37+
self.order = order
38+
self.ctas_per_cga = ctas_per_cga
39+
self.cta_split_num = cta_split_num
40+
self.cta_order = cta_order
41+
42+
def __str__(self):
43+
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
44+
45+
46+
layouts = [
47+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
48+
warps_per_cta=[2, 2], rep_cluster=[1, 1]),
49+
BlockedLayout([1, 1], [1, 16], [2, 2], [1, 0])
50+
]
51+
52+
if MIN_GROUP_SIZE == 16:
53+
# Add threads_per_warp=32 cases.
54+
layouts + [
55+
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=16, ops_per_chan=2, threads_per_warp=16,
56+
warps_per_cta=[2, 2], rep_cluster=[1, 1]),
57+
BlockedLayout([1, 1], [1, 16], [2, 2], [1, 0])
58+
]
59+
60+
61+
def warps_per_cta(layout, shape):
62+
return layout.warps_per_cta
63+
64+
65+
GPU_DIALECT = "ttg"
66+
67+
68+
@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32], [64, 32], [16, 16]])
69+
@pytest.mark.parametrize("src_layout", layouts)
70+
@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"])
71+
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
72+
def test_horizontal_simd_reduce(M, N, src_layout, dtype_str, reduce_op, device, tmp_path: pathlib.Path):
73+
ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str]
74+
arith_op = {
75+
"max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, #
76+
"sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"}
77+
}[reduce_op][dtype_str]
78+
numpy_op = {"max": np.max, "sum": np.sum}[reduce_op]
79+
rdims_1d = f"{M}"
80+
rdims_2d = f"{M}x1"
81+
store_range = "%1"
82+
warps = src_layout.warps_per_cta
83+
threads_per_warp = int(np.prod(src_layout.threads_per_warp))
84+
num_warps = int(np.prod(warps))
85+
blocked = BlockedLayout([1, 1], [16, threads_per_warp // 16], [4, num_warps // 4], [0, 1], [1, 1], [1, 1], [0, 1])
86+
one_d_layout = BlockedLayout([1], [threads_per_warp], [num_warps], [0], [1], [1], [0])
87+
88+
ir = f"""
89+
#blocked = {blocked}
90+
#src = {src_layout}
91+
#one_d_layout = {one_d_layout}
92+
module attributes {{"ttg.num-warps" = {num_warps} : i32, "ttg.num-ctas" = 1 : i32, "ttg.threads-per-warp" = {threads_per_warp} : i32, "triton_intel_gpu.min_sg_size" = {MIN_GROUP_SIZE} }} {{
93+
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
94+
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
95+
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked>
96+
%2 = tt.splat %arg1 : i32 -> tensor<{M}x1xi32, #blocked>
97+
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
98+
%4 = tt.splat %arg0 : !tt.ptr<{ty}> -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked>
99+
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked>
100+
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>
101+
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N}xi32, #blocked>
102+
%8 = tt.broadcast %5 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked> -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>
103+
%9 = tt.broadcast %7 : tensor<1x{N}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked>
104+
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked>
105+
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>
106+
%12 = {GPU_DIALECT}.convert_layout %11 : tensor<{M}x{N}x{ty}, #blocked> -> tensor<{M}x{N}x{ty}, #src>
107+
%13 = "tt.reduce"(%12) ({{
108+
^bb0(%arg3: {ty}, %arg4: {ty}):
109+
%17 = {arith_op} %arg3, %arg4 : {ty}
110+
tt.reduce.return %17 : {ty}
111+
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
112+
%14 = tt.splat %arg2 : !tt.ptr<{ty}> -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>
113+
%15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
114+
%16 = {GPU_DIALECT}.convert_layout %13 : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
115+
%17 = tt.expand_dims %16 {{axis = 1 : i32}} : tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{rdims_2d}x{ty}, #blocked>
116+
tt.store %15, %17 : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>
117+
tt.return
118+
}}
119+
}}
120+
"""
121+
122+
temp_file = tmp_path / "test_reduce_layouts.ttgir"
123+
print("johnlu ttgir:", ir)
124+
temp_file.write_text(ir)
125+
kernel = triton.compile(str(temp_file))
126+
127+
rs = RandomState(17)
128+
x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10)
129+
z_shape = (M, 1)
130+
z = np.zeros(z_shape).astype(dtype_str)
131+
132+
x_tri = torch.tensor(x, device=device)
133+
z_tri = torch.tensor(z, device=device)
134+
135+
kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri)
136+
z_ref = numpy_op(x, axis=1, keepdims=True)
137+
138+
llir = kernel.asm['llir']
139+
assert re.search(r'call .* asm', llir), 'no inline visa in llir' # inline visa is used
140+
141+
if dtype_str == 'float16':
142+
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2)
143+
else:
144+
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)

third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
6161
triton::ReduceOp op, unsigned numLaneToReduce,
6262
unsigned interleave) const override;
6363

64+
bool warpBatchReduce(RewriterBase &rewriter, Location loc,
65+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
66+
triton::ReduceOp op, unsigned numLaneToReduce,
67+
unsigned interleave) const override {
68+
return false;
69+
};
70+
6471
std::string getMulhiFuncName(Type resultElementTy) const override;
6572

6673
void printf(RewriterBase &rewriter, Value formatStrStart,

third_party/intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,23 @@ struct XeInstrExecution {
315315
bool onlyAttachMLIRArgs{};
316316
};
317317

318+
enum XeArch {
319+
Xe = 0,
320+
Xe2 = 1,
321+
Xe3 = 2,
322+
};
323+
324+
struct XeVISAInstr : public XeInstrBase<XeVISAInstr> {
325+
using XeInstrBase<XeVISAInstr>::XeInstrBase;
326+
327+
static std::optional<std::string> getTypeName(Type scalarTy);
328+
static unsigned getGRFSizeInBytes(XeArch arch);
329+
static unsigned getExecMaskLaneNum(XeArch arch);
330+
};
331+
332+
std::string simdReduceAsm(std::string binOp, int warpSize, int accSize,
333+
Type elemTy, XeArch arch);
334+
318335
} // namespace triton
319336
} // namespace mlir
320337

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "PatternTritonGPUOpToLLVM.h"
22
#include "lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h"
33
#include "triton/Dialect/Triton/IR/Dialect.h"
4+
#include <triton/Tools/Sys/GetEnv.hpp>
45

56
using namespace mlir;
67
using namespace mlir::triton;
@@ -176,6 +177,18 @@ struct ReduceOpConversion
176177
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
177178
unsigned threadOffsetOnReductionAxis =
178179
helper.getThreadOffsetOnReductionAxis();
180+
181+
bool simdReduce =
182+
triton::tools::getBoolEnv("TRITON_INTEL_ENABLE_SIMD_REDUCE");
183+
184+
if (simdReduce) {
185+
auto ret = targetInfo.warpBatchReduce(rewriter, op.getLoc(), accs, op,
186+
sizeIntraWarps,
187+
threadOffsetOnReductionAxis);
188+
if (ret)
189+
return;
190+
}
191+
179192
for (auto it : accs) {
180193
const SmallVector<unsigned> &key = it.first;
181194
SmallVector<Value> &acc = accs[key];

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "TargetInfo.h"
10+
#include "intel/include/TritonIntelGPUToLLVM/XeAsmFormat.h"
11+
#include <llvm/ADT/TypeSwitch.h>
12+
#include <llvm/Support/FormatVariadic.h>
13+
1014
#include "Dialect/TritonIntelGPU/IR/Utils.h"
1115
#include "SPIRVTargetInfo.h"
1216
#include "Utility.h"
@@ -15,6 +19,14 @@ using namespace mlir;
1519

1620
namespace mlir::triton::intel {
1721

22+
struct XeSIMDReduceInstr : public XeVISAInstr {
23+
24+
XeSIMDReduceInstr(XeBuilder *builder, std::string binOp, unsigned warpSize,
25+
unsigned accSize, Type elemTy, XeArch arch)
26+
: XeVISAInstr(builder,
27+
simdReduceAsm(binOp, warpSize, accSize, elemTy, arch)) {};
28+
};
29+
1830
bool TargetInfo::supportMaximumMinimum() const { return false; }
1931
Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type,
2032
Value cmp) const {
@@ -112,6 +124,139 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc,
112124
return rewriter.create<arith::IndexCastOp>(loc, i32_ty, blockId);
113125
}
114126

127+
static SmallVector<ArrayRef<Value>> splitInBatches(ArrayRef<Value> srcValues,
128+
size_t batchSize) {
129+
SmallVector<ArrayRef<Value>> batches;
130+
for (; !srcValues.empty(); srcValues = srcValues.drop_front(batchSize))
131+
batches.push_back(srcValues.take_front(batchSize));
132+
return batches;
133+
}
134+
135+
bool TargetInfo::warpBatchReduce(
136+
RewriterBase &rewriter, Location loc,
137+
std::map<SmallVector<unsigned>, SmallVector<Value>> &acc,
138+
triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const {
139+
// No horizontal reduce required.
140+
if (numLaneToReduce == 1)
141+
return false;
142+
// Horizontal reduce with interleave stride not supported.
143+
if (interleave > 1)
144+
return false;
145+
// Check if it is a simple reduce operation supported by
146+
// TritonGEN::SubGroupReduceOp.
147+
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
148+
return false;
149+
Region &combineOp = op.getCombineOp();
150+
if (combineOp.getBlocks().size() > 1)
151+
return false;
152+
Block &block = *combineOp.begin();
153+
Operation *yield = block.getTerminator();
154+
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
155+
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
156+
reduceOp->getNumResults() != 1)
157+
return false;
158+
if (reduceOp->getOperand(0) != block.getArgument(0) ||
159+
reduceOp->getOperand(1) != block.getArgument(1))
160+
return false;
161+
162+
auto mod = op->getParentOfType<ModuleOp>();
163+
unsigned warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
164+
165+
// TODO: support clustered reduce.
166+
if (numLaneToReduce != warpSize)
167+
return false;
168+
169+
if (!isSupportedWarpReduceOp(reduceOp, numLaneToReduce, warpSize))
170+
return false;
171+
172+
unsigned minSGSize =
173+
mod->getAttrOfType<IntegerAttr>(
174+
gpu::intel::TritonIntelGPUDialect::getMinSGSizeAttrName())
175+
.getInt();
176+
177+
if (isa<arith::AddFOp, arith::MaxNumFOp>(reduceOp)) {
178+
// So we have to align the number of simd reduce results to the warp size.
179+
if (acc.size() % warpSize)
180+
return false;
181+
Type elemType = acc.begin()->second[0].getType();
182+
VectorType reduceTy = vec_ty(elemType, warpSize);
183+
184+
// Group the acc in batch.
185+
SmallVector<Value> inputAccs;
186+
for (auto it : acc) {
187+
const SmallVector<unsigned> &key = it.first;
188+
SmallVector<Value> &val = acc[key];
189+
assert(val.size() == 1 && "acc size has to be 1 for ungrouped input");
190+
inputAccs.push_back(val[0]);
191+
}
192+
SmallVector<Value> resultAccs(inputAccs.size() / warpSize);
193+
194+
std::string batchedHorizontalReduce;
195+
// TODO: support all possible reduction modes
196+
TypeSwitch<Operation *>(reduceOp)
197+
.Case<arith::AddFOp>([&](auto) { batchedHorizontalReduce = "add"; })
198+
.Case<arith::MaxNumFOp>([&](auto) { batchedHorizontalReduce = "max"; })
199+
.Default(
200+
[&](auto) { llvm_unreachable("Unhandled batched reduce kind"); });
201+
202+
llvm::transform(
203+
splitInBatches(inputAccs, warpSize), std::begin(resultAccs),
204+
[&](ArrayRef<Value> inputs) {
205+
auto inputRange = llvm::enumerate(inputs);
206+
Value batchedReduceVal = std::accumulate(
207+
std::begin(inputRange), std::end(inputRange),
208+
rewriter.create<LLVM::PoisonOp>(loc, reduceTy).getRes(),
209+
[reduceTy, loc, &rewriter](Value acc, auto entry) -> Value {
210+
auto [index, src] = entry;
211+
auto b = TritonLLVMOpBuilder(loc, rewriter);
212+
return b.insert_element(reduceTy, acc, src, b.i32_val(index));
213+
});
214+
XeBuilder xeBuilder;
215+
XeSIMDReduceInstr &bReduceOp = *xeBuilder.create<XeSIMDReduceInstr>(
216+
batchedHorizontalReduce, warpSize, warpSize, elemType,
217+
minSGSize == 8 ? Xe : Xe2);
218+
// The VISA inline asm doesn't support uniform result type. "=rw.u"
219+
// auto res = vISABuilder.newOperand("=rw.u");
220+
XeBuilder::Operand *res = xeBuilder.newOperand("=rw");
221+
XeBuilder::Operand *in = xeBuilder.newOperand(batchedReduceVal, "rw");
222+
bReduceOp({res, in}, /*onlyAttachMLIRArgs=*/true);
223+
Type resultTy = reduceTy.getElementType();
224+
return xeBuilder.launch(rewriter, loc, resultTy, true);
225+
});
226+
227+
unsigned grouped_iter = 0;
228+
for (unsigned i = 0; i < resultAccs.size(); ++i) {
229+
// The output of the inline vISA has to be the non-uniform value.
230+
// Have to shuffle the result to get the reduce value.
231+
Value ret = resultAccs[i];
232+
for (unsigned j = 0; j < warpSize; ++j) {
233+
inputAccs[grouped_iter++] =
234+
LLVM::intel::shuffleIdx(loc, rewriter, ret, j);
235+
}
236+
}
237+
grouped_iter = 0;
238+
for (auto it : acc) {
239+
const SmallVector<unsigned> &key = it.first;
240+
SmallVector<Value> &val = acc[key];
241+
val[0] = inputAccs[grouped_iter++];
242+
}
243+
#if 0
244+
auto res = vISABuilder.newOperand("=rw.u");
245+
auto in = vISABuilder.newOperand(batchedReduceVal, "rw");
246+
bReduceOp({res, in}, /*onlyAttachMLIRArgs=*/true);
247+
Value ret = vISABuilder.launch(rewriter, loc, reduceTy, false);
248+
Type resultTy = reduceTy.getElementType();
249+
for (unsigned i = 0; i < grouped_accs.size(); ++i) {
250+
grouped_accs[i] = b.extract_element(resultTy, ret, b.i32_val(i));
251+
}
252+
#endif
253+
254+
return true;
255+
}
256+
257+
return false;
258+
}
259+
115260
bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
116261
SmallVector<Value> &acc, triton::ReduceOp op,
117262
unsigned numLaneToReduce,

0 commit comments

Comments
 (0)