-
-
Notifications
You must be signed in to change notification settings - Fork 9.9k
[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile #21350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile #21350
Conversation
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an issue with torch.compile
on ROCm platforms where a specific kernel call (wvSplitK
) was being omitted. The fix involves refactoring rocm_unquantized_gemm
into a custom PyTorch operator. This is a sound approach to work around compiler issues. My review identified a critical issue in the 'fake' implementation of this new custom operator, which could lead to incorrect shape inference for input tensors with more than two dimensions. I've provided a suggestion to fix this.
Signed-off-by: Randall Smith <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this bug manifest on main? It sounds like we just get a wrong answer when the dispatching logic decides to run skinny_gemm in V1? Do you have a minimal reproducible example?
vllm/model_executor/layers/utils.py
Outdated
def rocm_unquantized_gemm(layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
weight: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Let's add the return type here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: Randall Smith <[email protected]>
I'm not sure what the root cause is related to how torch.compile or the internal vllm compiliation logic is implemented, but you can see the bug by running on almost any model, I used Llama-3.1-8B-Instruct for testing. Then just run the profiler and see how many times the function is called. |
Ok so if I'm understanding correctly the problem is that, even if all of the conditions for skinny gemm are true, Before I accept can you just quickly verify that this issue persists even after you clear out both torch.compile caches? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a reasonable fix for now, but it would be good to understand what's going on here.
@@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, | |||
return torch.nn.functional.linear(x, weight, bias) | |||
|
|||
|
|||
def rocm_unquantized_gemm_impl_fake( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned this offline, but it is generally better to put the minimal amount of things needed into the custom op. In this situation I think that's the following (+ some dependencies)
if m > 8 and 0 < n <= 4:
out = ops.wvSplitK(weight, x_view, cu_count)
return out.view(*x.shape[:-1], weight.shape[0])
elif m % 4 == 0 and n == 1 and k <= 8192:
out = ops.LLMM1(weight, x_view, 4)
return out.view(*x.shape[:-1], weight.shape[0])
else
return torch.nn.functional.linear(x, weight, bias)
the reason being is that torch.compile may be able to optimize the nn.Linears. For example, it is able to select different matmul kernels, or fuse operations into it (if there are fusable operations nearby).
That being said, it's not clear to me how much torch.compile is able to do for matmuls on ROCM, so, feel free to ship this as-is
To answer the overall question, @SageMoore: The way vLLM uses torch.compile is that it's using it to capture one single graph that works for all batch sizes. If there are branches on the shape in the graph, then the graph capture will pick one of the branches (it will pick the branch for the batch size vLLM is using to perform the initial graph capture with), and then the resulting graph will only be correct for the conditions on that branch. In this case here, the condition is batch_size > 4. The workarounds are generally:
In this case here the branching is for kernel selection, so hiding the logic in a custom operator seems reasonable. |
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: shuw <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: x22x22 <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: x22x22 <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: Jinzhen Lin <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: Paul Pak <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: Boyuan Feng <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]> Signed-off-by: Diego-Castan <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]>
…1-4) due to torch.compile (vllm-project#21350) Signed-off-by: Randall Smith <[email protected]>
It turns out that torch.compile was omitting calls to wvSplitK, our skinny gemm, since it was never being called after torch.compile compilation.
Several attempts were made to fix this, such as lifting the conditional logic in rocm_unquantized_gemm into its caller, converting wvSplitK into a custom op, and twiddling with the logic in rocm_unquantized_gemm.
Converting rocm_unquantized_gemm into a custom op via direct_register_custom_op with register_fake fixes the problem. Confirmed fix with profiler runs.
Note: this is for small batch sizes, batch size 1-4.
Pertinent information from profiler run below, number of calls at end of profiler information.
**** Without direct_register_custom_op ****
void wvSplitK_hf_sml_<__hip_bfloat16, 64, 2, 16, 8, ... 0.00% 0.000us 0.00% 0.000us 0.000us 57.828ms 4.10% 57.828ms 225.892us 256
*** With direct_register_custom_op ***
void wvSplitK_hf_sml_<__hip_bfloat16, 64, 2, 16, 8, ... 0.00% 0.000us 0.00% 0.000us 0.000us 948.566ms 70.23% 948.566ms 28.835us 32896