-
-
Notifications
You must be signed in to change notification settings - Fork 10.1k
Reduce block sizes to no run out of shared memory for flex + fp32 #23853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: drisspg <[email protected]>
69eb2b6
to
8c682cb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix an out-of-shared-memory error when using FlexAttention with fp32 on GPUs with limited shared memory. The change replaces a hardcoded block size with a relative reduction, which is a good improvement. My review includes a suggestion to make this logic more robust by dynamically calculating the shared memory requirement instead of relying on a magic number. This will prevent future errors and improve performance across different models.
kernel_options["BLOCK_M"] = kernel_options["BLOCK_M"] // 2 | ||
kernel_options["BLOCK_N"] = kernel_options["BLOCK_N"] // 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While halving the block size is a good improvement over the previous hardcoded value, the condition that triggers this reduction is based on a magic number (144 * 1024
). This appears to be a heuristic for a worst-case scenario (e.g., head_size=512
), which makes the logic brittle.
This can lead to two problems:
- Correctness: If a model with a
head_size
larger than anticipated is used, it could still lead to an out-of-shared-memory error. - Performance: For models with smaller
head_size
, the block size might be reduced unnecessarily, leading to suboptimal performance.
A more robust approach would be to dynamically calculate the estimated shared memory requirement based on the actual head_size
and dtype
of the query tensor. This would make the logic more resilient and performant across different model architectures.
Here is a suggested implementation that would replace lines 790-792:
if torch.cuda.is_available():
device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin
head_size = query.shape[-1]
dtype_size = query.element_size()
block_m = kernel_options["BLOCK_M"]
block_n = kernel_options["BLOCK_N"]
# Estimate shared memory for Q, K, and softmax stats (m, l).
# This is based on common Triton flash attention implementations.
required_smem = (block_m + block_n) * head_size * dtype_size + (block_m * block_n * 4)
if required_smem > max_shared_memory:
kernel_options["BLOCK_M"] //= 2
kernel_options["BLOCK_N"] //= 2
Signed-off-by: Huy Do <[email protected]>
Please fix pre-commit |
Thank @drisspg for the fix here. This has been merged via #20358 after @youkaichao merged the PR earlier. #20358 has your change to test it out with 2.8.0. I guess we can close this PR. |
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Seeing out of smem on fp32 on l4 CI
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.