Skip to content

Commit 46f6aa6

Browse files
authored
Prioritize outermost loop for warp spec (#1000)
1 parent 63e022f commit 46f6aa6

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

helion/autotuner/config_spec.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,24 @@ def normalize(self, config: helion.Config | dict[str, object]) -> None:
245245
name, config.get(name, ()), block_ids=self.grid_block_ids
246246
)
247247

248-
# Only one range_warp_specializes is allowed, take the last one
249248
range_warp_specializes = cast(
250249
"list[bool | None]", config.get("range_warp_specializes", [])
251250
)
252251

253252
if range_warp_specializes and any(range_warp_specializes):
254-
for i in [j for j, val in enumerate(range_warp_specializes) if val][:-1]:
253+
# Only one range_warp_specializes is allowed, take the first one
254+
# Prefer warp specialize on outermost loop
255+
first_idx = range_warp_specializes.index(True)
256+
for i in range(first_idx + 1, len(range_warp_specializes)):
255257
range_warp_specializes[i] = None
256258

257259
range_unroll_factors = cast(
258260
"list[int]", config.get("range_unroll_factors", [])
259261
)
260-
if range_unroll_factors and range_unroll_factors[-1]:
261-
range_unroll_factors[-1] = 0
262+
if range_unroll_factors and range_unroll_factors[first_idx] > 1:
263+
if range_unroll_factors[first_idx]:
264+
range_unroll_factors[first_idx] = 0
265+
262266
config["range_unroll_factors"] = range_unroll_factors
263267

264268
config["range_warp_specializes"] = range_warp_specializes

test/test_autotuner.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environmen
44
--- assertExpectedJournal(TestAutotuner.test_config_fragment0)
55
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 0], range_unroll_factors=[0, 0], range_warp_specializes=[None, None])
66
helion.Config(block_sizes=[32, 128, 64], indexing='tensor_descriptor', l2_groupings=[8], load_eviction_policies=['', ''], loop_orders=[[1, 0]], num_stages=8, num_warps=8, pid_type='persistent_blocked', range_flattens=[None, True], range_multi_buffers=[False, True], range_num_stages=[3, 0], range_unroll_factors=[1, 0], range_warp_specializes=[None, True])
7-
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[16], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[True, None], range_multi_buffers=[None, None], range_num_stages=[2, 0], range_unroll_factors=[2, 0], range_warp_specializes=[True, False])
8-
helion.Config(block_sizes=[16, 128, 64], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[2, 0], range_warp_specializes=[True, None])
7+
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[16], load_eviction_policies=['', ''], loop_orders=[[0, 1]], num_stages=7, num_warps=4, pid_type='persistent_interleaved', range_flattens=[True, None], range_multi_buffers=[None, None], range_num_stages=[2, 0], range_unroll_factors=[0, 3], range_warp_specializes=[True, None])
8+
helion.Config(block_sizes=[16, 128, 64], indexing='pointer', l2_groupings=[64], load_eviction_policies=['first', ''], loop_orders=[[1, 0]], num_stages=2, num_warps=16, pid_type='persistent_interleaved', range_flattens=[True, True], range_multi_buffers=[False, None], range_num_stages=[2, 4], range_unroll_factors=[0, 3], range_warp_specializes=[True, None])
99
helion.Config(block_sizes=[64, 32, 16], indexing='tensor_descriptor', l2_groupings=[2], load_eviction_policies=['first', 'last'], loop_orders=[[1, 0]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None, True], range_multi_buffers=[None, True], range_num_stages=[0, 4], range_unroll_factors=[0, 1], range_warp_specializes=[None, None])
1010
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[32], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=2, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True, False], range_multi_buffers=[True, None], range_num_stages=[3, 2], range_unroll_factors=[2, 2], range_warp_specializes=[False, False])
1111
helion.Config(block_sizes=[16, 16, 16], indexing='pointer', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[0, 1]], num_stages=5, num_warps=4, pid_type='persistent_interleaved', range_flattens=[None, True], range_multi_buffers=[False, False], range_num_stages=[3, 4], range_unroll_factors=[3, 2], range_warp_specializes=[None, None])
@@ -15,7 +15,7 @@ helion.Config(block_sizes=[16, 128, 16], indexing='pointer', l2_groupings=[8], l
1515

1616
--- assertExpectedJournal(TestAutotuner.test_config_fragment1)
1717
helion.Config(block_sizes=[8, 16, 16], flatten_loops=[False], indexing='pointer', l2_groupings=[1], load_eviction_policies=['', ''], loop_orders=[[0, 1, 2]], num_stages=2, num_warps=4, pid_type='flat', range_flattens=[None], range_multi_buffers=[None], range_num_stages=[], range_unroll_factors=[0], range_warp_specializes=[None])
18-
helion.Config(block_sizes=[1, 32, 32], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[0], range_warp_specializes=[True])
18+
helion.Config(block_sizes=[1, 32, 32], flatten_loops=[False], indexing='tensor_descriptor', l2_groupings=[4], load_eviction_policies=['first', 'first'], loop_orders=[[1, 2, 0]], num_stages=4, num_warps=8, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[False], range_unroll_factors=[1], range_warp_specializes=[True])
1919
helion.Config(block_sizes=[2, 512, 4], flatten_loops=[True], indexing='tensor_descriptor', l2_groupings=[16], load_eviction_policies=['last', ''], loop_orders=[[2, 1, 0]], num_stages=4, num_warps=1, pid_type='persistent_blocked', range_flattens=[None], range_multi_buffers=[None], range_unroll_factors=[3], range_warp_specializes=[False])
2020
helion.Config(block_sizes=[1, 2, 8], flatten_loops=[True], indexing='pointer', l2_groupings=[32], load_eviction_policies=['last', 'last'], loop_orders=[[1, 2, 0]], num_stages=7, num_warps=16, pid_type='persistent_interleaved', range_flattens=[False], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True])
2121
helion.Config(block_sizes=[1, 128, 4], flatten_loops=[True], indexing='pointer', l2_groupings=[2], load_eviction_policies=['', 'last'], loop_orders=[[0, 2, 1]], num_stages=6, num_warps=1, pid_type='persistent_interleaved', range_flattens=[True], range_multi_buffers=[None], range_unroll_factors=[0], range_warp_specializes=[True])

0 commit comments

Comments
 (0)