Skip to content

Commit ccc9640

Browse files
committed
Generalize tile_with_offset pass
1 parent 0a8d647 commit ccc9640

File tree

4 files changed

+150
-60
lines changed

4 files changed

+150
-60
lines changed

helion/_compiler/device_ir.py

Lines changed: 66 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import contextlib
66
import dataclasses
77
import functools
8-
import itertools
98
import operator
109
import re
1110
import textwrap
@@ -1216,55 +1215,82 @@ def add_tile_with_offset_metadata(graph_info: GraphInfo) -> None:
12161215
"""
12171216
graph = graph_info.graph
12181217
env = CompileEnvironment.current()
1219-
1220-
for node in itertools.chain(
1221-
graph.find_nodes(op="call_function", target=operator.add),
1222-
graph.find_nodes(op="call_function", target=torch.ops.aten.add.Tensor),
1223-
):
1224-
# Check if this is tile.index + offset pattern
1225-
# args[0] should be tile_index result, args[1] should be int/SymInt
1226-
if len(node.args) != 2 and not node.kwargs:
1227-
continue
1228-
left_arg, right_arg = node.args
1229-
1230-
# Check if left argument is a tile_index call
1218+
add_targets = (operator.add, torch.ops.aten.add.Tensor)
1219+
offset_types = (int, torch.SymInt)
1220+
for node in graph.nodes:
12311221
if (
1232-
not isinstance(left_arg, torch.fx.Node)
1233-
or left_arg.op != "call_function"
1234-
or left_arg.target != hl.tile_index
1222+
node.op != "call_function"
1223+
or node.target not in add_targets
1224+
or node.kwargs
1225+
or len(node.args) != 2
12351226
):
12361227
continue
12371228

1238-
# Check if right argument is an integer offset
1239-
# It could be a constant, SymInt node, or another value
1240-
# We accept int, SymInt, or nodes that represent them
1241-
offset = None
1242-
if isinstance(right_arg, (int, torch.SymInt)):
1243-
offset = right_arg
1244-
elif isinstance(right_arg, torch.fx.Node):
1245-
# Check the node's metadata for the value
1246-
val = right_arg.meta.get("val")
1247-
if isinstance(val, (int, torch.SymInt)):
1248-
offset = val
1249-
1250-
if offset is None:
1251-
continue
1229+
block_id: int | None = None
1230+
total_offset: int | torch.SymInt = 0
1231+
valid = True
12521232

1253-
# Extract the block_id from the tile_index call
1254-
tile_arg = left_arg.args[0]
1255-
block_id = None
1256-
if isinstance(tile_arg, torch.fx.Node) and isinstance(
1257-
tile_arg.meta["val"], torch.SymInt
1258-
):
1259-
block_id = env.get_block_id(tile_arg.meta["val"])
1233+
for arg in node.args:
1234+
tile_offset_value: int | torch.SymInt | None = None
1235+
arg_block_id: int | None = None
1236+
1237+
if isinstance(arg, torch.fx.Node):
1238+
meta_tile = arg.meta.get("tile_with_offset")
1239+
if meta_tile is not None:
1240+
arg_block_id = meta_tile.get("block_id")
1241+
if arg_block_id is None:
1242+
valid = False
1243+
break
1244+
tile_offset_value = meta_tile.get("offset", 0)
1245+
elif (
1246+
arg.op == "call_function"
1247+
and arg.target == hl.tile_index
1248+
and arg.args
1249+
and isinstance(arg.args[0], torch.fx.Node)
1250+
):
1251+
tile_val = arg.args[0].meta.get("val")
1252+
if isinstance(tile_val, torch.SymInt):
1253+
arg_block_id = env.get_block_id(tile_val)
1254+
if arg_block_id is None:
1255+
valid = False
1256+
break
1257+
tile_offset_value = 0
1258+
else:
1259+
val = arg.meta.get("val")
1260+
if isinstance(val, offset_types):
1261+
total_offset = total_offset + val
1262+
continue
1263+
1264+
if arg_block_id is not None:
1265+
if block_id is not None:
1266+
valid = False
1267+
break
1268+
if tile_offset_value is None:
1269+
tile_offset_value = 0
1270+
block_id = arg_block_id
1271+
total_offset = total_offset + tile_offset_value
1272+
continue
1273+
1274+
val = arg.meta.get("val")
1275+
if isinstance(val, offset_types):
1276+
total_offset = total_offset + val
1277+
continue
1278+
1279+
valid = False
1280+
break
1281+
1282+
if isinstance(arg, offset_types):
1283+
total_offset = total_offset + arg
1284+
continue
1285+
valid = False
1286+
break
12601287

1261-
if block_id is None:
1288+
if not valid or block_id is None:
12621289
continue
12631290

1264-
# Add metadata to mark this as a tile+offset node
12651291
node.meta["tile_with_offset"] = {
12661292
"block_id": block_id,
1267-
"offset": offset,
1293+
"offset": total_offset,
12681294
}
12691295

12701296

test/test_autotuner.expected

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,28 @@ This file is automatically generated by assertExpectedJournal calls in test_auto
22
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.
33

44
--- assertExpectedJournal(TestAutotuner.test_config_fragment0)
5-
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None])
6-
helion.Config(block_sizes=[32, 128, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 2], range_warp_specializes=[None, True])
7-
helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 3], range_warp_specializes=[None, False])
8-
helion.Config(block_sizes=[16, 32, 256], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 3], range_warp_specializes=[True, None])
9-
helion.Config(block_sizes=[64, 32, 16], indexing='block_ptr', l2_groupings=[2], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 4], range_unroll_factors=[0, 1], range_warp_specializes=[None, None])
10-
helion.Config(block_sizes=[16, 16, 16], indexing='tensor_descriptor', l2_groupings=[32], load_eviction_policies=['last', 'first'], loop_orders=[[0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None, False], range_multi_buffers=[None, None], range_num_stages=[0, 2], range_unroll_factors=[0, 2], range_warp_specializes=[None, False])
11-
helion.Config(block_sizes=[16, 32, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 3], range_unroll_factors=[0, 3], range_warp_specializes=[None, None])
12-
helion.Config(block_sizes=[16, 32, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=8, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False, None], range_multi_buffers=[False, None], range_num_stages=[3, 3], range_unroll_factors=[2, 3], range_warp_specializes=[False, True])
13-
helion.Config(block_sizes=[256, 16, 16], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=5, num_warps=32, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, False], range_num_stages=[0, 1], range_unroll_factors=[0, 2], range_warp_specializes=[None, True])
14-
helion.Config(block_sizes=[16, 64, 16], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=3, num_warps=32, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[False, None], range_num_stages=[3, 0], range_unroll_factors=[3, 4], range_warp_specializes=[False, True])
5+
helion.Config(advanced_compiler_configuration=0, block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None])
6+
helion.Config(advanced_compiler_configuration=10, block_sizes=[32, 128, 64], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 2], range_warp_specializes=[None, True])
7+
helion.Config(advanced_compiler_configuration=5, block_sizes=[256, 16, 256], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=32, pid_type='flat', range_flattens=[None, False], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 3], range_warp_specializes=[None, None])
8+
helion.Config(advanced_compiler_configuration=3, block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[64], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0]], num_stages=5, num_warps=2, pid_type='persistent_interleaved', range_flattens=[True, None], range_multi_buffers=[True, True], range_num_stages=[3, 1], range_unroll_factors=[4, 1], range_warp_specializes=[False, True])
9+
helion.Config(advanced_compiler_configuration=0, block_sizes=[64, 16, 128], indexing='block_ptr', l2_groupings=[4], load_eviction_policies=['', 'last'], loop_orders=[[0, 1]], num_stages=3, num_warps=16, pid_type='flat', range_flattens=[None, False], range_multi_buffers=[None, False], range_num_stages=[0, 2], range_unroll_factors=[0, 0], range_warp_specializes=[None, None])
10+
helion.Config(advanced_compiler_configuration=9, block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[4], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=8, pid_type='persistent_blocked', range_flattens=[True, False], range_multi_buffers=[True, None], range_num_stages=[0, 0], range_unroll_factors=[4, 1], range_warp_specializes=[True, False])
11+
helion.Config(advanced_compiler_configuration=5, block_sizes=[16, 16, 16], indexing='block_ptr', l2_groupings=[1], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_interleaved', range_flattens=[None, True], range_multi_buffers=[False, False], range_num_stages=[0, 4], range_unroll_factors=[3, 4], range_warp_specializes=[False, False])
12+
helion.Config(advanced_compiler_configuration=3, block_sizes=[64, 16, 16], indexing='block_ptr', l2_groupings=[32], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=5, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None, True], range_multi_buffers=[True, False], range_num_stages=[1, 4], range_unroll_factors=[3, 3], range_warp_specializes=[None, False])
13+
helion.Config(advanced_compiler_configuration=7, block_sizes=[32, 64, 64], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['last', ''], loop_orders=[[0, 1]], num_stages=6, num_warps=2, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 2], range_warp_specializes=[None, None])
14+
helion.Config(advanced_compiler_configuration=2, block_sizes=[16, 16, 16], indexing='block_ptr', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[1, 0]], num_stages=4, num_warps=1, pid_type='persistent_interleaved', range_flattens=[None, False], range_multi_buffers=[True, True], range_num_stages=[1, 4], range_unroll_factors=[3, 1], range_warp_specializes=[True, False])
1515

1616
--- assertExpectedJournal(TestAutotuner.test_config_fragment1)
17-
helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
18-
helion.Config(block_sizes=[1, 64, 64], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True])
19-
helion.Config(block_sizes=[2, 8, 512], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['first', 'first'], loop_orders=[[2, 0, 1]], num_stages=2, num_warps=1, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
20-
helion.Config(block_sizes=[1, 512, 1], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[1], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[2], range_warp_specializes=[True])
21-
helion.Config(block_sizes=[1, 4, 256], flatten_loops=[True], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'last'], loop_orders=[[1, 0, 2]], num_stages=2, num_warps=32, pid_type='persistent_interleaved', range_flattens=[None], range_multi_buffers=[True], range_unroll_factors=[1], range_warp_specializes=[True])
22-
helion.Config(block_sizes=[1, 128, 16], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=1, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[4], range_warp_specializes=[None])
23-
helion.Config(block_sizes=[8, 32, 256], flatten_loops=[False], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', 'last'], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=8, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[4], range_warp_specializes=[None])
24-
helion.Config(block_sizes=[2, 64, 32], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', 'first'], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=16, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
25-
helion.Config(block_sizes=[4, 32, 1], flatten_loops=[True], indexing='pointer', l2_groupings=[8], load_eviction_policies=['', 'last'], loop_orders=[[2, 1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[False], range_unroll_factors=[3], range_warp_specializes=[True])
26-
helion.Config(block_sizes=[4, 2, 128], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['', 'first'], loop_orders=[[1, 2, 0]], num_stages=2, num_warps=4, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[1], range_warp_specializes=[False])
17+
helion.Config(advanced_compiler_configuration=0, block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
18+
helion.Config(advanced_compiler_configuration=6, block_sizes=[1, 32, 32], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True])
19+
helion.Config(advanced_compiler_configuration=10, block_sizes=[1, 32, 1], flatten_loops=[True], indexing='block_ptr', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[2, 1, 0]], num_stages=4, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[None], range_unroll_factors=[3], range_warp_specializes=[False])
20+
helion.Config(advanced_compiler_configuration=2, block_sizes=[1, 8, 16], flatten_loops=[True], indexing='pointer', l2_groupings=[32], load_eviction_policies=['last', 'last'], loop_orders=[[2, 1, 0]], num_stages=7, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[3], range_warp_specializes=[True])
21+
helion.Config(advanced_compiler_configuration=10, block_sizes=[1, 1, 64], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['last', ''], loop_orders=[[2, 0, 1]], num_stages=6, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[None], range_unroll_factors=[4], range_warp_specializes=[True])
22+
helion.Config(advanced_compiler_configuration=6, block_sizes=[4, 2, 128], flatten_loops=[True], indexing='block_ptr', l2_groupings=[1], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1, 2]], num_stages=6, num_warps=1, pid_type='persistent_blocked', range_flattens=[True], range_multi_buffers=[True], range_unroll_factors=[0], range_warp_specializes=[True])
23+
helion.Config(advanced_compiler_configuration=3, block_sizes=[2, 16, 2], flatten_loops=[True], indexing='block_ptr', l2_groupings=[64], load_eviction_policies=['first', 'first'], loop_orders=[[0, 2, 1]], num_stages=4, num_warps=16, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[None])
24+
helion.Config(advanced_compiler_configuration=5, block_sizes=[4, 128, 16], flatten_loops=[True], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0, 2]], num_stages=6, num_warps=4, pid_type='persistent_interleaved', range_flattens=[False], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
25+
helion.Config(advanced_compiler_configuration=4, block_sizes=[4, 256, 32], flatten_loops=[False], indexing='block_ptr', l2_groupings=[8], load_eviction_policies=['last', ''], loop_orders=[[2, 1, 0]], num_stages=5, num_warps=2, pid_type='persistent_blocked', range_flattens=[False], range_multi_buffers=[True], range_unroll_factors=[1], range_warp_specializes=[True])
26+
helion.Config(advanced_compiler_configuration=9, block_sizes=[2, 128, 8], flatten_loops=[True], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', ''], loop_orders=[[1, 2, 0]], num_stages=5, num_warps=32, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
2727

2828
--- assertExpectedJournal(TestAutotuner.test_save_load_config)
2929
{

0 commit comments

Comments
 (0)