Skip to content

Conversation

mengluy0125
Copy link
Contributor

Summary: We add block_size heuristics for autotune since it can have ValueError for tensors with large shape.

Differential Revision: D82742888

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 18, 2025
@facebook-github-bot
Copy link

@mengluy0125 has exported this pull request. If you are a Meta employee, you can view the originating diff in D82742888.

mengluy0125 added a commit to mengluy0125/helion that referenced this pull request Sep 18, 2025
…1048576) (pytorch#625)

Summary:

We add block_size heuristics for autotune since it can have ValueError for tensors with large shape.

Differential Revision: D82742888
@facebook-github-bot
Copy link

@mengluy0125 has exported this pull request. If you are a Meta employee, you can view the originating diff in D82742888.

Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO capping the size on a per-dimension basis isn't ideal. Since if you have multiple dimensions you need the product of the block sizes to be less than the limit. So the cap is not dimension independent.

Maybe we could improve the heuristics in:

def shrink_config(
self, flat_config: FlatConfig, max_elements_per_thread: int
) -> None:
"""
Fully random configs tend to run out of resources and tile a long time to compile.
Here we shrink the config to a reasonable size.
Args:
flat_config: config to mutate in place
max_elements_per_thread: maximum number of elements per thread
"""
num_threads = warps_to_threads(cast("int", flat_config[self.num_warps_index]))
max_elements = max_elements_per_thread * num_threads
while self.block_numel(flat_config) > max_elements:
changes = 0
for i in self.block_size_indices:
val = flat_config[i]
assert isinstance(val, int)
threshold = max(self.flat_spec[i].get_minimum(), self.min_block_size)
if val // 2 >= threshold:
flat_config[i] = val // 2
changes += 1
if changes == 0:
break

to address this problem?

@mengluy0125
Copy link
Contributor Author

IMO capping the size on a per-dimension basis isn't ideal. Since if you have multiple dimensions you need the product of the block sizes to be less than the limit. So the cap is not dimension independent.

Maybe we could improve the heuristics in:

def shrink_config(
self, flat_config: FlatConfig, max_elements_per_thread: int
) -> None:
"""
Fully random configs tend to run out of resources and tile a long time to compile.
Here we shrink the config to a reasonable size.
Args:
flat_config: config to mutate in place
max_elements_per_thread: maximum number of elements per thread
"""
num_threads = warps_to_threads(cast("int", flat_config[self.num_warps_index]))
max_elements = max_elements_per_thread * num_threads
while self.block_numel(flat_config) > max_elements:
changes = 0
for i in self.block_size_indices:
val = flat_config[i]
assert isinstance(val, int)
threshold = max(self.flat_spec[i].get_minimum(), self.min_block_size)
if val // 2 >= threshold:
flat_config[i] = val // 2
changes += 1
if changes == 0:
break

to address this problem?

Oh, make sense. Let me do the heuristics improvement inside the function.

…1048576) (pytorch#625)

Summary:

We add block_size heuristics for autotune since it can have ValueError for tensors with large shape.

Differential Revision: D82742888
@facebook-github-bot
Copy link

@mengluy0125 has exported this pull request. If you are a Meta employee, you can view the originating diff in D82742888.

Comment on lines +105 to +108
# Respect Triton's maximum tensor element limit
triton_limit = TRITON_MAX_TENSOR_NUMEL
theoretical_max_elements = max_elements_per_thread * num_threads
max_elements = min(theoretical_max_elements, triton_limit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I think this is overcounting, exact same problem I had on #485

self.block_numel(flat_config) > 2 ** 20

doesn't actually mean in the triton code we emit, we will end up with a tensor of numel 2 ** 20.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, though maybe this is better than what we have now.

Comment on lines +105 to +108
# Respect Triton's maximum tensor element limit
triton_limit = TRITON_MAX_TENSOR_NUMEL
theoretical_max_elements = max_elements_per_thread * num_threads
max_elements = min(theoretical_max_elements, triton_limit)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, though maybe this is better than what we have now.

@yf225
Copy link
Contributor

yf225 commented Sep 19, 2025

Merging it to fix benchmark CI issues.

@yf225 yf225 merged commit 3d5b641 into pytorch:main Sep 19, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants