Skip to content

Commit c6ad34f

Browse files
authored
Use correct types for constexpr arguments (#8248)
A previous commit improved str_to_ty to support constexprs, but its use in ast_to_ttir did not correctly handle nested tuples: fixing this properly requires recursion. Instead of invoking str_to_ty, it's easier to just fix up the types ourselves. Fortunately, we have the arguments passed into the kernel and can use this to go in and correct the parameters before code generation sees it.
1 parent c733bf7 commit c6ad34f

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

python/test/unit/language/test_tuple.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,19 @@ def m_to_the_n(X, shape: tl.constexpr, strides, m_n):
256256
torch.testing.assert_close(x, expected_x, rtol=0, atol=0)
257257

258258

259+
def test_passing_nested_tuple_with_constexpr(device):
260+
261+
@triton.jit
262+
def test(x):
263+
# This creates a new scope, which will force a copy of liveins. It's
264+
# important for this to happen as it forces IR flattening/unflattening,
265+
# which relies on the types being correct for the roundtrip to succeed.
266+
for _ in range(1):
267+
tl.static_assert(x[1][0] == 2)
268+
269+
test[(1, )](((1, ), (tl.constexpr(2), )))
270+
271+
259272
def test_passing_tuple_to_make_tensor_descriptor(device, with_allocator):
260273

261274
@triton.jit

python/triton/compiler/code_generator.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,16 +1589,24 @@ def ret(self, node: ast.Call):
15891589

15901590
def ast_to_ttir(fn, src, context, options, codegen_fns, module_map, module=None):
15911591
arg_types = [None] * len(fn.arg_names)
1592-
const_iter = iter(src.constants.items())
1593-
kc, vc = next(const_iter, (None, None))
1594-
1595-
for i, (ks, v) in enumerate(src.signature.items()):
1596-
idx = fn.arg_names.index(ks)
1597-
cexpr = None
1598-
if kc is not None and kc[0] == i:
1599-
cexpr = vc
1600-
kc, vc = next(const_iter, (None, None))
1601-
arg_types[idx] = str_to_ty(v, cexpr)
1592+
1593+
for k, v in src.signature.items():
1594+
idx = fn.arg_names.index(k)
1595+
arg_types[idx] = str_to_ty(v, None)
1596+
1597+
def apply_constexpr_types(argument, indices, value):
1598+
index = indices.pop()
1599+
if len(indices) == 0:
1600+
if isinstance(argument, list):
1601+
argument[index] = constexpr(value).type
1602+
else:
1603+
argument.types[index] = constexpr(value).type
1604+
else:
1605+
apply_constexpr_types(argument[index], indices, value)
1606+
1607+
for path, value in src.constants.items():
1608+
apply_constexpr_types(arg_types, list(path)[::-1], value)
1609+
16021610
prototype = ASTFunction([], arg_types, src.constants, src.attrs)
16031611
file_name, begin_line = get_jit_fn_file_line(fn)
16041612
# query function representation

0 commit comments

Comments
 (0)