-
-
Notifications
You must be signed in to change notification settings - Fork 9.6k
[Feature] Add support for MoE models in the calibration-free RTN-based quantization #20766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,18 +3,19 @@ | |
# Copyright © 2025, Oracle and/or its affiliates. | ||
|
||
import os | ||
from typing import Any, Optional | ||
from typing import Any, Callable, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.nn.parameter import Parameter | ||
|
||
from vllm.logger import init_logger | ||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase | ||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, | ||
set_weight_attrs) | ||
from vllm.model_executor.layers.quantization import QuantizationMethods | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
QuantizationConfig, QuantizeMethodBase) | ||
|
||
logger = init_logger(__name__) | ||
"""By default, use 8 bit as target precision, but it can be | ||
|
@@ -71,9 +72,11 @@ def from_config(cls, config: dict[str, Any]) -> "RTNConfig": | |
return cls(weight_bits, group_size) | ||
|
||
def get_quant_method(self, layer: torch.nn.Module, | ||
prefix: str) -> Optional["RTNLinearMethod"]: | ||
prefix: str) -> Optional["QuantizeMethodBase"]: | ||
if isinstance(layer, LinearBase): | ||
return RTNLinearMethod(self) | ||
elif isinstance(layer, FusedMoE): | ||
return RTNMoEMethod(self) | ||
return None | ||
|
||
|
||
|
@@ -94,11 +97,18 @@ def narrow(self, dim, start, length): | |
self.data.narrow(dim, start // factor, length // factor), | ||
self.scale.narrow(dim, start, length), self.quant_config) | ||
|
||
def __getitem__(self, key): | ||
return RTNTensor(self.data[key], self.scale[key], self.quant_config) | ||
|
||
@property | ||
def shape(self): | ||
shape = self.data.shape | ||
factor = 1 if self.quant_config.weight_bits == 8 else 2 | ||
return torch.Size((shape[0] * factor, shape[1])) | ||
batch_present = len(shape) == 3 | ||
if batch_present: | ||
return torch.Size((shape[0], shape[1] * factor, shape[2])) | ||
else: | ||
return torch.Size((shape[0] * factor, shape[1])) | ||
|
||
def copy_(self, loaded_weight: torch.Tensor) -> None: | ||
qweight, weight_scale = rtn_quantize(loaded_weight.cuda(), | ||
|
@@ -165,7 +175,7 @@ def create_weights( | |
weight = RTNParameter(data=torch.empty(output_size_per_partition // | ||
factor, | ||
input_size_per_partition, | ||
dtype=torch.int8), | ||
dtype=torch.uint8), | ||
scale=scale, | ||
quant_config=self.quant_config) | ||
|
||
|
@@ -180,18 +190,7 @@ def create_weights( | |
layer.output_size_per_partition = output_size_per_partition | ||
|
||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
"""torch.compile does not know how to deal with a Parameter subclass | ||
(aka RTNParameter). As we don't really need RTNParameters for the | ||
forward pass, we replace them with equivalent instances of Parameters. | ||
""" | ||
old_weight = layer.weight | ||
assert isinstance(old_weight, RTNParameter) | ||
data = old_weight.data.data | ||
|
||
delattr(layer, "weight") | ||
|
||
new_weight = Parameter(data=data, requires_grad=False) | ||
layer.register_parameter("weight", new_weight) | ||
fix_weights(layer, "weight") | ||
|
||
def apply(self, | ||
layer: torch.nn.Module, | ||
|
@@ -209,6 +208,128 @@ def apply(self, | |
return out | ||
|
||
|
||
class RTNMoEMethod(FusedMoEMethodBase): | ||
|
||
def __init__(self, quant_config: RTNConfig): | ||
self.quant_config = quant_config | ||
|
||
def create_weights(self, layer: torch.nn.Module, num_experts: int, | ||
hidden_size: int, intermediate_size_per_partition: int, | ||
params_dtype: torch.dtype, **extra_weight_attrs): | ||
|
||
factor = 1 if self.quant_config.weight_bits == 8 else 2 | ||
|
||
# Fused gate_up_proj (column parallel) | ||
num_groups_per_col = (hidden_size // self.quant_config.group_size | ||
if self.quant_config.group_size != -1 else 1) | ||
w13_scale = Parameter( | ||
torch.empty(num_experts, | ||
2 * intermediate_size_per_partition, | ||
num_groups_per_col, | ||
dtype=params_dtype), | ||
requires_grad=False, | ||
) | ||
layer.register_parameter("w13_scale", w13_scale) | ||
|
||
w13_weight = RTNParameter(data=torch.empty( | ||
num_experts, | ||
2 * intermediate_size_per_partition // factor, | ||
hidden_size, | ||
dtype=torch.uint8), | ||
scale=w13_scale, | ||
quant_config=self.quant_config) | ||
layer.register_parameter("w13_weight", w13_weight) | ||
set_weight_attrs(w13_weight, extra_weight_attrs) | ||
|
||
# down_proj (row parallel) | ||
num_groups_per_col = (intermediate_size_per_partition // | ||
self.quant_config.group_size | ||
if self.quant_config.group_size != -1 else 1) | ||
w2_scale = Parameter(torch.zeros(num_experts, | ||
hidden_size, | ||
num_groups_per_col, | ||
dtype=params_dtype), | ||
requires_grad=False) | ||
layer.register_parameter("w2_scale", w2_scale) | ||
|
||
w2_weight = RTNParameter(data=torch.empty( | ||
num_experts, | ||
hidden_size // factor, | ||
intermediate_size_per_partition, | ||
dtype=torch.uint8), | ||
scale=w2_scale, | ||
quant_config=self.quant_config) | ||
layer.register_parameter("w2_weight", w2_weight) | ||
set_weight_attrs(w2_weight, extra_weight_attrs) | ||
|
||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: | ||
weight_bits = self.quant_config.weight_bits | ||
fix_weights(layer, "w13_weight", weight_bits == 4) | ||
fix_weights(layer, "w2_weight", weight_bits == 4) | ||
|
||
def apply( | ||
self, | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
router_logits: torch.Tensor, | ||
top_k: int, | ||
renormalize: bool, | ||
use_grouped_topk: bool = False, | ||
topk_group: Optional[int] = None, | ||
num_expert_group: Optional[int] = None, | ||
global_num_experts: int = -1, | ||
expert_map: Optional[torch.Tensor] = None, | ||
custom_routing_function: Optional[Callable] = None, | ||
scoring_func: str = "softmax", | ||
e_score_correction_bias: Optional[torch.Tensor] = None, | ||
apply_router_weight_on_input: bool = False, | ||
activation: str = "silu", | ||
enable_eplb: bool = False, | ||
expert_load_view: Optional[torch.Tensor] = None, | ||
logical_to_physical_map: Optional[torch.Tensor] = None, | ||
logical_replica_count: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
if enable_eplb: | ||
raise NotImplementedError( | ||
"EPLB not supported for `RTNMoEMethod` yet.") | ||
|
||
from vllm.model_executor.layers.fused_moe import fused_experts | ||
|
||
topk_weights, topk_ids = FusedMoE.select_experts( | ||
hidden_states=x, | ||
router_logits=router_logits, | ||
use_grouped_topk=use_grouped_topk, | ||
top_k=top_k, | ||
renormalize=renormalize, | ||
topk_group=topk_group, | ||
num_expert_group=num_expert_group, | ||
custom_routing_function=custom_routing_function, | ||
scoring_func=scoring_func, | ||
e_score_correction_bias=e_score_correction_bias) | ||
|
||
weight_bits = self.quant_config.weight_bits | ||
group_size = self.quant_config.group_size | ||
|
||
ret = fused_experts( | ||
x, | ||
layer.w13_weight, | ||
layer.w2_weight, | ||
topk_weights=topk_weights, | ||
topk_ids=topk_ids, | ||
inplace=True, | ||
activation=activation, | ||
use_int4_w4a16=weight_bits == 4, | ||
use_int8_w8a16=weight_bits == 8, | ||
global_num_experts=global_num_experts, | ||
w1_scale=layer.w13_scale, | ||
w2_scale=layer.w2_scale, | ||
apply_router_weight_on_input=apply_router_weight_on_input, | ||
expert_map=expert_map, | ||
block_shape=[0, group_size]) | ||
|
||
return ret | ||
|
||
|
||
def rtn_quantize(tensor: torch.Tensor, num_bits: int, | ||
group_size: int) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Quantize a tensor using per-group static scaling factor. | ||
|
@@ -221,34 +342,44 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int, | |
If equal to -1, each row in the input tensor is treated | ||
as one group. | ||
""" | ||
batch_present = len(tensor.shape) == 3 | ||
if not batch_present: | ||
tensor = tensor.unsqueeze(0) | ||
|
||
q_range = 2**num_bits | ||
num_groups = (tensor.shape[0] * tensor.shape[1] // | ||
group_size if group_size != -1 else tensor.shape[0]) | ||
num_groups = (tensor.shape[1] * tensor.shape[2] // | ||
group_size if group_size != -1 else tensor.shape[1]) | ||
"""Calculate a scaling factor per input group. | ||
""" | ||
input_flat = tensor.reshape(num_groups, -1) | ||
input_min = torch.min(input_flat, dim=1, keepdim=True)[0] | ||
input_max = torch.max(input_flat, dim=1, keepdim=True)[0] | ||
input_flat = tensor.reshape(tensor.shape[0], num_groups, -1) | ||
input_min = torch.min(input_flat, dim=2, keepdim=True)[0] | ||
input_max = torch.max(input_flat, dim=2, keepdim=True)[0] | ||
input_max_abs = torch.max(input_min.abs(), input_max.abs()) | ||
scale = (input_max_abs * 2.0 / (q_range - 1)) | ||
"""Scale each input group, truncate and round to the nearest integer. | ||
"""Scale each input group, round to the nearest integer, shift | ||
the range and truncate. | ||
""" | ||
scaled_input = input_flat / scale | ||
scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1) | ||
scaled_input = scaled_input.round() | ||
scaled_input += q_range // 2 | ||
scaled_input = scaled_input.clamp(0, q_range - 1) | ||
|
||
scale = scale.reshape(tensor.shape[0], -1).contiguous() | ||
inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8) | ||
scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous() | ||
inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8) | ||
inputs_q = inputs_q.contiguous() | ||
|
||
if num_bits == 4: | ||
"""Pack two 4-bit values into each byte. | ||
""" | ||
inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf) | ||
inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1]) | ||
inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xf) | ||
inputs_q = inputs_q.reshape(tensor.shape[0], tensor.shape[1] // 2, | ||
tensor.shape[2]) | ||
inputs_q = inputs_q.contiguous() | ||
|
||
if not batch_present: | ||
inputs_q = inputs_q.squeeze(0) | ||
scale = scale.squeeze(0) | ||
|
||
return inputs_q, scale | ||
|
||
|
||
|
@@ -259,31 +390,60 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: | |
tensor: The input tensor. | ||
scale: The tensor with per-group scale factors. | ||
""" | ||
batch_present = len(tensor.shape) == 3 | ||
if not batch_present: | ||
tensor = tensor.unsqueeze(0) | ||
scale = scale.unsqueeze(0) | ||
|
||
num_groups = scale.size(0) * scale.size(1) | ||
input_dim, output_dim = tensor.shape | ||
num_groups = scale.size(1) * scale.size(2) | ||
batch, input_dim, output_dim = tensor.shape | ||
|
||
num_bits = 8 if input_dim == scale.size(0) else 4 | ||
num_bits = 8 if input_dim == scale.size(1) else 4 | ||
q_range = 2**num_bits | ||
if num_bits == 4: | ||
input_dim *= 2 | ||
|
||
data = torch.empty((input_dim, output_dim), | ||
data = torch.empty((batch, input_dim, output_dim), | ||
dtype=scale.dtype, | ||
device=tensor.device) | ||
|
||
if num_bits == 8: | ||
data.copy_(tensor) | ||
data -= q_range // 2 | ||
Comment on lines
411
to
+412
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's incorrect. The shift (subtraction) happens also when |
||
else: | ||
"""Unpack two 4-bit values from each byte. | ||
""" | ||
tensor = tensor.reshape(input_dim, output_dim // 2) | ||
tensor = tensor.reshape(batch, input_dim, output_dim // 2) | ||
for i in range(2): | ||
data[:, i::2] = (tensor << 4 * (1 - i)) >> 4 | ||
data[:, :, i::2] = ((tensor << 4 * | ||
(1 - i)) >> 4).to(torch.int8) - q_range // 2 | ||
"""Scale each input group with its scaling factor. | ||
""" | ||
scale = scale.reshape(num_groups, -1) | ||
data = data.reshape(num_groups, -1) | ||
scale = scale.reshape(batch, num_groups, -1) | ||
data = data.reshape(batch, num_groups, -1) | ||
data = torch.mul(data, scale) | ||
|
||
input_deq = data.reshape((input_dim, output_dim)).contiguous() | ||
input_deq = data.reshape((batch, input_dim, output_dim)).contiguous() | ||
if not batch_present: | ||
input_deq = input_deq.squeeze(0) | ||
|
||
return input_deq | ||
|
||
|
||
def fix_weights(layer: torch.nn.Module, | ||
param_name: str, | ||
reshape: bool = False): | ||
"""torch.compile does not know how to deal with a Parameter subclass | ||
(aka RTNParameter). As we don't really need RTNParameters for the | ||
forward pass, we replace them with equivalent instances of Parameters. | ||
""" | ||
old_weight = getattr(layer, param_name) | ||
assert isinstance(old_weight, RTNParameter) | ||
data = old_weight.data.data | ||
|
||
delattr(layer, param_name) | ||
|
||
if reshape: | ||
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1) | ||
mgoin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
new_weight = Parameter(data=data, requires_grad=False) | ||
layer.register_parameter(param_name, new_weight) |
Uh oh!
There was an error while loading. Please reload this page.