-
Notifications
You must be signed in to change notification settings - Fork 38
Fix ValueError: numel (2097152) exceeds triton maximum tensor numel (1048576) #625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@mengluy0125 has exported this pull request. If you are a Meta employee, you can view the originating diff in D82742888. |
…1048576) (pytorch#625) Summary: We add block_size heuristics for autotune since it can have ValueError for tensors with large shape. Differential Revision: D82742888
66de86d
to
ad6ac1e
Compare
@mengluy0125 has exported this pull request. If you are a Meta employee, you can view the originating diff in D82742888. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO capping the size on a per-dimension basis isn't ideal. Since if you have multiple dimensions you need the product of the block sizes to be less than the limit. So the cap is not dimension independent.
Maybe we could improve the heuristics in:
helion/helion/autotuner/config_generation.py
Lines 90 to 113 in 0a33998
def shrink_config( | |
self, flat_config: FlatConfig, max_elements_per_thread: int | |
) -> None: | |
""" | |
Fully random configs tend to run out of resources and tile a long time to compile. | |
Here we shrink the config to a reasonable size. | |
Args: | |
flat_config: config to mutate in place | |
max_elements_per_thread: maximum number of elements per thread | |
""" | |
num_threads = warps_to_threads(cast("int", flat_config[self.num_warps_index])) | |
max_elements = max_elements_per_thread * num_threads | |
while self.block_numel(flat_config) > max_elements: | |
changes = 0 | |
for i in self.block_size_indices: | |
val = flat_config[i] | |
assert isinstance(val, int) | |
threshold = max(self.flat_spec[i].get_minimum(), self.min_block_size) | |
if val // 2 >= threshold: | |
flat_config[i] = val // 2 | |
changes += 1 | |
if changes == 0: | |
break |
to address this problem?
Oh, make sense. Let me do the heuristics improvement inside the function. |
…1048576) (pytorch#625) Summary: We add block_size heuristics for autotune since it can have ValueError for tensors with large shape. Differential Revision: D82742888
ad6ac1e
to
0945338
Compare
@mengluy0125 has exported this pull request. If you are a Meta employee, you can view the originating diff in D82742888. |
# Respect Triton's maximum tensor element limit | ||
triton_limit = TRITON_MAX_TENSOR_NUMEL | ||
theoretical_max_elements = max_elements_per_thread * num_threads | ||
max_elements = min(theoretical_max_elements, triton_limit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think this is overcounting, exact same problem I had on #485
self.block_numel(flat_config) > 2 ** 20
doesn't actually mean in the triton code we emit, we will end up with a tensor of numel 2 ** 20.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, though maybe this is better than what we have now.
# Respect Triton's maximum tensor element limit | ||
triton_limit = TRITON_MAX_TENSOR_NUMEL | ||
theoretical_max_elements = max_elements_per_thread * num_threads | ||
max_elements = min(theoretical_max_elements, triton_limit) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, though maybe this is better than what we have now.
Merging it to fix benchmark CI issues. |
Summary: We add block_size heuristics for autotune since it can have ValueError for tensors with large shape.
Differential Revision: D82742888