Skip to content

Commit 8d0d554

Browse files
committed
[intel] 2Dblock runtime HW checks
1 parent c437c95 commit 8d0d554

File tree

11 files changed

+574
-206
lines changed

11 files changed

+574
-206
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4444
"TRITON_F32_DEFAULT",
4545
"TRITON_PREFER_TMEM_16x256_LAYOUT",
4646
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
47+
"TRITON_INTEL_2DBLOCK_ASSERT",
4748
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
4849
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
4950
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
import triton
3+
4+
import ctypes
5+
import sys
6+
7+
8+
def run_load_ir(temp_file, elem_size, *args):
9+
out_type = f"i{int(elem_size) * 4}"
10+
ir = f"""
11+
module attributes {{
12+
ttg.target = "xpu",
13+
"ttg.num-warps" = 32 : i32,
14+
"ttg.num-ctas" = 1 : i32,
15+
"ttg.threads-per-warp" = 16 : i32
16+
}} {{
17+
tt.func @dyn_block(
18+
%iptr : i64, %base_width : i32,
19+
%base_height : i32, %base_pitch : i32,
20+
%x : i32, %y : i32) {{
21+
%p0 = llvm.inttoptr %iptr : i64 to !llvm.ptr
22+
23+
%v = triton_gen.2Dblockload %p0, %base_width, %base_height,
24+
%base_pitch, %x, %y
25+
{{ elem_size_in_bits = {elem_size}, tile_width = 8, tile_height = 8,
26+
v_blocks = 1, transpose = false,
27+
vnni_transform = false, cache_control = Default }}
28+
: (!llvm.ptr, i32, i32, i32, i32, i32)
29+
-> vector<1x{out_type}>
30+
31+
// prevent GluonInline
32+
%v_cast = llvm.bitcast %v : vector<1x{out_type}> to {out_type}
33+
llvm.inline_asm has_side_effects asm_dialect = att
34+
"", "r" %v_cast : ({out_type}) -> ()
35+
36+
tt.return
37+
}}
38+
}}
39+
"""
40+
41+
with open(temp_file, "w", encoding="utf-8") as f:
42+
f.write(ir)
43+
44+
kernel = triton.compile(temp_file)
45+
46+
a = torch.zeros((256, 64), dtype=torch.float32, device="xpu")
47+
48+
addr = ctypes.c_int64(a.data_ptr()).value
49+
50+
kernel[(1, 1, 1)](addr, *map(int, args), 0)
51+
52+
53+
if __name__ == "__main__":
54+
fn = globals()[sys.argv[1]]
55+
fn(*sys.argv[2:])

python/test/unit/intel/test_block_load.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import pytest
22
import torch
3+
4+
import os
5+
import signal
6+
import subprocess
7+
import sys
38
import pathlib
49
from functools import partial
510

@@ -207,3 +212,45 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
207212
result_tor = fn_tor()
208213
result_tri = fn_tri()
209214
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)
215+
216+
217+
@pytest.mark.parametrize("elem_size, width, height, pitch, x",
218+
[[8, 16777216, 64, 16777216, 0], # width <= 24 bits
219+
[8, 32, 64, 128, 0], # width >= 64
220+
[8, 66, 64, 128, 0], # width % max(4,elemSize) == 0
221+
[8, 128, 16777216, 128, 0], # height <= 24 bits
222+
[8, 128, 64, 16777216, 0], # pitch <= 24 bits
223+
[8, 128, 64, 32, 0], # pitch >= 64
224+
[8, 128, 64, 70, 0], # pitch % 16 == 0
225+
[8, 128, 64, 120, 0], # pitch >= width
226+
[8, 128, 64, 128, 1], # x*elemSize % 4 == 0 (alignment for 8-bit)
227+
[16, 128, 64, 128, 1], # x*elemSize % 4 == 0 (alignment for 16-bit)
228+
])
229+
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
230+
@pytest.mark.xfail(
231+
not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
232+
and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']),
233+
reason="Block loads and/or DPAS not supported on this architecture", run=False)
234+
def test_block_load_asserts(elem_size, width, height, pitch, x, monkeypatch, tmp_path: pathlib.Path):
235+
monkeypatch.setenv("TRITON_INTEL_2DBLOCK_ASSERT", "1")
236+
237+
dir_path = os.path.dirname(os.path.realpath(__file__))
238+
helper_path = os.path.join(dir_path, "block_load_helper.py")
239+
240+
temp_file = tmp_path / "test_block_load_asserts.ttgir"
241+
242+
proc = subprocess.run(
243+
[
244+
sys.executable, helper_path, "run_load_ir",
245+
str(temp_file),
246+
str(elem_size),
247+
str(width),
248+
str(height),
249+
str(pitch),
250+
str(x)
251+
],
252+
capture_output=True,
253+
)
254+
255+
rc = proc.returncode
256+
assert rc == -signal.SIGABRT
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: env TRITON_INTEL_2DBLOCK_ASSERT=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s --check-prefix=ASSERT
2+
// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s --check-prefix=NOASSERT
3+
4+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
5+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
6+
// ASSERT: llvm.call spir_funccc @__assert_fail
7+
// NOASSERT-NOT: __assert_fail
8+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<2xi16>
9+
llvm.return
10+
}
11+
}
12+
13+
// -----
14+
15+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
16+
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
17+
// ASSERT: llvm.call spir_funccc @__assert_fail
18+
// NOASSERT-NOT: __assert_fail
19+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
20+
llvm.return
21+
}
22+
}
23+
24+
// -----
25+
26+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
27+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
28+
// ASSERT: llvm.call spir_funccc @__assert_fail
29+
// NOASSERT-NOT: __assert_fail
30+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
31+
llvm.return
32+
}
33+
}

third_party/intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ namespace triton {
2121
#define GEN_PASS_DECL
2222
#include "intel/include/TritonGENToLLVM/Passes.h.inc"
2323

24-
void populateTritonGENToLLVMConversionPatterns(LLVMTypeConverter &converter,
25-
RewritePatternSet &patterns);
24+
namespace gpu::intel {
25+
class LibCallEmitter;
26+
} // namespace gpu::intel
27+
28+
void populateTritonGENToLLVMConversionPatterns(
29+
LLVMTypeConverter &converter, RewritePatternSet &patterns,
30+
const mlir::triton::gpu::intel::LibCallEmitter &emitter);
2631

2732
void registerConvertTritonGENToLLVMInterface(DialectRegistry &registry);
2833

0 commit comments

Comments
 (0)