-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile #21350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
gshtras
merged 9 commits into
vllm-project:main
from
rasmith:ransmith_fix_rocm_unquantized_gemm
Jul 28, 2025
Merged
[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile #21350
Changes from 5 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
25cefed
working register_fake, but need direct_register_custom_op
rasmith a35bcac
remove debug code
rasmith 0044390
formatting
rasmith d22d7c7
revert __init__
rasmith e355d2c
fake tensor
rasmith 1537e59
add return type
rasmith 5c7668b
Merge branch 'vllm-project:main' into ransmith_fix_rocm_unquantized_gemm
rasmith 504720c
Merge branch 'vllm-project:main' into ransmith_fix_rocm_unquantized_gemm
rasmith a860838
Merge branch 'vllm-project:main' into ransmith_fix_rocm_unquantized_gemm
rasmith File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| from vllm import _custom_ops as ops | ||
| from vllm import envs | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils import direct_register_custom_op | ||
|
|
||
|
|
||
| def get_token_bin_counts_and_mask( | ||
|
|
@@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module, | |
| return torch.nn.functional.linear(x, weight, bias) | ||
|
|
||
|
|
||
| def rocm_unquantized_gemm(layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None): | ||
| def rocm_unquantized_gemm_impl( | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| from vllm.platforms.rocm import on_gfx9 | ||
| k = weight.shape[1] | ||
| use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ | ||
|
|
@@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module, | |
| return torch.nn.functional.linear(x, weight, bias) | ||
|
|
||
|
|
||
| def rocm_unquantized_gemm_impl_fake( | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| return x.new_empty((*x.shape[:-1], weight.shape[0])) | ||
|
|
||
|
|
||
| def rocm_unquantized_gemm(layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None): | ||
|
||
| return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) | ||
|
|
||
|
|
||
| direct_register_custom_op( | ||
| op_name="rocm_unquantized_gemm_impl", | ||
| op_func=rocm_unquantized_gemm_impl, | ||
| mutates_args=[], | ||
| fake_impl=rocm_unquantized_gemm_impl_fake, | ||
| dispatch_key=current_platform.dispatch_key, | ||
| ) | ||
|
|
||
|
|
||
| def cpu_unquantized_gemm(layer: torch.nn.Module, | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned this offline, but it is generally better to put the minimal amount of things needed into the custom op. In this situation I think that's the following (+ some dependencies)
the reason being is that torch.compile may be able to optimize the nn.Linears. For example, it is able to select different matmul kernels, or fuse operations into it (if there are fusable operations nearby).
That being said, it's not clear to me how much torch.compile is able to do for matmuls on ROCM, so, feel free to ship this as-is