88from vllm import _custom_ops as ops
99from vllm import envs
1010from vllm .platforms import current_platform
11+ from vllm .utils import direct_register_custom_op
1112
1213
1314def get_token_bin_counts_and_mask (
@@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module,
7071 return torch .nn .functional .linear (x , weight , bias )
7172
7273
73- def rocm_unquantized_gemm ( layer : torch . nn . Module ,
74- x : torch .Tensor ,
75- weight : torch .Tensor ,
76- bias : Optional [torch .Tensor ] = None ):
74+ def rocm_unquantized_gemm_impl (
75+ x : torch .Tensor ,
76+ weight : torch .Tensor ,
77+ bias : Optional [torch .Tensor ] = None ) -> torch . Tensor :
7778 from vllm .platforms .rocm import on_gfx9
7879 k = weight .shape [1 ]
7980 use_skinny = (envs .VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9 () and \
@@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
9798 return torch .nn .functional .linear (x , weight , bias )
9899
99100
101+ def rocm_unquantized_gemm_impl_fake (
102+ x : torch .Tensor ,
103+ weight : torch .Tensor ,
104+ bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
105+ return x .new_empty ((* x .shape [:- 1 ], weight .shape [0 ]))
106+
107+
108+ def rocm_unquantized_gemm (layer : torch .nn .Module ,
109+ x : torch .Tensor ,
110+ weight : torch .Tensor ,
111+ bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
112+ return torch .ops .vllm .rocm_unquantized_gemm_impl (x , weight , bias )
113+
114+
115+ direct_register_custom_op (
116+ op_name = "rocm_unquantized_gemm_impl" ,
117+ op_func = rocm_unquantized_gemm_impl ,
118+ mutates_args = [],
119+ fake_impl = rocm_unquantized_gemm_impl_fake ,
120+ dispatch_key = current_platform .dispatch_key ,
121+ )
122+
123+
100124def cpu_unquantized_gemm (layer : torch .nn .Module ,
101125 x : torch .Tensor ,
102126 weight : torch .Tensor ,
0 commit comments