Skip to content
32 changes: 28 additions & 4 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm import _custom_ops as ops
from vllm import envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op


def get_token_bin_counts_and_mask(
Expand Down Expand Up @@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module,
return torch.nn.functional.linear(x, weight, bias)


def rocm_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
def rocm_unquantized_gemm_impl(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
from vllm.platforms.rocm import on_gfx9
k = weight.shape[1]
use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \
Expand All @@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
return torch.nn.functional.linear(x, weight, bias)


def rocm_unquantized_gemm_impl_fake(
Copy link
Collaborator

@zou3519 zou3519 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this offline, but it is generally better to put the minimal amount of things needed into the custom op. In this situation I think that's the following (+ some dependencies)

    if m > 8 and 0 < n <= 4:
        out = ops.wvSplitK(weight, x_view, cu_count)
        return out.view(*x.shape[:-1], weight.shape[0])
    elif m % 4 == 0 and n == 1 and k <= 8192:
        out = ops.LLMM1(weight, x_view, 4)
        return out.view(*x.shape[:-1], weight.shape[0])
    else
        return torch.nn.functional.linear(x, weight, bias)

the reason being is that torch.compile may be able to optimize the nn.Linears. For example, it is able to select different matmul kernels, or fuse operations into it (if there are fusable operations nearby).

That being said, it's not clear to me how much torch.compile is able to do for matmuls on ROCM, so, feel free to ship this as-is

x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return x.new_empty((*x.shape[:-1], weight.shape[0]))


def rocm_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Let's add the return type here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias)


direct_register_custom_op(
op_name="rocm_unquantized_gemm_impl",
op_func=rocm_unquantized_gemm_impl,
mutates_args=[],
fake_impl=rocm_unquantized_gemm_impl_fake,
dispatch_key=current_platform.dispatch_key,
)


def cpu_unquantized_gemm(layer: torch.nn.Module,
x: torch.Tensor,
weight: torch.Tensor,
Expand Down