33# Copyright © 2025, Oracle and/or its affiliates.
44
55import os
6- from typing import Any , Optional
6+ from typing import Any , Callable , Optional
77
88import torch
99import torch .nn .functional as F
1010from torch .nn .parameter import Parameter
1111
1212from vllm .logger import init_logger
13+ from vllm .model_executor .layers .fused_moe import FusedMoE , FusedMoEMethodBase
1314from vllm .model_executor .layers .linear import (LinearBase , LinearMethodBase ,
1415 set_weight_attrs )
1516from vllm .model_executor .layers .quantization import QuantizationMethods
1617from vllm .model_executor .layers .quantization .base_config import (
17- QuantizationConfig )
18+ QuantizationConfig , QuantizeMethodBase )
1819
1920logger = init_logger (__name__ )
2021"""By default, use 8 bit as target precision, but it can be
@@ -71,9 +72,11 @@ def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
7172 return cls (weight_bits , group_size )
7273
7374 def get_quant_method (self , layer : torch .nn .Module ,
74- prefix : str ) -> Optional ["RTNLinearMethod " ]:
75+ prefix : str ) -> Optional ["QuantizeMethodBase " ]:
7576 if isinstance (layer , LinearBase ):
7677 return RTNLinearMethod (self )
78+ elif isinstance (layer , FusedMoE ):
79+ return RTNMoEMethod (self )
7780 return None
7881
7982
@@ -94,11 +97,18 @@ def narrow(self, dim, start, length):
9497 self .data .narrow (dim , start // factor , length // factor ),
9598 self .scale .narrow (dim , start , length ), self .quant_config )
9699
100+ def __getitem__ (self , key ):
101+ return RTNTensor (self .data [key ], self .scale [key ], self .quant_config )
102+
97103 @property
98104 def shape (self ):
99105 shape = self .data .shape
100106 factor = 1 if self .quant_config .weight_bits == 8 else 2
101- return torch .Size ((shape [0 ] * factor , shape [1 ]))
107+ batch_present = len (shape ) == 3
108+ if batch_present :
109+ return torch .Size ((shape [0 ], shape [1 ] * factor , shape [2 ]))
110+ else :
111+ return torch .Size ((shape [0 ] * factor , shape [1 ]))
102112
103113 def copy_ (self , loaded_weight : torch .Tensor ) -> None :
104114 qweight , weight_scale = rtn_quantize (loaded_weight .cuda (),
@@ -165,7 +175,7 @@ def create_weights(
165175 weight = RTNParameter (data = torch .empty (output_size_per_partition //
166176 factor ,
167177 input_size_per_partition ,
168- dtype = torch .int8 ),
178+ dtype = torch .uint8 ),
169179 scale = scale ,
170180 quant_config = self .quant_config )
171181
@@ -180,18 +190,7 @@ def create_weights(
180190 layer .output_size_per_partition = output_size_per_partition
181191
182192 def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
183- """torch.compile does not know how to deal with a Parameter subclass
184- (aka RTNParameter). As we don't really need RTNParameters for the
185- forward pass, we replace them with equivalent instances of Parameters.
186- """
187- old_weight = layer .weight
188- assert isinstance (old_weight , RTNParameter )
189- data = old_weight .data .data
190-
191- delattr (layer , "weight" )
192-
193- new_weight = Parameter (data = data , requires_grad = False )
194- layer .register_parameter ("weight" , new_weight )
193+ fix_weights (layer , "weight" )
195194
196195 def apply (self ,
197196 layer : torch .nn .Module ,
@@ -209,6 +208,128 @@ def apply(self,
209208 return out
210209
211210
211+ class RTNMoEMethod (FusedMoEMethodBase ):
212+
213+ def __init__ (self , quant_config : RTNConfig ):
214+ self .quant_config = quant_config
215+
216+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
217+ hidden_size : int , intermediate_size_per_partition : int ,
218+ params_dtype : torch .dtype , ** extra_weight_attrs ):
219+
220+ factor = 1 if self .quant_config .weight_bits == 8 else 2
221+
222+ # Fused gate_up_proj (column parallel)
223+ num_groups_per_col = (hidden_size // self .quant_config .group_size
224+ if self .quant_config .group_size != - 1 else 1 )
225+ w13_scale = Parameter (
226+ torch .empty (num_experts ,
227+ 2 * intermediate_size_per_partition ,
228+ num_groups_per_col ,
229+ dtype = params_dtype ),
230+ requires_grad = False ,
231+ )
232+ layer .register_parameter ("w13_scale" , w13_scale )
233+
234+ w13_weight = RTNParameter (data = torch .empty (
235+ num_experts ,
236+ 2 * intermediate_size_per_partition // factor ,
237+ hidden_size ,
238+ dtype = torch .uint8 ),
239+ scale = w13_scale ,
240+ quant_config = self .quant_config )
241+ layer .register_parameter ("w13_weight" , w13_weight )
242+ set_weight_attrs (w13_weight , extra_weight_attrs )
243+
244+ # down_proj (row parallel)
245+ num_groups_per_col = (intermediate_size_per_partition //
246+ self .quant_config .group_size
247+ if self .quant_config .group_size != - 1 else 1 )
248+ w2_scale = Parameter (torch .zeros (num_experts ,
249+ hidden_size ,
250+ num_groups_per_col ,
251+ dtype = params_dtype ),
252+ requires_grad = False )
253+ layer .register_parameter ("w2_scale" , w2_scale )
254+
255+ w2_weight = RTNParameter (data = torch .empty (
256+ num_experts ,
257+ hidden_size // factor ,
258+ intermediate_size_per_partition ,
259+ dtype = torch .uint8 ),
260+ scale = w2_scale ,
261+ quant_config = self .quant_config )
262+ layer .register_parameter ("w2_weight" , w2_weight )
263+ set_weight_attrs (w2_weight , extra_weight_attrs )
264+
265+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
266+ weight_bits = self .quant_config .weight_bits
267+ fix_weights (layer , "w13_weight" , weight_bits == 4 )
268+ fix_weights (layer , "w2_weight" , weight_bits == 4 )
269+
270+ def apply (
271+ self ,
272+ layer : torch .nn .Module ,
273+ x : torch .Tensor ,
274+ router_logits : torch .Tensor ,
275+ top_k : int ,
276+ renormalize : bool ,
277+ use_grouped_topk : bool = False ,
278+ topk_group : Optional [int ] = None ,
279+ num_expert_group : Optional [int ] = None ,
280+ global_num_experts : int = - 1 ,
281+ expert_map : Optional [torch .Tensor ] = None ,
282+ custom_routing_function : Optional [Callable ] = None ,
283+ scoring_func : str = "softmax" ,
284+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
285+ apply_router_weight_on_input : bool = False ,
286+ activation : str = "silu" ,
287+ enable_eplb : bool = False ,
288+ expert_load_view : Optional [torch .Tensor ] = None ,
289+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
290+ logical_replica_count : Optional [torch .Tensor ] = None ,
291+ ) -> torch .Tensor :
292+ if enable_eplb :
293+ raise NotImplementedError (
294+ "EPLB not supported for `RTNMoEMethod` yet." )
295+
296+ from vllm .model_executor .layers .fused_moe import fused_experts
297+
298+ topk_weights , topk_ids = FusedMoE .select_experts (
299+ hidden_states = x ,
300+ router_logits = router_logits ,
301+ use_grouped_topk = use_grouped_topk ,
302+ top_k = top_k ,
303+ renormalize = renormalize ,
304+ topk_group = topk_group ,
305+ num_expert_group = num_expert_group ,
306+ custom_routing_function = custom_routing_function ,
307+ scoring_func = scoring_func ,
308+ e_score_correction_bias = e_score_correction_bias )
309+
310+ weight_bits = self .quant_config .weight_bits
311+ group_size = self .quant_config .group_size
312+
313+ ret = fused_experts (
314+ x ,
315+ layer .w13_weight ,
316+ layer .w2_weight ,
317+ topk_weights = topk_weights ,
318+ topk_ids = topk_ids ,
319+ inplace = True ,
320+ activation = activation ,
321+ use_int4_w4a16 = weight_bits == 4 ,
322+ use_int8_w8a16 = weight_bits == 8 ,
323+ global_num_experts = global_num_experts ,
324+ w1_scale = layer .w13_scale ,
325+ w2_scale = layer .w2_scale ,
326+ apply_router_weight_on_input = apply_router_weight_on_input ,
327+ expert_map = expert_map ,
328+ block_shape = [0 , group_size ])
329+
330+ return ret
331+
332+
212333def rtn_quantize (tensor : torch .Tensor , num_bits : int ,
213334 group_size : int ) -> tuple [torch .Tensor , torch .Tensor ]:
214335 """Quantize a tensor using per-group static scaling factor.
@@ -221,34 +342,44 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int,
221342 If equal to -1, each row in the input tensor is treated
222343 as one group.
223344 """
345+ batch_present = len (tensor .shape ) == 3
346+ if not batch_present :
347+ tensor = tensor .unsqueeze (0 )
224348
225349 q_range = 2 ** num_bits
226- num_groups = (tensor .shape [0 ] * tensor .shape [1 ] //
227- group_size if group_size != - 1 else tensor .shape [0 ])
350+ num_groups = (tensor .shape [1 ] * tensor .shape [2 ] //
351+ group_size if group_size != - 1 else tensor .shape [1 ])
228352 """Calculate a scaling factor per input group.
229353 """
230- input_flat = tensor .reshape (num_groups , - 1 )
231- input_min = torch .min (input_flat , dim = 1 , keepdim = True )[0 ]
232- input_max = torch .max (input_flat , dim = 1 , keepdim = True )[0 ]
354+ input_flat = tensor .reshape (tensor . shape [ 0 ], num_groups , - 1 )
355+ input_min = torch .min (input_flat , dim = 2 , keepdim = True )[0 ]
356+ input_max = torch .max (input_flat , dim = 2 , keepdim = True )[0 ]
233357 input_max_abs = torch .max (input_min .abs (), input_max .abs ())
234358 scale = (input_max_abs * 2.0 / (q_range - 1 ))
235- """Scale each input group, truncate and round to the nearest integer.
359+ """Scale each input group, round to the nearest integer, shift
360+ the range and truncate.
236361 """
237362 scaled_input = input_flat / scale
238- scaled_input = scaled_input .clamp (- q_range // 2 , q_range // 2 - 1 )
239363 scaled_input = scaled_input .round ()
364+ scaled_input += q_range // 2
365+ scaled_input = scaled_input .clamp (0 , q_range - 1 )
240366
241- scale = scale .reshape (tensor .shape [0 ], - 1 ).contiguous ()
242- inputs_q = scaled_input .reshape (tensor .shape ).to (torch .int8 )
367+ scale = scale .reshape (tensor .shape [0 ], tensor . shape [ 1 ], - 1 ).contiguous ()
368+ inputs_q = scaled_input .reshape (tensor .shape ).to (torch .uint8 )
243369 inputs_q = inputs_q .contiguous ()
244370
245371 if num_bits == 4 :
246372 """Pack two 4-bit values into each byte.
247373 """
248- inputs_q = (inputs_q [:, 1 ::2 ] << 4 ) | (inputs_q [:, ::2 ] & 0xf )
249- inputs_q = inputs_q .reshape (tensor .shape [0 ] // 2 , tensor .shape [1 ])
374+ inputs_q = (inputs_q [:, :, 1 ::2 ] << 4 ) | (inputs_q [:, :, ::2 ] & 0xf )
375+ inputs_q = inputs_q .reshape (tensor .shape [0 ], tensor .shape [1 ] // 2 ,
376+ tensor .shape [2 ])
250377 inputs_q = inputs_q .contiguous ()
251378
379+ if not batch_present :
380+ inputs_q = inputs_q .squeeze (0 )
381+ scale = scale .squeeze (0 )
382+
252383 return inputs_q , scale
253384
254385
@@ -259,31 +390,60 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
259390 tensor: The input tensor.
260391 scale: The tensor with per-group scale factors.
261392 """
393+ batch_present = len (tensor .shape ) == 3
394+ if not batch_present :
395+ tensor = tensor .unsqueeze (0 )
396+ scale = scale .unsqueeze (0 )
262397
263- num_groups = scale .size (0 ) * scale .size (1 )
264- input_dim , output_dim = tensor .shape
398+ num_groups = scale .size (1 ) * scale .size (2 )
399+ batch , input_dim , output_dim = tensor .shape
265400
266- num_bits = 8 if input_dim == scale .size (0 ) else 4
401+ num_bits = 8 if input_dim == scale .size (1 ) else 4
402+ q_range = 2 ** num_bits
267403 if num_bits == 4 :
268404 input_dim *= 2
269405
270- data = torch .empty ((input_dim , output_dim ),
406+ data = torch .empty ((batch , input_dim , output_dim ),
271407 dtype = scale .dtype ,
272408 device = tensor .device )
273409
274410 if num_bits == 8 :
275411 data .copy_ (tensor )
412+ data -= q_range // 2
276413 else :
277414 """Unpack two 4-bit values from each byte.
278415 """
279- tensor = tensor .reshape (input_dim , output_dim // 2 )
416+ tensor = tensor .reshape (batch , input_dim , output_dim // 2 )
280417 for i in range (2 ):
281- data [:, i ::2 ] = (tensor << 4 * (1 - i )) >> 4
418+ data [:, :, i ::2 ] = ((tensor << 4 *
419+ (1 - i )) >> 4 ).to (torch .int8 ) - q_range // 2
282420 """Scale each input group with its scaling factor.
283421 """
284- scale = scale .reshape (num_groups , - 1 )
285- data = data .reshape (num_groups , - 1 )
422+ scale = scale .reshape (batch , num_groups , - 1 )
423+ data = data .reshape (batch , num_groups , - 1 )
286424 data = torch .mul (data , scale )
287425
288- input_deq = data .reshape ((input_dim , output_dim )).contiguous ()
426+ input_deq = data .reshape ((batch , input_dim , output_dim )).contiguous ()
427+ if not batch_present :
428+ input_deq = input_deq .squeeze (0 )
429+
289430 return input_deq
431+
432+
433+ def fix_weights (layer : torch .nn .Module ,
434+ param_name : str ,
435+ reshape : bool = False ):
436+ """torch.compile does not know how to deal with a Parameter subclass
437+ (aka RTNParameter). As we don't really need RTNParameters for the
438+ forward pass, we replace them with equivalent instances of Parameters.
439+ """
440+ old_weight = getattr (layer , param_name )
441+ assert isinstance (old_weight , RTNParameter )
442+ data = old_weight .data .data
443+
444+ delattr (layer , param_name )
445+
446+ if reshape :
447+ data = data .reshape (old_weight .shape [0 ], old_weight .shape [1 ] * 2 , - 1 )
448+ new_weight = Parameter (data = data , requires_grad = False )
449+ layer .register_parameter (param_name , new_weight )
0 commit comments