Skip to content

Commit b361f14

Browse files
authored
[AMD][BugFix] Fix omission of wvSplitK kernel for small batch sizes (1-4) due to torch.compile (#21350)
Signed-off-by: Randall Smith <[email protected]>
1 parent 01c753e commit b361f14

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

vllm/model_executor/layers/utils.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm import _custom_ops as ops
99
from vllm import envs
1010
from vllm.platforms import current_platform
11+
from vllm.utils import direct_register_custom_op
1112

1213

1314
def get_token_bin_counts_and_mask(
@@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module,
7071
return torch.nn.functional.linear(x, weight, bias)
7172

7273

73-
def rocm_unquantized_gemm(layer: torch.nn.Module,
74-
x: torch.Tensor,
75-
weight: torch.Tensor,
76-
bias: Optional[torch.Tensor] = None):
74+
def rocm_unquantized_gemm_impl(
75+
x: torch.Tensor,
76+
weight: torch.Tensor,
77+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
7778
from vllm.platforms.rocm import on_gfx9
7879
k = weight.shape[1]
7980
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
@@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
9798
return torch.nn.functional.linear(x, weight, bias)
9899

99100

101+
def rocm_unquantized_gemm_impl_fake(
102+
x: torch.Tensor,
103+
weight: torch.Tensor,
104+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
105+
return x.new_empty((*x.shape[:-1], weight.shape[0]))
106+
107+
108+
def rocm_unquantized_gemm(layer: torch.nn.Module,
109+
x: torch.Tensor,
110+
weight: torch.Tensor,
111+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
112+
return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)
113+
114+
115+
direct_register_custom_op(
116+
op_name="rocm_unquantized_gemm_impl",
117+
op_func=rocm_unquantized_gemm_impl,
118+
mutates_args=[],
119+
fake_impl=rocm_unquantized_gemm_impl_fake,
120+
dispatch_key=current_platform.dispatch_key,
121+
)
122+
123+
100124
def cpu_unquantized_gemm(layer: torch.nn.Module,
101125
x: torch.Tensor,
102126
weight: torch.Tensor,

0 commit comments

Comments
 (0)