Skip to content

Commit d8ed53f

Browse files
sakoganpaulpak58
authored andcommitted
[Feature] Add support for MoE models in the calibration-free RTN-based quantization (vllm-project#20766)
Signed-off-by: Alex Kogan <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent f8f84fd commit d8ed53f

File tree

2 files changed

+201
-38
lines changed

2 files changed

+201
-38
lines changed

tests/quantization/test_rtn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from tests.quantization.utils import is_quant_method_supported
1010

11-
MODELS = ["microsoft/Phi-3-mini-4k-instruct"]
11+
MODELS = [
12+
"microsoft/Phi-3-mini-4k-instruct", # dense model
13+
"ai21labs/Jamba-tiny-dev", # MoE model
14+
]
1215

1316

1417
@pytest.mark.skipif(not is_quant_method_supported("rtn"),

vllm/model_executor/layers/quantization/rtn.py

Lines changed: 197 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
# Copyright © 2025, Oracle and/or its affiliates.
44

55
import os
6-
from typing import Any, Optional
6+
from typing import Any, Callable, Optional
77

88
import torch
99
import torch.nn.functional as F
1010
from torch.nn.parameter import Parameter
1111

1212
from vllm.logger import init_logger
13+
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
1314
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1415
set_weight_attrs)
1516
from vllm.model_executor.layers.quantization import QuantizationMethods
1617
from vllm.model_executor.layers.quantization.base_config import (
17-
QuantizationConfig)
18+
QuantizationConfig, QuantizeMethodBase)
1819

1920
logger = init_logger(__name__)
2021
"""By default, use 8 bit as target precision, but it can be
@@ -71,9 +72,11 @@ def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
7172
return cls(weight_bits, group_size)
7273

7374
def get_quant_method(self, layer: torch.nn.Module,
74-
prefix: str) -> Optional["RTNLinearMethod"]:
75+
prefix: str) -> Optional["QuantizeMethodBase"]:
7576
if isinstance(layer, LinearBase):
7677
return RTNLinearMethod(self)
78+
elif isinstance(layer, FusedMoE):
79+
return RTNMoEMethod(self)
7780
return None
7881

7982

@@ -94,11 +97,18 @@ def narrow(self, dim, start, length):
9497
self.data.narrow(dim, start // factor, length // factor),
9598
self.scale.narrow(dim, start, length), self.quant_config)
9699

100+
def __getitem__(self, key):
101+
return RTNTensor(self.data[key], self.scale[key], self.quant_config)
102+
97103
@property
98104
def shape(self):
99105
shape = self.data.shape
100106
factor = 1 if self.quant_config.weight_bits == 8 else 2
101-
return torch.Size((shape[0] * factor, shape[1]))
107+
batch_present = len(shape) == 3
108+
if batch_present:
109+
return torch.Size((shape[0], shape[1] * factor, shape[2]))
110+
else:
111+
return torch.Size((shape[0] * factor, shape[1]))
102112

103113
def copy_(self, loaded_weight: torch.Tensor) -> None:
104114
qweight, weight_scale = rtn_quantize(loaded_weight.cuda(),
@@ -165,7 +175,7 @@ def create_weights(
165175
weight = RTNParameter(data=torch.empty(output_size_per_partition //
166176
factor,
167177
input_size_per_partition,
168-
dtype=torch.int8),
178+
dtype=torch.uint8),
169179
scale=scale,
170180
quant_config=self.quant_config)
171181

@@ -180,18 +190,7 @@ def create_weights(
180190
layer.output_size_per_partition = output_size_per_partition
181191

182192
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
183-
"""torch.compile does not know how to deal with a Parameter subclass
184-
(aka RTNParameter). As we don't really need RTNParameters for the
185-
forward pass, we replace them with equivalent instances of Parameters.
186-
"""
187-
old_weight = layer.weight
188-
assert isinstance(old_weight, RTNParameter)
189-
data = old_weight.data.data
190-
191-
delattr(layer, "weight")
192-
193-
new_weight = Parameter(data=data, requires_grad=False)
194-
layer.register_parameter("weight", new_weight)
193+
fix_weights(layer, "weight")
195194

196195
def apply(self,
197196
layer: torch.nn.Module,
@@ -209,6 +208,128 @@ def apply(self,
209208
return out
210209

211210

211+
class RTNMoEMethod(FusedMoEMethodBase):
212+
213+
def __init__(self, quant_config: RTNConfig):
214+
self.quant_config = quant_config
215+
216+
def create_weights(self, layer: torch.nn.Module, num_experts: int,
217+
hidden_size: int, intermediate_size_per_partition: int,
218+
params_dtype: torch.dtype, **extra_weight_attrs):
219+
220+
factor = 1 if self.quant_config.weight_bits == 8 else 2
221+
222+
# Fused gate_up_proj (column parallel)
223+
num_groups_per_col = (hidden_size // self.quant_config.group_size
224+
if self.quant_config.group_size != -1 else 1)
225+
w13_scale = Parameter(
226+
torch.empty(num_experts,
227+
2 * intermediate_size_per_partition,
228+
num_groups_per_col,
229+
dtype=params_dtype),
230+
requires_grad=False,
231+
)
232+
layer.register_parameter("w13_scale", w13_scale)
233+
234+
w13_weight = RTNParameter(data=torch.empty(
235+
num_experts,
236+
2 * intermediate_size_per_partition // factor,
237+
hidden_size,
238+
dtype=torch.uint8),
239+
scale=w13_scale,
240+
quant_config=self.quant_config)
241+
layer.register_parameter("w13_weight", w13_weight)
242+
set_weight_attrs(w13_weight, extra_weight_attrs)
243+
244+
# down_proj (row parallel)
245+
num_groups_per_col = (intermediate_size_per_partition //
246+
self.quant_config.group_size
247+
if self.quant_config.group_size != -1 else 1)
248+
w2_scale = Parameter(torch.zeros(num_experts,
249+
hidden_size,
250+
num_groups_per_col,
251+
dtype=params_dtype),
252+
requires_grad=False)
253+
layer.register_parameter("w2_scale", w2_scale)
254+
255+
w2_weight = RTNParameter(data=torch.empty(
256+
num_experts,
257+
hidden_size // factor,
258+
intermediate_size_per_partition,
259+
dtype=torch.uint8),
260+
scale=w2_scale,
261+
quant_config=self.quant_config)
262+
layer.register_parameter("w2_weight", w2_weight)
263+
set_weight_attrs(w2_weight, extra_weight_attrs)
264+
265+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
266+
weight_bits = self.quant_config.weight_bits
267+
fix_weights(layer, "w13_weight", weight_bits == 4)
268+
fix_weights(layer, "w2_weight", weight_bits == 4)
269+
270+
def apply(
271+
self,
272+
layer: torch.nn.Module,
273+
x: torch.Tensor,
274+
router_logits: torch.Tensor,
275+
top_k: int,
276+
renormalize: bool,
277+
use_grouped_topk: bool = False,
278+
topk_group: Optional[int] = None,
279+
num_expert_group: Optional[int] = None,
280+
global_num_experts: int = -1,
281+
expert_map: Optional[torch.Tensor] = None,
282+
custom_routing_function: Optional[Callable] = None,
283+
scoring_func: str = "softmax",
284+
e_score_correction_bias: Optional[torch.Tensor] = None,
285+
apply_router_weight_on_input: bool = False,
286+
activation: str = "silu",
287+
enable_eplb: bool = False,
288+
expert_load_view: Optional[torch.Tensor] = None,
289+
logical_to_physical_map: Optional[torch.Tensor] = None,
290+
logical_replica_count: Optional[torch.Tensor] = None,
291+
) -> torch.Tensor:
292+
if enable_eplb:
293+
raise NotImplementedError(
294+
"EPLB not supported for `RTNMoEMethod` yet.")
295+
296+
from vllm.model_executor.layers.fused_moe import fused_experts
297+
298+
topk_weights, topk_ids = FusedMoE.select_experts(
299+
hidden_states=x,
300+
router_logits=router_logits,
301+
use_grouped_topk=use_grouped_topk,
302+
top_k=top_k,
303+
renormalize=renormalize,
304+
topk_group=topk_group,
305+
num_expert_group=num_expert_group,
306+
custom_routing_function=custom_routing_function,
307+
scoring_func=scoring_func,
308+
e_score_correction_bias=e_score_correction_bias)
309+
310+
weight_bits = self.quant_config.weight_bits
311+
group_size = self.quant_config.group_size
312+
313+
ret = fused_experts(
314+
x,
315+
layer.w13_weight,
316+
layer.w2_weight,
317+
topk_weights=topk_weights,
318+
topk_ids=topk_ids,
319+
inplace=True,
320+
activation=activation,
321+
use_int4_w4a16=weight_bits == 4,
322+
use_int8_w8a16=weight_bits == 8,
323+
global_num_experts=global_num_experts,
324+
w1_scale=layer.w13_scale,
325+
w2_scale=layer.w2_scale,
326+
apply_router_weight_on_input=apply_router_weight_on_input,
327+
expert_map=expert_map,
328+
block_shape=[0, group_size])
329+
330+
return ret
331+
332+
212333
def rtn_quantize(tensor: torch.Tensor, num_bits: int,
213334
group_size: int) -> tuple[torch.Tensor, torch.Tensor]:
214335
"""Quantize a tensor using per-group static scaling factor.
@@ -221,34 +342,44 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int,
221342
If equal to -1, each row in the input tensor is treated
222343
as one group.
223344
"""
345+
batch_present = len(tensor.shape) == 3
346+
if not batch_present:
347+
tensor = tensor.unsqueeze(0)
224348

225349
q_range = 2**num_bits
226-
num_groups = (tensor.shape[0] * tensor.shape[1] //
227-
group_size if group_size != -1 else tensor.shape[0])
350+
num_groups = (tensor.shape[1] * tensor.shape[2] //
351+
group_size if group_size != -1 else tensor.shape[1])
228352
"""Calculate a scaling factor per input group.
229353
"""
230-
input_flat = tensor.reshape(num_groups, -1)
231-
input_min = torch.min(input_flat, dim=1, keepdim=True)[0]
232-
input_max = torch.max(input_flat, dim=1, keepdim=True)[0]
354+
input_flat = tensor.reshape(tensor.shape[0], num_groups, -1)
355+
input_min = torch.min(input_flat, dim=2, keepdim=True)[0]
356+
input_max = torch.max(input_flat, dim=2, keepdim=True)[0]
233357
input_max_abs = torch.max(input_min.abs(), input_max.abs())
234358
scale = (input_max_abs * 2.0 / (q_range - 1))
235-
"""Scale each input group, truncate and round to the nearest integer.
359+
"""Scale each input group, round to the nearest integer, shift
360+
the range and truncate.
236361
"""
237362
scaled_input = input_flat / scale
238-
scaled_input = scaled_input.clamp(-q_range // 2, q_range // 2 - 1)
239363
scaled_input = scaled_input.round()
364+
scaled_input += q_range // 2
365+
scaled_input = scaled_input.clamp(0, q_range - 1)
240366

241-
scale = scale.reshape(tensor.shape[0], -1).contiguous()
242-
inputs_q = scaled_input.reshape(tensor.shape).to(torch.int8)
367+
scale = scale.reshape(tensor.shape[0], tensor.shape[1], -1).contiguous()
368+
inputs_q = scaled_input.reshape(tensor.shape).to(torch.uint8)
243369
inputs_q = inputs_q.contiguous()
244370

245371
if num_bits == 4:
246372
"""Pack two 4-bit values into each byte.
247373
"""
248-
inputs_q = (inputs_q[:, 1::2] << 4) | (inputs_q[:, ::2] & 0xf)
249-
inputs_q = inputs_q.reshape(tensor.shape[0] // 2, tensor.shape[1])
374+
inputs_q = (inputs_q[:, :, 1::2] << 4) | (inputs_q[:, :, ::2] & 0xf)
375+
inputs_q = inputs_q.reshape(tensor.shape[0], tensor.shape[1] // 2,
376+
tensor.shape[2])
250377
inputs_q = inputs_q.contiguous()
251378

379+
if not batch_present:
380+
inputs_q = inputs_q.squeeze(0)
381+
scale = scale.squeeze(0)
382+
252383
return inputs_q, scale
253384

254385

@@ -259,31 +390,60 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
259390
tensor: The input tensor.
260391
scale: The tensor with per-group scale factors.
261392
"""
393+
batch_present = len(tensor.shape) == 3
394+
if not batch_present:
395+
tensor = tensor.unsqueeze(0)
396+
scale = scale.unsqueeze(0)
262397

263-
num_groups = scale.size(0) * scale.size(1)
264-
input_dim, output_dim = tensor.shape
398+
num_groups = scale.size(1) * scale.size(2)
399+
batch, input_dim, output_dim = tensor.shape
265400

266-
num_bits = 8 if input_dim == scale.size(0) else 4
401+
num_bits = 8 if input_dim == scale.size(1) else 4
402+
q_range = 2**num_bits
267403
if num_bits == 4:
268404
input_dim *= 2
269405

270-
data = torch.empty((input_dim, output_dim),
406+
data = torch.empty((batch, input_dim, output_dim),
271407
dtype=scale.dtype,
272408
device=tensor.device)
273409

274410
if num_bits == 8:
275411
data.copy_(tensor)
412+
data -= q_range // 2
276413
else:
277414
"""Unpack two 4-bit values from each byte.
278415
"""
279-
tensor = tensor.reshape(input_dim, output_dim // 2)
416+
tensor = tensor.reshape(batch, input_dim, output_dim // 2)
280417
for i in range(2):
281-
data[:, i::2] = (tensor << 4 * (1 - i)) >> 4
418+
data[:, :, i::2] = ((tensor << 4 *
419+
(1 - i)) >> 4).to(torch.int8) - q_range // 2
282420
"""Scale each input group with its scaling factor.
283421
"""
284-
scale = scale.reshape(num_groups, -1)
285-
data = data.reshape(num_groups, -1)
422+
scale = scale.reshape(batch, num_groups, -1)
423+
data = data.reshape(batch, num_groups, -1)
286424
data = torch.mul(data, scale)
287425

288-
input_deq = data.reshape((input_dim, output_dim)).contiguous()
426+
input_deq = data.reshape((batch, input_dim, output_dim)).contiguous()
427+
if not batch_present:
428+
input_deq = input_deq.squeeze(0)
429+
289430
return input_deq
431+
432+
433+
def fix_weights(layer: torch.nn.Module,
434+
param_name: str,
435+
reshape: bool = False):
436+
"""torch.compile does not know how to deal with a Parameter subclass
437+
(aka RTNParameter). As we don't really need RTNParameters for the
438+
forward pass, we replace them with equivalent instances of Parameters.
439+
"""
440+
old_weight = getattr(layer, param_name)
441+
assert isinstance(old_weight, RTNParameter)
442+
data = old_weight.data.data
443+
444+
delattr(layer, param_name)
445+
446+
if reshape:
447+
data = data.reshape(old_weight.shape[0], old_weight.shape[1] * 2, -1)
448+
new_weight = Parameter(data=data, requires_grad=False)
449+
layer.register_parameter(param_name, new_weight)

0 commit comments

Comments
 (0)