|
16 | 16 | tensor_model_parallel_all_gather,
|
17 | 17 | tensor_model_parallel_all_reduce)
|
18 | 18 | from vllm.logger import init_logger
|
| 19 | +from vllm.model_executor.custom_op import CustomOp |
19 | 20 | from vllm.model_executor.layers.quantization.base_config import (
|
20 | 21 | QuantizationConfig, QuantizeMethodBase)
|
21 | 22 | from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
@@ -226,7 +227,7 @@ def apply(self,
|
226 | 227 | return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
|
227 | 228 |
|
228 | 229 |
|
229 |
| -class LinearBase(torch.nn.Module): |
| 230 | +class LinearBase(CustomOp): |
230 | 231 | """Base linear layer.
|
231 | 232 |
|
232 | 233 | Args:
|
@@ -269,12 +270,8 @@ def __init__(
|
269 | 270 | prefix=prefix)
|
270 | 271 | self.return_bias = return_bias
|
271 | 272 |
|
272 |
| - def forward( |
273 |
| - self, x: torch.Tensor |
274 |
| - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: |
275 |
| - raise NotImplementedError |
276 |
| - |
277 | 273 |
|
| 274 | +@CustomOp.register("replicated_linear") |
278 | 275 | class ReplicatedLinear(LinearBase):
|
279 | 276 | """Replicated linear layer.
|
280 | 277 |
|
@@ -443,6 +440,7 @@ def weight_loader(self,
|
443 | 440 | param[shard_offset:shard_offset + shard_size] = loaded_weight
|
444 | 441 |
|
445 | 442 |
|
| 443 | +@CustomOp.register("column_parallel_linear") |
446 | 444 | class ColumnParallelLinear(LinearBase):
|
447 | 445 | """Linear layer with column parallelism.
|
448 | 446 |
|
@@ -1229,6 +1227,7 @@ def weight_loader(self,
|
1229 | 1227 | param_data.copy_(loaded_weight)
|
1230 | 1228 |
|
1231 | 1229 |
|
| 1230 | +@CustomOp.register("row_parallel_linear") |
1232 | 1231 | class RowParallelLinear(LinearBase):
|
1233 | 1232 | """Linear layer with row parallelism.
|
1234 | 1233 |
|
@@ -1405,6 +1404,7 @@ def extra_repr(self) -> str:
|
1405 | 1404 | return s
|
1406 | 1405 |
|
1407 | 1406 |
|
| 1407 | +@CustomOp.register("qkv_cross_parallel_linear") |
1408 | 1408 | class QKVCrossParallelLinear(LinearBase):
|
1409 | 1409 | """Linear layers for efficient cross-attention's QKV transformation.
|
1410 | 1410 |
|
|
0 commit comments