Skip to content

Commit e15c893

Browse files
committed
[intel] 2Dblock runtime HW checks
1 parent b526748 commit e15c893

File tree

11 files changed

+543
-178
lines changed

11 files changed

+543
-178
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
4444
"TRITON_F32_DEFAULT",
4545
"TRITON_PREFER_TMEM_16x256_LAYOUT",
4646
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
47+
"TRITON_INTEL_2DBLOCK_ASSERT",
4748
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
4849
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
4950
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import triton
3+
4+
import ctypes
5+
import sys
6+
7+
8+
def run_ir(device, temp_file):
9+
ir = r"""
10+
module attributes {
11+
ttg.target = "xpu",
12+
"ttg.num-warps" = 32 : i32,
13+
"ttg.num-ctas" = 1 : i32,
14+
"ttg.threads-per-warp" = 16 : i32
15+
} {
16+
tt.func @dyn_block(
17+
%iptr : i64, %base_width : i32,
18+
%base_height : i32, %base_pitch : i32,
19+
%x : i32, %y : i32) {
20+
%p0 = llvm.inttoptr %iptr : i64 to !llvm.ptr
21+
22+
%v = triton_gen.2Dblockload %p0, %base_width, %base_height,
23+
%base_pitch, %x, %y
24+
{ elem_size_in_bits = 8, tile_width = 8, tile_height = 8,
25+
v_blocks = 1, transpose = false,
26+
vnni_transform = false, cache_control = Default }
27+
: (!llvm.ptr, i32, i32, i32, i32, i32)
28+
-> vector<2xi16>
29+
30+
// prevent GluonInline
31+
%v_i32 = llvm.bitcast %v : vector<2xi16> to i32
32+
llvm.inline_asm has_side_effects asm_dialect = att
33+
"", "r" %v_i32 : (i32) -> ()
34+
35+
tt.return
36+
}
37+
}
38+
"""
39+
40+
with open(temp_file, "w", encoding="utf-8") as f:
41+
f.write(ir)
42+
43+
kernel = triton.compile(temp_file)
44+
45+
a = torch.randn((256, 64), dtype=torch.float32, device=device)
46+
47+
addr = ctypes.c_int64(a.data_ptr()).value
48+
49+
kernel[(1, 1, 1)](addr, 64, 64, 1, 0, 0)
50+
51+
52+
if __name__ == "__main__":
53+
fn = globals()[sys.argv[1]]
54+
fn(*sys.argv[2:])

python/test/unit/intel/test_block_load.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,30 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
207207
result_tor = fn_tor()
208208
result_tri = fn_tri()
209209
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)
210+
211+
212+
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
213+
@pytest.mark.xfail(
214+
not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
215+
and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']),
216+
reason="Block loads and/or DPAS not supported on this architecture", run=False)
217+
def test_block_load_asserts(monkeypatch, device, tmp_path: pathlib.Path):
218+
monkeypatch.setenv("TRITON_INTEL_2DBLOCK_ASSERT", "1")
219+
220+
import os
221+
import signal
222+
import subprocess
223+
import sys
224+
225+
dir_path = os.path.dirname(os.path.realpath(__file__))
226+
helper_path = os.path.join(dir_path, "block_io_helper.py")
227+
228+
temp_file = tmp_path / "test_block_load_asserts.ttgir"
229+
230+
proc = subprocess.run(
231+
[sys.executable, helper_path, "run_ir", device, str(temp_file)],
232+
capture_output=True,
233+
)
234+
235+
rc = proc.returncode
236+
assert rc == -signal.SIGABRT or rc == (128 + signal.SIGABRT)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: env TRITON_INTEL_2DBLOCK_ASSERT=1 triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s
2+
3+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
4+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
5+
// CHECK: llvm.call spir_funccc @__assert_fail
6+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<2xi16>
7+
llvm.return
8+
}
9+
}
10+
11+
// -----
12+
13+
llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
14+
// CHECK: llvm.call spir_funccc @__assert_fail
15+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=16, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
16+
llvm.return
17+
}
18+
19+
// -----
20+
21+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
22+
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
23+
// CHECK: llvm.call spir_funccc @__assert_fail
24+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
25+
llvm.return
26+
}
27+
}

third_party/intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@ namespace triton {
2121
#define GEN_PASS_DECL
2222
#include "intel/include/TritonGENToLLVM/Passes.h.inc"
2323

24-
void populateTritonGENToLLVMConversionPatterns(LLVMTypeConverter &converter,
25-
RewritePatternSet &patterns);
24+
namespace gpu::intel {
25+
class LibCallEmitter;
26+
} // namespace gpu::intel
27+
28+
void populateTritonGENToLLVMConversionPatterns(
29+
LLVMTypeConverter &converter, RewritePatternSet &patterns,
30+
const mlir::triton::gpu::intel::LibCallEmitter &emitter);
2631

2732
void registerConvertTritonGENToLLVMInterface(DialectRegistry &registry);
2833

0 commit comments

Comments
 (0)