8
8
from vllm import _custom_ops as ops
9
9
from vllm import envs
10
10
from vllm .platforms import current_platform
11
+ from vllm .utils import direct_register_custom_op
11
12
12
13
13
14
def get_token_bin_counts_and_mask (
@@ -70,10 +71,10 @@ def default_unquantized_gemm(layer: torch.nn.Module,
70
71
return torch .nn .functional .linear (x , weight , bias )
71
72
72
73
73
- def rocm_unquantized_gemm ( layer : torch . nn . Module ,
74
- x : torch .Tensor ,
75
- weight : torch .Tensor ,
76
- bias : Optional [torch .Tensor ] = None ):
74
+ def rocm_unquantized_gemm_impl (
75
+ x : torch .Tensor ,
76
+ weight : torch .Tensor ,
77
+ bias : Optional [torch .Tensor ] = None ) -> torch . Tensor :
77
78
from vllm .platforms .rocm import on_gfx9
78
79
k = weight .shape [1 ]
79
80
use_skinny = (envs .VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9 () and \
@@ -97,6 +98,29 @@ def rocm_unquantized_gemm(layer: torch.nn.Module,
97
98
return torch .nn .functional .linear (x , weight , bias )
98
99
99
100
101
+ def rocm_unquantized_gemm_impl_fake (
102
+ x : torch .Tensor ,
103
+ weight : torch .Tensor ,
104
+ bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
105
+ return x .new_empty ((* x .shape [:- 1 ], weight .shape [0 ]))
106
+
107
+
108
+ def rocm_unquantized_gemm (layer : torch .nn .Module ,
109
+ x : torch .Tensor ,
110
+ weight : torch .Tensor ,
111
+ bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
112
+ return torch .ops .vllm .rocm_unquantized_gemm_impl (x , weight , bias )
113
+
114
+
115
+ direct_register_custom_op (
116
+ op_name = "rocm_unquantized_gemm_impl" ,
117
+ op_func = rocm_unquantized_gemm_impl ,
118
+ mutates_args = [],
119
+ fake_impl = rocm_unquantized_gemm_impl_fake ,
120
+ dispatch_key = current_platform .dispatch_key ,
121
+ )
122
+
123
+
100
124
def cpu_unquantized_gemm (layer : torch .nn .Module ,
101
125
x : torch .Tensor ,
102
126
weight : torch .Tensor ,
0 commit comments