Skip to content

Conversation

rasmith
Copy link
Contributor

@rasmith rasmith commented Jul 22, 2025

It turns out that torch.compile was omitting calls to wvSplitK, our skinny gemm, since it was never being called after torch.compile compilation.

Several attempts were made to fix this, such as lifting the conditional logic in rocm_unquantized_gemm into its caller, converting wvSplitK into a custom op, and twiddling with the logic in rocm_unquantized_gemm.

Converting rocm_unquantized_gemm into a custom op via direct_register_custom_op with register_fake fixes the problem. Confirmed fix with profiler runs.

Note: this is for small batch sizes, batch size 1-4.

Pertinent information from profiler run below, number of calls at end of profiler information.

**** Without direct_register_custom_op ****

void wvSplitK_hf_sml_<__hip_bfloat16, 64, 2, 16, 8, ... 0.00% 0.000us 0.00% 0.000us 0.000us 57.828ms 4.10% 57.828ms 225.892us 256

*** With direct_register_custom_op ***

void wvSplitK_hf_sml_<__hip_bfloat16, 64, 2, 16, 8, ... 0.00% 0.000us 0.00% 0.000us 0.000us 948.566ms 70.23% 948.566ms 28.835us 32896

rasmith added 4 commits July 18, 2025 16:52
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the rocm Related to AMD ROCm label Jul 22, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an issue with torch.compile on ROCm platforms where a specific kernel call (wvSplitK) was being omitted. The fix involves refactoring rocm_unquantized_gemm into a custom PyTorch operator. This is a sound approach to work around compiler issues. My review identified a critical issue in the 'fake' implementation of this new custom operator, which could lead to incorrect shape inference for input tensors with more than two dimensions. I've provided a suggestion to fix this.

Signed-off-by: Randall Smith <[email protected]>
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this bug manifest on main? It sounds like we just get a wrong answer when the dispatching logic decides to run skinny_gemm in V1? Do you have a minimal reproducible example?

def rocm_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Let's add the return type here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Signed-off-by: Randall Smith <[email protected]>
@rasmith rasmith changed the title [AMD][BugFix] Fix omission of wvSplitK kernel due to torch.compile [AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile Jul 22, 2025
@rasmith
Copy link
Contributor Author

rasmith commented Jul 22, 2025

How does this bug manifest on main? It sounds like we just get a wrong answer when the dispatching logic decides to run skinny_gemm in V1? Do you have a minimal reproducible example?

I'm not sure what the root cause is related to how torch.compile or the internal vllm compiliation logic is implemented, but you can see the bug by running on almost any model, I used Llama-3.1-8B-Instruct for testing. Then just run the profiler and see how many times the function is called.

@SageMoore
Copy link
Contributor

Ok so if I'm understanding correctly the problem is that, even if all of the conditions for skinny gemm are true, rocm_unquantized_gemm will still use torch.nn.functional.linear when run through torch.compile? I ran a quick sharegpt serving benchmark and didn't see any slowdowns from the custom op registration so I think this is fine.

Before I accept can you just quickly verify that this issue persists even after you clear out both torch.compile caches?
~/.cache/vllm/torch_compile_cache/ and /tmp/torchinductor_$USER/?

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 24, 2025
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a reasonable fix for now, but it would be good to understand what's going on here.

@@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
return torch.nn.functional.linear(x, weight, bias)


def rocm_unquantized_gemm_impl_fake(
Copy link
Collaborator

@zou3519 zou3519 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this offline, but it is generally better to put the minimal amount of things needed into the custom op. In this situation I think that's the following (+ some dependencies)

    if m > 8 and 0 < n <= 4:
        out = ops.wvSplitK(weight, x_view, cu_count)
        return out.view(*x.shape[:-1], weight.shape[0])
    elif m % 4 == 0 and n == 1 and k <= 8192:
        out = ops.LLMM1(weight, x_view, 4)
        return out.view(*x.shape[:-1], weight.shape[0])
    else
        return torch.nn.functional.linear(x, weight, bias)

the reason being is that torch.compile may be able to optimize the nn.Linears. For example, it is able to select different matmul kernels, or fuse operations into it (if there are fusable operations nearby).

That being said, it's not clear to me how much torch.compile is able to do for matmuls on ROCM, so, feel free to ship this as-is

@zou3519
Copy link
Collaborator

zou3519 commented Jul 24, 2025

To answer the overall question, @SageMoore:

The way vLLM uses torch.compile is that it's using it to capture one single graph that works for all batch sizes. If there are branches on the shape in the graph, then the graph capture will pick one of the branches (it will pick the branch for the batch size vLLM is using to perform the initial graph capture with), and then the resulting graph will only be correct for the conditions on that branch.

In this case here, the condition is batch_size > 4.

The workarounds are generally:

  • don't do the branching
  • hide the logic inside a custom operator

In this case here the branching is for kernel selection, so hiding the logic in a custom operator seems reasonable.

@gshtras gshtras merged commit b361f14 into vllm-project:main Jul 28, 2025
64 checks passed
liuyumoye pushed a commit to liuyumoye/vllm that referenced this pull request Jul 31, 2025
wenscarl pushed a commit to wenscarl/vllm that referenced this pull request Aug 4, 2025
…1-4) due to torch.compile (vllm-project#21350)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: shuw <[email protected]>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…1-4) due to torch.compile (vllm-project#21350)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: x22x22 <[email protected]>
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…1-4) due to torch.compile (vllm-project#21350)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: x22x22 <[email protected]>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…1-4) due to torch.compile (vllm-project#21350)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…1-4) due to torch.compile (vllm-project#21350)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
taneem-ibrahim pushed a commit to taneem-ibrahim/vllm that referenced this pull request Aug 14, 2025
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Aug 14, 2025
…1-4) due to torch.compile (vllm-project#21350)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Boyuan Feng <[email protected]>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…1-4) due to torch.compile (vllm-project#21350)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
googlercolin pushed a commit to googlercolin/vllm that referenced this pull request Aug 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants