Skip to content

Commit d34a1f3

Browse files
authored
Revert "Revert "[frontend] Remove Complex Regex for MLIR Parsing (#4924)"" (#2681)
Closes #2653 This reverts commit ecc9bd4. The issue looks like: ```bash loc("/tmp/pytest-of-runner/pytest-0/popen-gw0/test_convert2d_dst_layout8_int21/test_convert2d.ttgir":4:30): error: #"triton_intel_gpu"<"dpas<{repeatCount=8, systolicDepth=8, executionSize = 8, opsPerChan = 1, threadsPerWarp = 32, warpsPerCTA=[4, 1], repCluster=[1, 1]}>"> : 'none' attribute created with unregistered dialect. If this is intended, please call allowUnregisteredDialects() on the MLIRContext, or use -allow-unregistered-dialect with the MLIR opt tool used ``` --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent fd5b82a commit d34a1f3

File tree

4 files changed

+182
-32
lines changed

4 files changed

+182
-32
lines changed

python/src/ir.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2525
#include "mlir/Transforms/LocationSnapshot.h"
2626

27+
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
2728
#include "triton/Dialect/Triton/IR/Dialect.h"
2829
#include "triton/Dialect/Triton/IR/Types.h"
2930
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -502,6 +503,16 @@ void init_triton_ir(py::module &&m) {
502503
[](ModuleOp &self, FuncOp &funcOp) -> void {
503504
self.push_back(funcOp);
504505
})
506+
.def("get_entry_func_name",
507+
[](ModuleOp &self) -> std::string {
508+
for (auto &op : self.getOps()) {
509+
if (auto func = dyn_cast<FuncOp>(op)) {
510+
if (LLVM::isKernel(func))
511+
return func.getName().str();
512+
}
513+
}
514+
return "";
515+
})
505516
.def("has_function",
506517
[](ModuleOp &self, std::string &funcName) -> bool {
507518
if (self.lookupSymbol(funcName))
@@ -512,6 +523,43 @@ void init_triton_ir(py::module &&m) {
512523
[](ModuleOp &self, std::string &funcName) -> FuncOp {
513524
return self.lookupSymbol<FuncOp>(funcName);
514525
})
526+
/*
527+
* def ty_to_cpp(ty) is the consumer of this function.
528+
* If the type is a ptr it expects ty[0] == '*', else the type itself.
529+
*/
530+
531+
.def("get_function_signature",
532+
[](ModuleOp &self, FuncOp &func) -> std::vector<std::string> {
533+
std::vector<std::string> strVec;
534+
535+
auto type = func.getFunctionType();
536+
unsigned numArgs = type.getNumInputs();
537+
for (unsigned i = 0; i != numArgs; ++i) {
538+
std::string tempType;
539+
llvm::raw_string_ostream os(tempType);
540+
541+
auto ty = type.getInput(i);
542+
if (auto attributes = func.getCallableArgAttrs()) {
543+
Attribute attr = attributes[i];
544+
// Check for tt.nv_tma_desc = 1
545+
if (auto dAttr = dyn_cast<DictionaryAttr>(attr)) {
546+
if (dAttr.contains("tt.nv_tma_desc")) {
547+
strVec.push_back("nvTmaDesc");
548+
continue;
549+
}
550+
}
551+
}
552+
if (auto ptrType = dyn_cast<PointerType>(ty)) {
553+
auto pType = ptrType.getPointeeType();
554+
os << "*";
555+
pType.print(os);
556+
} else {
557+
ty.print(os);
558+
}
559+
strVec.push_back(tempType);
560+
}
561+
return strVec;
562+
})
515563
.def("get_int_attr",
516564
[](ModuleOp &self, std::string name) -> py::object {
517565
auto ret = self->getAttrOfType<IntegerAttr>(name);
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import tempfile
2+
import triton
3+
from triton.compiler import IRSource, make_backend
4+
from triton._C.libtriton import ir
5+
6+
target = triton.runtime.driver.active.get_current_target()
7+
backend = make_backend(target)
8+
9+
10+
def test_mlir_attribute_parsing() -> None:
11+
'''
12+
Tests that MLIR attributes are parsed correctly from input ttir/ttgir.
13+
14+
Checks for the following:
15+
1. Name and type signature are parsed correctly
16+
2. _get_num_warps_from_ir_str() works
17+
3. tt.nv_tma_desc attribute is parsed correctly
18+
'''
19+
20+
sample_ttgir = r"""
21+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [8, 1], order = [1, 0]}>
22+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 2], order = [1, 0]}>
23+
#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
24+
#shared = #triton_gpu.shared<{vec = 4, perPhase = 2, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
25+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0], hasLeadingOffset = false}>
26+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
27+
tt.func public @matmul_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
28+
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
29+
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
30+
%arg3: i32 {tt.divisibility = 16 : i32},
31+
%arg4: i32 {tt.divisibility = 16 : i32},
32+
%arg5: i32 {tt.divisibility = 16 : i32},
33+
%arg6: i32 {tt.divisibility = 16 : i32},
34+
%arg7: i32 {tt.divisibility = 16 : i32},
35+
%arg8: i32 {tt.divisibility = 16 : i32, tt.nv_tma_desc = 0 : i32},
36+
%desc: !tt.ptr<i8, 0> {tt.nv_tma_desc = 1 : i32}) attributes {noinline = false} {
37+
tt.return
38+
}
39+
}
40+
"""
41+
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
42+
f.write(sample_ttgir)
43+
f.flush()
44+
context = ir.context()
45+
src = IRSource(f.name, context, backend)
46+
47+
# check name and type signature
48+
# should match ty_to_cpp(...)
49+
assert src.signature == \
50+
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
51+
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
52+
assert src.name == "@matmul_kernel"
53+
54+
# check num warps
55+
assert src.parse_options()['num_warps'] == 8
56+
57+
sample_ttgir_vector_add = r"""
58+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
59+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
60+
tt.func public @add_kernel(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32},
61+
%arg1: !tt.ptr<i32> {tt.divisibility = 16 : i32},
62+
%arg2: !tt.ptr<i32> {tt.divisibility = 16 : i32},
63+
%arg3: i32 {tt.divisibility = 16 : i32})
64+
attributes {noinline = false} {
65+
%c1024_i32 = arith.constant 1024 : i32
66+
%0 = tt.get_program_id x : i32
67+
%1 = arith.muli %0, %c1024_i32 : i32
68+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
69+
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
70+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
71+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked>
72+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked>
73+
%7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
74+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
75+
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
76+
%10 = tt.splat %arg1 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
77+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
78+
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
79+
%13 = arith.addi %9, %12 : tensor<1024xi32, #blocked>
80+
%14 = tt.splat %arg2 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
81+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
82+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<i32>, #blocked>
83+
tt.return
84+
}
85+
}
86+
"""
87+
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
88+
f.write(sample_ttgir_vector_add)
89+
f.flush()
90+
context = ir.context()
91+
src = IRSource(f.name, context, backend)
92+
93+
# now test compilation
94+
triton.compile(f.name, target=target)

python/triton/compiler/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from .compiler import CompiledKernel, ASTSource, compile, make_backend, LazyDict
1+
from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict
22
from .errors import CompilationError
33

4-
__all__ = ["compile", "make_backend", "ASTSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", "LazyDict"]
4+
__all__ = [
5+
"compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError",
6+
"LazyDict"
7+
]

python/triton/compiler/compiler.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,13 @@
2525
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
2626
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
2727
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
28-
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
2928
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
3029
prototype_pattern = {
31-
"ttir": mlir_prototype_pattern,
32-
"ttgir": mlir_prototype_pattern,
3330
"ptx": ptx_prototype_pattern,
3431
}
3532

36-
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<)]+|<[^>]+>)+(?: {[^}]+})?),?'
3733
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
3834
arg_type_pattern = {
39-
"ttir": mlir_arg_type_pattern,
40-
"ttgir": mlir_arg_type_pattern,
4135
"ptx": ptx_arg_type_pattern,
4236
}
4337

@@ -55,16 +49,6 @@ def convert_type_repr(x):
5549
return x
5650

5751

58-
def _get_num_warps_from_ir_str(src: str):
59-
ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:'
60-
# TODO(jlebar): Using a regex to get num-warps is a hack, and will break if
61-
# e.g. someone has an instruction (not module) attribute named "num-warps".
62-
num_warps_matches = re.findall(ttgir_num_warps_pattern, src)
63-
assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps"
64-
num_warps = int(num_warps_matches[0])
65-
return num_warps
66-
67-
6852
class ASTSource:
6953

7054
def __init__(self, fn, signature, constants=None, attrs=None) -> None:
@@ -107,28 +91,42 @@ def parse_options(self):
10791

10892
class IRSource:
10993

110-
def __init__(self, path):
94+
def __init__(self, path, context, backend):
11195
self.path = path
11296
path = Path(path)
11397
self.ext = path.suffix[1:]
11498
self.src = path.read_text()
115-
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
116-
self.name = match.group(1)
117-
signature = match.group(2)
118-
types = re.findall(arg_type_pattern[self.ext], signature)
119-
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
99+
ir.load_dialects(context)
100+
backend.load_dialects(context)
101+
102+
# We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
103+
# TODO - replace with a proper parser
104+
if self.ext == "ptx":
105+
match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
106+
self.name = match.group(1)
107+
signature = match.group(2)
108+
types = re.findall(arg_type_pattern[self.ext], signature)
109+
self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
110+
else:
111+
self.module = ir.parse_mlir_module(self.path, context)
112+
fn_name = self.module.get_entry_func_name()
113+
self.name = "@" + fn_name
114+
funcOp = self.module.get_function(fn_name)
115+
func_ty = self.module.get_function_signature(funcOp)
116+
self.signature = {k: ty for k, ty in enumerate(func_ty)}
120117

121118
def hash(self):
122119
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
123120

124121
def make_ir(self, options, codegen_fns, module_map, context):
125-
module = ir.parse_mlir_module(self.path, context)
126-
module.context = context
127-
return module
122+
self.module.context = context
123+
return self.module
128124

129125
def parse_options(self):
130126
if self.ext == "ttgir":
131-
return {'num_warps': _get_num_warps_from_ir_str(self.src)}
127+
num_warps = self.module.get_int_attr("triton_gpu.num-warps")
128+
assert num_warps is not None, "Unable to parse triton_gpu.num-warps attribute"
129+
return {'num_warps': num_warps}
132130
return dict()
133131

134132

@@ -225,7 +223,9 @@ def compile(src, target=None, options=None):
225223
# create backend
226224
if ir_source:
227225
assert isinstance(src, str), "source must be either AST or a filepath"
228-
src = IRSource(src)
226+
context = ir.context()
227+
src = IRSource(src, context, backend)
228+
229229
extra_options = src.parse_options()
230230
options = backend.parse_options(dict(options or dict(), **extra_options))
231231
# create cache manager
@@ -266,9 +266,14 @@ def compile(src, target=None, options=None):
266266
# when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
267267
if ir_source:
268268
first_stage += 1
269-
context = ir.context()
270-
ir.load_dialects(context)
271-
backend.load_dialects(context)
269+
270+
# For IRSource, we have already grabbed the context + called both
271+
# ir.load_dialects and backend.load_dialects.
272+
if not isinstance(src, IRSource):
273+
context = ir.context()
274+
ir.load_dialects(context)
275+
backend.load_dialects(context)
276+
272277
codegen_fns = backend.get_codegen_implementation()
273278
module_map = backend.get_module_map()
274279
try:

0 commit comments

Comments
 (0)