From 0602e0c6e544c21e91cd91f47d97ce809bda7121 Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Thu, 10 Jul 2025 11:30:14 -0400 Subject: [PATCH 1/3] add support for MoE models quantized with RTN Signed-off-by: Alex Kogan --- .../model_executor/layers/quantization/rtn.py | 232 +++++++++++++++--- 1 file changed, 196 insertions(+), 36 deletions(-) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 68309716cf90..05e536720ca1 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -3,18 +3,19 @@ # Copyright © 2025, Oracle and/or its affiliates. import os -from typing import Any, Optional +from typing import Any, Callable, Optional import torch import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) logger = init_logger(__name__) """By default, use 8 bit as target precision, but it can be @@ -71,9 +72,11 @@ def from_config(cls, config: dict[str, Any]) -> "RTNConfig": return cls(weight_bits, group_size) def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["RTNLinearMethod"]: + prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return RTNLinearMethod(self) + elif isinstance(layer, FusedMoE): + return RTNMoEMethod(self) return None @@ -94,11 +97,18 @@ def narrow(self, dim, start, length): self.data.narrow(dim, start // factor, length // factor), self.scale.narrow(dim, start, length), self.quant_config) + def __getitem__(self, key): + return RTNTensor(self.data[key], self.scale[key], self.quant_config) + @property def shape(self): shape = self.data.shape factor = 1 if self.quant_config.weight_bits == 8 else 2 - return torch.Size((shape[0] * factor, shape[1])) + batch_present = len(shape) == 3 + if batch_present: + return torch.Size((shape[0], shape[1] * factor, shape[2])) + else: + return torch.Size((shape[0] * factor, shape[1])) def copy_(self, loaded_weight: torch.Tensor) -> None: qweight, weight_scale = rtn_quantize(loaded_weight.cuda(), @@ -165,7 +175,7 @@ def create_weights( weight = RTNParameter(data=torch.empty(output_size_per_partition // factor, input_size_per_partition, - dtype=torch.int8), + dtype=torch.uint8), scale=scale, quant_config=self.quant_config) @@ -180,18 +190,7 @@ def create_weights( layer.output_size_per_partition = output_size_per_partition def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """torch.compile does not know how to deal with a Parameter subclass - (aka RTNParameter). As we don't really need RTNParameters for the - forward pass, we replace them with equivalent instances of Parameters. - """ - old_weight = layer.weight - assert isinstance(old_weight, RTNParameter) - data = old_weight.data.data - - delattr(layer, "weight") - - new_weight = Parameter(data=data, requires_grad=False) - layer.register_parameter("weight", new_weight) + fix_weights(layer, "weight") def apply(self, layer: torch.nn.Module, @@ -209,6 +208,128 @@ def apply(self, return out +class RTNMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: RTNConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + factor = 1 if self.quant_config.weight_bits == 8 else 2 + + # Fused gate_up_proj (column parallel) + num_groups_per_col = (hidden_size // self.quant_config.group_size + if self.quant_config.group_size != -1 else 1) + w13_scale = Parameter( + torch.empty(num_experts, + 2 * intermediate_size_per_partition, + num_groups_per_col, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_scale", w13_scale) + + w13_weight = RTNParameter(data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition // factor, + hidden_size, + dtype=torch.uint8), + scale=w13_scale, + quant_config=self.quant_config) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + num_groups_per_col = (intermediate_size_per_partition // + self.quant_config.group_size + if self.quant_config.group_size != -1 else 1) + w2_scale = Parameter(torch.zeros(num_experts, + hidden_size, + num_groups_per_col, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scale) + + w2_weight = RTNParameter(data=torch.empty( + num_experts, + hidden_size // factor, + intermediate_size_per_partition, + dtype=torch.uint8), + scale=w2_scale, + quant_config=self.quant_config) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight_bits = self.quant_config.weight_bits + fix_weights(layer, "w13_weight", weight_bits == 4) + fix_weights(layer, "w2_weight", weight_bits == 4) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `RTNMoEMethod` yet.") + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + weight_bits = self.quant_config.weight_bits + group_size = self.quant_config.group_size + + ret = fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=global_num_experts, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + block_shape=[0, group_size]) + + return ret + + def rtn_quantize(tensor: torch.Tensor, num_bits: int, group_size: int) -> tuple[torch.Tensor, torch.Tensor]: """Quantize a tensor using per-group static scaling factor. @@ -221,34 +342,44 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int, If equal to -1, each row in the input tensor is treated as one group. """ + batch_present = len(tensor.shape) == 3 + if not batch_present: + tensor = tensor.unsqueeze(0) q_range = 2**num_bits - num_groups = (tensor.shape[0] * tensor.shape[1] // - group_size if group_size != -1 else tensor.shape[0]) + num_groups = (tensor.shape[1] * tensor.shape[2] // + group_size if group_size != -1 else tensor.shape[1]) """Calculate a scaling factor per input group. """ - input_flat = tensor.reshape(num_groups, -1) - input_min = torch.min(input_flat, dim=1, keepdim=True)[0] - input_max = torch.max(input_flat, dim=1, keepdim=True)[0] + input_flat = tensor.reshape(tensor.shape[0], num_groups, -1) + input_min = torch.min(input_flat, dim=2, keepdim=True)[0] + input_max = torch.max(input_flat, dim=2, keepdim=True)[0] input_max_abs = torch.max(input_min.abs(), input_max.abs()) scale = (input_max_abs * 2.0 / (q_range - 1)) """Scale each input group, truncate and round to the nearest integer. """ scaled_input = input_flat / scale - scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1) scaled_input = scaled_input.round() - scale = scale.reshape(tensor.shape[0], -1).contiguous() - inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8) + scaled_input += q_range // 2 + scaled_input = scaled_input.clamp(0, q_range - 1) + + scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous() + inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8) inputs_q = inputs_q.contiguous() if num_bits == 4: """Pack two 4-bit values into each byte. """ - inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf) - inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1]) + inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xf) + inputs_q = inputs_q.reshape(tensor.shape[0], tensor.shape[1] // 2, + tensor.shape[2]) inputs_q = inputs_q.contiguous() + if not batch_present: + inputs_q = inputs_q.squeeze(0) + scale = scale.squeeze(0) + return inputs_q, scale @@ -259,31 +390,60 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: tensor: The input tensor. scale: The tensor with per-group scale factors. """ + batch_present = len(tensor.shape) == 3 + if not batch_present: + tensor = tensor.unsqueeze(0) + scale = scale.unsqueeze(0) - num_groups = scale.size(0) * scale.size(1) - input_dim, output_dim = tensor.shape + num_groups = scale.size(1) * scale.size(2) + batch, input_dim, output_dim = tensor.shape - num_bits = 8 if input_dim == scale.size(0) else 4 + num_bits = 8 if input_dim == scale.size(1) else 4 + q_range = 2**num_bits if num_bits == 4: input_dim *= 2 - data = torch.empty((input_dim, output_dim), + data = torch.empty((batch, input_dim, output_dim), dtype=scale.dtype, device=tensor.device) if num_bits == 8: data.copy_(tensor) + data -= q_range // 2 else: """Unpack two 4-bit values from each byte. """ - tensor = tensor.reshape(input_dim, output_dim // 2) + tensor = tensor.reshape(batch, input_dim, output_dim // 2) for i in range(2): - data[:, i::2] = (tensor << 4 * (1 - i)) >> 4 + data[:, :, i::2] = ((tensor << 4 * + (1 - i)) >> 4).to(torch.int8) - q_range // 2 """Scale each input group with its scaling factor. """ - scale = scale.reshape(num_groups, -1) - data = data.reshape(num_groups, -1) + scale = scale.reshape(batch, num_groups, -1) + data = data.reshape(batch, num_groups, -1) data = torch.mul(data, scale) - input_deq = data.reshape((input_dim, output_dim)).contiguous() + input_deq = data.reshape((batch, input_dim, output_dim)).contiguous() + if not batch_present: + input_deq = input_deq.squeeze(0) + return input_deq + + +def fix_weights(layer: torch.nn.Module, + param_name: str, + reshape: bool = False): + """torch.compile does not know how to deal with a Parameter subclass + (aka RTNParameter). As we don't really need RTNParameters for the + forward pass, we replace them with equivalent instances of Parameters. + """ + old_weight = getattr(layer, param_name) + assert isinstance(old_weight, RTNParameter) + data = old_weight.data.data + + delattr(layer, param_name) + + if reshape: + data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1) + new_weight = Parameter(data=data, requires_grad=False) + layer.register_parameter(param_name, new_weight) From 78dcad8246f710c2ef9fe712f67894c02cb7d4d3 Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Thu, 10 Jul 2025 11:30:37 -0400 Subject: [PATCH 2/3] expand test_rtn.py to include an MoE model Signed-off-by: Alex Kogan --- tests/quantization/test_rtn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/quantization/test_rtn.py b/tests/quantization/test_rtn.py index 133b2d9e4df6..bc2b468f97d8 100644 --- a/tests/quantization/test_rtn.py +++ b/tests/quantization/test_rtn.py @@ -8,7 +8,10 @@ from tests.quantization.utils import is_quant_method_supported -MODELS = ["microsoft/Phi-3-mini-4k-instruct"] +MODELS = [ + "microsoft/Phi-3-mini-4k-instruct", # dense model + "ai21labs/Jamba-tiny-dev", # MoE model +] @pytest.mark.skipif(not is_quant_method_supported("rtn"), From 1645804754a54377f7b3983b71d4d7b2ebbd0907 Mon Sep 17 00:00:00 2001 From: Alex Kogan Date: Sat, 19 Jul 2025 09:01:22 -0400 Subject: [PATCH 3/3] comment fix Signed-off-by: Alex Kogan --- vllm/model_executor/layers/quantization/rtn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index 05e536720ca1..cceaf9857c40 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -356,11 +356,11 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int, input_max = torch.max(input_flat, dim=2, keepdim=True)[0] input_max_abs = torch.max(input_min.abs(), input_max.abs()) scale = (input_max_abs * 2.0 / (q_range - 1)) - """Scale each input group, truncate and round to the nearest integer. + """Scale each input group, round to the nearest integer, shift + the range and truncate. """ scaled_input = input_flat / scale scaled_input = scaled_input.round() - scaled_input += q_range // 2 scaled_input = scaled_input.clamp(0, q_range - 1)