From 7c05081f0a65175b5bc30371b5354560375adbe0 Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Wed, 1 Oct 2025 17:17:39 -0400 Subject: [PATCH 1/5] utilize Marlin GEMM kernels for RTN Signed-off-by: Alex Kogan --- .../model_executor/layers/quantization/rtn.py | 240 ++++++++++++++---- .../layers/quantization/utils/marlin_utils.py | 40 +++ 2 files changed, 226 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 015dc136bb82..b146ebeb947b 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -5,21 +5,24 @@ import os from typing import Any, Callable, Optional, Union +import numpy as np import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, - FusedMoEMethodBase) -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEQuantConfig, int4_w4a16_moe_quant_config, - int8_w8a16_moe_quant_config) +from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, + FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_rtn_marlin_linear, marlin_make_workspace_new) +from vllm.scalar_type import scalar_types logger = init_logger(__name__) """By default, use 8 bit as target precision, but it can be @@ -30,6 +33,9 @@ overridden by setting the RTN_GROUP_SIZE envvar """ GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128") +"""Global Marlin workspace shared by all modules +""" +workspace = None class RTNConfig(QuantizationConfig): @@ -49,6 +55,9 @@ def __init__( "Currently, only 4-bit or 8-bit weight quantization is " f"supported for RTN, but got {self.weight_bits} bits.") + self.quant_type = (scalar_types.uint8b128 + if self.weight_bits == 8 else scalar_types.uint4b8) + def __repr__(self) -> str: return (f"RTNConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size})") @@ -194,22 +203,30 @@ def create_weights( layer.output_size_per_partition = output_size_per_partition def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - fix_weights(layer, "weight") + """ Repack weights and scales for Marlin kernels. + """ + weight_bits = self.quant_config.weight_bits + + weight, scale = repack_weights(layer.weight, layer.scale, weight_bits) + + replace_parameter(layer, "weight", weight) + replace_parameter(layer, "scale", scale) + + init_workspace(layer.weight.device) def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = layer.weight - scale = layer.scale - - weight = rtn_dequantize(qweight, scale) - out = F.linear(x, weight) - del weight - if bias is not None: - out.add_(bias) - - return out + return apply_rtn_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.scale, + workspace=workspace, + quant_type=self.quant_config.quant_type, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + bias=bias) class RTNMoEMethod(FusedMoEMethodBase): @@ -268,24 +285,25 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """ Repack weights and scales for Marlin kernels. + """ weight_bits = self.quant_config.weight_bits - fix_weights(layer, "w13_weight", weight_bits == 4) - fix_weights(layer, "w2_weight", weight_bits == 4) + + w13_weight, w13_scale = repack_weights(layer.w13_weight, + layer.w13_scale, weight_bits) + replace_parameter(layer, "w13_weight", w13_weight) + replace_parameter(layer, "w13_scale", w13_scale) + + w2_weight, w2_scale = repack_weights(layer.w2_weight, layer.w2_scale, + weight_bits) + replace_parameter(layer, "w2_weight", w2_weight) + replace_parameter(layer, "w2_scale", w2_scale) + + init_workspace(layer.w13_weight.device) def get_fused_moe_quant_config( self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: - weight_bits = self.quant_config.weight_bits - group_size = self.quant_config.group_size - assert weight_bits == 4 or weight_bits == 8 - config_builder = (int4_w4a16_moe_quant_config - if weight_bits == 4 else int8_w8a16_moe_quant_config) - return config_builder( - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - w1_zp=None, - w2_zp=None, - block_shape=[0, group_size], - ) + return None def apply( self, @@ -316,8 +334,6 @@ def apply( raise NotImplementedError( "EPLB not supported for `RTNMoEMethod` yet.") - from vllm.model_executor.layers.fused_moe import fused_experts - topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -332,19 +348,22 @@ def apply( e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) - return fused_experts( + return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, + getattr(layer, "w13_bias", None), + getattr(layer, "w2_bias", None), + layer.w13_scale, + layer.w2_scale, + router_logits, + topk_weights, + topk_ids, + quant_type_id=self.quant_config.quant_type.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - quant_config=self.moe_quant_config, - ) + workspace=workspace) def rtn_quantize(tensor: torch.Tensor, num_bits: int, @@ -447,20 +466,133 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return input_deq -def fix_weights(layer: torch.nn.Module, - param_name: str, - reshape: bool = False): - """torch.compile does not know how to deal with a Parameter subclass - (aka RTNParameter). As we don't really need RTNParameters for the - forward pass, we replace them with equivalent instances of Parameters. +def _get_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm_arr = np.array(perm) + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm_arr = perm_arr.reshape((-1, 8))[:, interleave].ravel() + perm_tensor = torch.from_numpy(perm_arr) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm_tensor, scale_perm, scale_perm_single + + +_perm, _scale_perm, _scale_perm_single = _get_perms() + + +def pack_for_marlin(weight, scale, qbits): + batch = weight.shape[0] + + n = weight.size(1) + k = weight.size(2) + groupsize = k // scale.size(2) + + tile = 16 + s = scale.permute(0, 2, 1) # transpose + w = weight.permute(0, 2, 1) # transpose + if groupsize != k: + w = w.reshape((batch, -1, groupsize, n)) + w = w.permute(0, 2, 1, 3) + w = w.reshape((batch, groupsize, -1)) + s = s.reshape((batch, 1, -1)) + + if groupsize != k: + w = w.reshape((batch, groupsize, -1, n)) + w = w.permute(0, 2, 1, 3) + w = w.reshape((batch, k, n)).contiguous() + s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm] + else: + s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, + _scale_perm_single] + s = s.reshape((batch, -1, n)).contiguous() + w = w.reshape((batch, k // tile, tile, n // tile, tile)) + w = w.permute((0, 1, 3, 2, 4)) + w = w.reshape((batch, k // tile, n * tile)) + res = w + res = res.reshape((batch, -1, _perm.numel()))[:, :, + _perm].reshape(res.shape) + if qbits == 4: + q = torch.zeros((batch, res.shape[1], res.shape[2] // 2), + dtype=torch.int8, + device=w.device) + for i in range(2): + q |= res[:, :, i::2] << 4 * i + q = q.reshape(batch, -1, n).contiguous() + else: + q = res.clone() + q[:, :, 2::8] = res[:, :, 4::8] + q[:, :, 3::8] = res[:, :, 5::8] + q[:, :, 4::8] = res[:, :, 2::8] + q[:, :, 5::8] = res[:, :, 3::8] + q = q.reshape(batch, -1, n).to(torch.int8).contiguous() + + return q, s + + +def repack_8bit_into_32bit(input): + output = torch.zeros((input.shape[0], input.shape[1], input.shape[2] // 4), + dtype=torch.int32, + device=input.device) + for i in range(4): + output |= (input[:, :, i::4] & 0xff).to(torch.int32) << 8 * i + + return output + + +def repack_weights(qweight, scale, weight_bits): + batch_present = len(qweight.shape) == 3 + if not batch_present: + qweight = qweight.unsqueeze(0) + scale = scale.unsqueeze(0) + + if weight_bits == 4: + """Unpack two 4-bit values from each byte. + """ + qweight_unpacked = torch.empty( + (qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]), + dtype=torch.uint8, + device=qweight.device) + for i in range(2): + qweight_unpacked[:, :, i::2] = \ + ((qweight << 4 * (1 - i)) >> 4).reshape( + qweight.shape[0], + qweight.shape[1] * 2, + qweight.shape[2] // 2) + else: + qweight_unpacked = qweight + + qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, + weight_bits) + """Marlin kernels expect tensors in int32 format in a certain shape """ - old_weight = getattr(layer, param_name) - assert isinstance(old_weight, RTNParameter) - data = old_weight.data.data + qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8)) + qweight_reshaped = qweight_repacked.reshape(qweight.shape[0], + qweight.shape[2] // 16, -1) + if not batch_present: + qweight_reshaped = qweight_reshaped.squeeze(0) + scale_packed = scale_packed.squeeze(0) + + return qweight_reshaped, scale_packed - delattr(layer, param_name) - if reshape: - data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1) - new_weight = Parameter(data=data, requires_grad=False) - layer.register_parameter(param_name, new_weight) +def init_workspace(device): + global workspace + if workspace is None: + workspace = marlin_make_workspace_new(device, 4) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 317ad079b392..d1fbce73ac6b 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -477,3 +477,43 @@ def apply_awq_marlin_linear( is_zp_float=False) return output.reshape(out_shape) + + +def apply_rtn_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + bias, + weight_scale, + None, + None, + None, + None, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) + + return output.reshape(out_shape) From e7abb9a5a2ff8a0e8c1756ef99c2cb2403ca9106 Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Wed, 8 Oct 2025 09:43:37 -0400 Subject: [PATCH 2/5] fix style errors flagged by pre-commit Signed-off-by: Alex Kogan --- vllm/model_executor/layers/quantization/rtn.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index d897a9cbf1ea..ccc20f799cd2 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -12,26 +12,23 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, - FusedMoEQuantConfig -) -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, - FusedMoEMethodBase + FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, - set_weight_attrs + set_weight_attrs, ) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase + QuantizationConfig, + QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_rtn_marlin_linear, - marlin_make_workspace_new + apply_rtn_marlin_linear, + marlin_make_workspace_new, ) from vllm.scalar_type import scalar_types From 5cda8490b3616f4a398f22f40bd390e3e925c681 Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Wed, 8 Oct 2025 10:05:38 -0400 Subject: [PATCH 3/5] fix style errors flagged by pre-commit Signed-off-by: Alex Kogan --- .../model_executor/layers/quantization/rtn.py | 92 ++++++++++--------- .../layers/quantization/utils/marlin_utils.py | 73 ++++++++------- 2 files changed, 88 insertions(+), 77 deletions(-) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index ccc20f799cd2..d7ce7de23b64 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -40,7 +40,7 @@ """By default, use group size of 128 parameters, but it can be overridden by setting the RTN_GROUP_SIZE envvar """ -GROUP_SIZE = os.getenv('RTN_GROUP_SIZE', "128") +GROUP_SIZE = os.getenv("RTN_GROUP_SIZE", "128") """Global Marlin workspace shared by all modules """ workspace = None @@ -63,8 +63,9 @@ def __init__( f"supported for RTN, but got {self.weight_bits} bits." ) - self.quant_type = (scalar_types.uint8b128 - if self.weight_bits == 8 else scalar_types.uint4b8) + self.quant_type = ( + scalar_types.uint8b128 if self.weight_bits == 8 else scalar_types.uint4b8 + ) def __repr__(self) -> str: return ( @@ -227,8 +228,7 @@ def create_weights( layer.output_size_per_partition = output_size_per_partition def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """ Repack weights and scales for Marlin kernels. - """ + """ Repack weights and scales for Marlin kernels.""" weight_bits = self.quant_config.weight_bits weight, scale = repack_weights(layer.weight, layer.scale, weight_bits) @@ -238,10 +238,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: init_workspace(layer.weight.device) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: return apply_rtn_marlin_linear( input=x, weight=layer.weight, @@ -250,7 +252,8 @@ def apply(self, quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - bias=bias) + bias=bias, + ) class RTNMoEMethod(FusedMoEMethodBase): @@ -327,24 +330,26 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """ Repack weights and scales for Marlin kernels. - """ + """ Repack weights and scales for Marlin kernels.""" weight_bits = self.quant_config.weight_bits - w13_weight, w13_scale = repack_weights(layer.w13_weight, - layer.w13_scale, weight_bits) + w13_weight, w13_scale = repack_weights( + layer.w13_weight, layer.w13_scale, weight_bits + ) replace_parameter(layer, "w13_weight", w13_weight) replace_parameter(layer, "w13_scale", w13_scale) - w2_weight, w2_scale = repack_weights(layer.w2_weight, layer.w2_scale, - weight_bits) + w2_weight, w2_scale = repack_weights( + layer.w2_weight, layer.w2_scale, weight_bits + ) replace_parameter(layer, "w2_weight", w2_weight) replace_parameter(layer, "w2_scale", w2_scale) init_workspace(layer.w13_weight.device) def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + self, layer: torch.nn.Module + ) -> Optional[FusedMoEQuantConfig]: return None def apply( @@ -405,7 +410,8 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, - workspace=workspace) + workspace=workspace, + ) def rtn_quantize( @@ -521,8 +527,10 @@ def _get_perms(): col = i // 4 for block in [0, 1]: for row in [ - 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1 + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1 ]: perm1.append(16 * row + col + 8 * block) for j in range(4): @@ -537,8 +545,7 @@ def _get_perms(): scale_perm.extend([i + 8 * j for j in range(8)]) scale_perm_single = [] for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return perm_tensor, scale_perm, scale_perm_single @@ -567,19 +574,17 @@ def pack_for_marlin(weight, scale, qbits): w = w.reshape((batch, k, n)).contiguous() s = s.reshape((batch, -1, len(_scale_perm)))[:, :, _scale_perm] else: - s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, - _scale_perm_single] + s = s.reshape((batch, -1, len(_scale_perm_single)))[:, :, _scale_perm_single] s = s.reshape((batch, -1, n)).contiguous() w = w.reshape((batch, k // tile, tile, n // tile, tile)) w = w.permute((0, 1, 3, 2, 4)) w = w.reshape((batch, k // tile, n * tile)) res = w - res = res.reshape((batch, -1, _perm.numel()))[:, :, - _perm].reshape(res.shape) + res = res.reshape((batch, -1, _perm.numel()))[:, :, _perm].reshape(res.shape) if qbits == 4: - q = torch.zeros((batch, res.shape[1], res.shape[2] // 2), - dtype=torch.int8, - device=w.device) + q = torch.zeros( + (batch, res.shape[1], res.shape[2] // 2), dtype=torch.int8, device=w.device + ) for i in range(2): q |= res[:, :, i::2] << 4 * i q = q.reshape(batch, -1, n).contiguous() @@ -595,11 +600,13 @@ def pack_for_marlin(weight, scale, qbits): def repack_8bit_into_32bit(input): - output = torch.zeros((input.shape[0], input.shape[1], input.shape[2] // 4), - dtype=torch.int32, - device=input.device) + output = torch.zeros( + (input.shape[0], input.shape[1], input.shape[2] // 4), + dtype=torch.int32, + device=input.device, + ) for i in range(4): - output |= (input[:, :, i::4] & 0xff).to(torch.int32) << 8 * i + output |= (input[:, :, i::4] & 0xFF).to(torch.int32) << 8 * i return output @@ -616,23 +623,22 @@ def repack_weights(qweight, scale, weight_bits): qweight_unpacked = torch.empty( (qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2]), dtype=torch.uint8, - device=qweight.device) + device=qweight.device, + ) for i in range(2): - qweight_unpacked[:, :, i::2] = \ - ((qweight << 4 * (1 - i)) >> 4).reshape( - qweight.shape[0], - qweight.shape[1] * 2, - qweight.shape[2] // 2) + qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape( + qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2 + ) else: qweight_unpacked = qweight - qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, - weight_bits) + qweight_packed, scale_packed = pack_for_marlin(qweight_unpacked, scale, weight_bits) """Marlin kernels expect tensors in int32 format in a certain shape """ qweight_repacked = repack_8bit_into_32bit(qweight_packed.to(torch.uint8)) - qweight_reshaped = qweight_repacked.reshape(qweight.shape[0], - qweight.shape[2] // 16, -1) + qweight_reshaped = qweight_repacked.reshape( + qweight.shape[0], qweight.shape[2] // 16, -1 + ) if not batch_present: qweight_reshaped = qweight_reshaped.squeeze(0) scale_packed = scale_packed.squeeze(0) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 64c4b05c4e1a..b8f110e6620f 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -532,40 +532,45 @@ def apply_awq_marlin_linear( def apply_rtn_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT +) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition, ) - - use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), - n=output_size_per_partition, - k=reshaped_x.size(1), - device=input.device, - dtype=input.dtype) - - output = ops.gptq_marlin_gemm(reshaped_x, - None, - weight, - bias, - weight_scale, - None, - None, - None, - None, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - use_atomic_add=use_atomic_add, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype + ) + + output = ops.gptq_marlin_gemm( + reshaped_x, + None, + weight, + bias, + weight_scale, + None, + None, + None, + None, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) return output.reshape(out_shape) From 07b0b195eb0496cd8f9a8a6530dd694f1914e131 Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Wed, 8 Oct 2025 10:12:31 -0400 Subject: [PATCH 4/5] fix style errors flagged by pre-commit Signed-off-by: Alex Kogan --- vllm/model_executor/layers/quantization/rtn.py | 10 +++++----- .../layers/quantization/utils/marlin_utils.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index d7ce7de23b64..421bb0f86179 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -228,7 +228,7 @@ def create_weights( layer.output_size_per_partition = output_size_per_partition def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """ Repack weights and scales for Marlin kernels.""" + """Repack weights and scales for Marlin kernels.""" weight_bits = self.quant_config.weight_bits weight, scale = repack_weights(layer.weight, layer.scale, weight_bits) @@ -242,7 +242,7 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor] = None + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return apply_rtn_marlin_linear( input=x, @@ -330,7 +330,7 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """ Repack weights and scales for Marlin kernels.""" + """Repack weights and scales for Marlin kernels.""" weight_bits = self.quant_config.weight_bits w13_weight, w13_scale = repack_weights( @@ -530,7 +530,7 @@ def _get_perms(): 2 * (i % 4), 2 * (i % 4) + 1, 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1 + 2 * (i % 4 + 4) + 1, ]: perm1.append(16 * row + col + 8 * block) for j in range(4): @@ -627,7 +627,7 @@ def repack_weights(qweight, scale, weight_bits): ) for i in range(2): qweight_unpacked[:, :, i::2] = ((qweight << 4 * (1 - i)) >> 4).reshape( - qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2 + qweight.shape[0], qweight.shape[1] * 2, qweight.shape[2] // 2 ) else: qweight_unpacked = qweight diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index b8f110e6620f..55d433e1ba4a 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -540,7 +540,7 @@ def apply_rtn_marlin_linear( output_size_per_partition: int, input_size_per_partition: int, bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition,) @@ -550,7 +550,7 @@ def apply_rtn_marlin_linear( n=output_size_per_partition, k=reshaped_x.size(1), device=input.device, - dtype=input.dtype + dtype=input.dtype, ) output = ops.gptq_marlin_gemm( From b99ebd25d1a7382ddc7af4acc38874eda649de9a Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Mon, 13 Oct 2025 11:19:18 -0400 Subject: [PATCH 5/5] fix pre-commit error Signed-off-by: Alex Kogan --- vllm/model_executor/layers/quantization/utils/marlin_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 83e94146be0f..071fb4ba1686 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -538,7 +538,7 @@ def apply_rtn_marlin_linear( quant_type: ScalarType, output_size_per_partition: int, input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, + bias: torch.Tensor | None = None, use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, ) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1])