3
3
# Copyright © 2025, Oracle and/or its affiliates.
4
4
5
5
import os
6
- from typing import Any , Optional
6
+ from typing import Any , Callable , Optional
7
7
8
8
import torch
9
9
import torch .nn .functional as F
10
10
from torch .nn .parameter import Parameter
11
11
12
12
from vllm .logger import init_logger
13
+ from vllm .model_executor .layers .fused_moe import FusedMoE , FusedMoEMethodBase
13
14
from vllm .model_executor .layers .linear import (LinearBase , LinearMethodBase ,
14
15
set_weight_attrs )
15
16
from vllm .model_executor .layers .quantization import QuantizationMethods
16
17
from vllm .model_executor .layers .quantization .base_config import (
17
- QuantizationConfig )
18
+ QuantizationConfig , QuantizeMethodBase )
18
19
19
20
logger = init_logger (__name__ )
20
21
"""By default, use 8 bit as target precision, but it can be
@@ -71,9 +72,11 @@ def from_config(cls, config: dict[str, Any]) -> "RTNConfig":
71
72
return cls (weight_bits , group_size )
72
73
73
74
def get_quant_method (self , layer : torch .nn .Module ,
74
- prefix : str ) -> Optional ["RTNLinearMethod " ]:
75
+ prefix : str ) -> Optional ["QuantizeMethodBase " ]:
75
76
if isinstance (layer , LinearBase ):
76
77
return RTNLinearMethod (self )
78
+ elif isinstance (layer , FusedMoE ):
79
+ return RTNMoEMethod (self )
77
80
return None
78
81
79
82
@@ -94,11 +97,18 @@ def narrow(self, dim, start, length):
94
97
self .data .narrow (dim , start // factor , length // factor ),
95
98
self .scale .narrow (dim , start , length ), self .quant_config )
96
99
100
+ def __getitem__ (self , key ):
101
+ return RTNTensor (self .data [key ], self .scale [key ], self .quant_config )
102
+
97
103
@property
98
104
def shape (self ):
99
105
shape = self .data .shape
100
106
factor = 1 if self .quant_config .weight_bits == 8 else 2
101
- return torch .Size ((shape [0 ] * factor , shape [1 ]))
107
+ batch_present = len (shape ) == 3
108
+ if batch_present :
109
+ return torch .Size ((shape [0 ], shape [1 ] * factor , shape [2 ]))
110
+ else :
111
+ return torch .Size ((shape [0 ] * factor , shape [1 ]))
102
112
103
113
def copy_ (self , loaded_weight : torch .Tensor ) -> None :
104
114
qweight , weight_scale = rtn_quantize (loaded_weight .cuda (),
@@ -165,7 +175,7 @@ def create_weights(
165
175
weight = RTNParameter (data = torch .empty (output_size_per_partition //
166
176
factor ,
167
177
input_size_per_partition ,
168
- dtype = torch .int8 ),
178
+ dtype = torch .uint8 ),
169
179
scale = scale ,
170
180
quant_config = self .quant_config )
171
181
@@ -180,18 +190,7 @@ def create_weights(
180
190
layer .output_size_per_partition = output_size_per_partition
181
191
182
192
def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
183
- """torch.compile does not know how to deal with a Parameter subclass
184
- (aka RTNParameter). As we don't really need RTNParameters for the
185
- forward pass, we replace them with equivalent instances of Parameters.
186
- """
187
- old_weight = layer .weight
188
- assert isinstance (old_weight , RTNParameter )
189
- data = old_weight .data .data
190
-
191
- delattr (layer , "weight" )
192
-
193
- new_weight = Parameter (data = data , requires_grad = False )
194
- layer .register_parameter ("weight" , new_weight )
193
+ fix_weights (layer , "weight" )
195
194
196
195
def apply (self ,
197
196
layer : torch .nn .Module ,
@@ -209,6 +208,128 @@ def apply(self,
209
208
return out
210
209
211
210
211
+ class RTNMoEMethod (FusedMoEMethodBase ):
212
+
213
+ def __init__ (self , quant_config : RTNConfig ):
214
+ self .quant_config = quant_config
215
+
216
+ def create_weights (self , layer : torch .nn .Module , num_experts : int ,
217
+ hidden_size : int , intermediate_size_per_partition : int ,
218
+ params_dtype : torch .dtype , ** extra_weight_attrs ):
219
+
220
+ factor = 1 if self .quant_config .weight_bits == 8 else 2
221
+
222
+ # Fused gate_up_proj (column parallel)
223
+ num_groups_per_col = (hidden_size // self .quant_config .group_size
224
+ if self .quant_config .group_size != - 1 else 1 )
225
+ w13_scale = Parameter (
226
+ torch .empty (num_experts ,
227
+ 2 * intermediate_size_per_partition ,
228
+ num_groups_per_col ,
229
+ dtype = params_dtype ),
230
+ requires_grad = False ,
231
+ )
232
+ layer .register_parameter ("w13_scale" , w13_scale )
233
+
234
+ w13_weight = RTNParameter (data = torch .empty (
235
+ num_experts ,
236
+ 2 * intermediate_size_per_partition // factor ,
237
+ hidden_size ,
238
+ dtype = torch .uint8 ),
239
+ scale = w13_scale ,
240
+ quant_config = self .quant_config )
241
+ layer .register_parameter ("w13_weight" , w13_weight )
242
+ set_weight_attrs (w13_weight , extra_weight_attrs )
243
+
244
+ # down_proj (row parallel)
245
+ num_groups_per_col = (intermediate_size_per_partition //
246
+ self .quant_config .group_size
247
+ if self .quant_config .group_size != - 1 else 1 )
248
+ w2_scale = Parameter (torch .zeros (num_experts ,
249
+ hidden_size ,
250
+ num_groups_per_col ,
251
+ dtype = params_dtype ),
252
+ requires_grad = False )
253
+ layer .register_parameter ("w2_scale" , w2_scale )
254
+
255
+ w2_weight = RTNParameter (data = torch .empty (
256
+ num_experts ,
257
+ hidden_size // factor ,
258
+ intermediate_size_per_partition ,
259
+ dtype = torch .uint8 ),
260
+ scale = w2_scale ,
261
+ quant_config = self .quant_config )
262
+ layer .register_parameter ("w2_weight" , w2_weight )
263
+ set_weight_attrs (w2_weight , extra_weight_attrs )
264
+
265
+ def process_weights_after_loading (self , layer : torch .nn .Module ) -> None :
266
+ weight_bits = self .quant_config .weight_bits
267
+ fix_weights (layer , "w13_weight" , weight_bits == 4 )
268
+ fix_weights (layer , "w2_weight" , weight_bits == 4 )
269
+
270
+ def apply (
271
+ self ,
272
+ layer : torch .nn .Module ,
273
+ x : torch .Tensor ,
274
+ router_logits : torch .Tensor ,
275
+ top_k : int ,
276
+ renormalize : bool ,
277
+ use_grouped_topk : bool = False ,
278
+ topk_group : Optional [int ] = None ,
279
+ num_expert_group : Optional [int ] = None ,
280
+ global_num_experts : int = - 1 ,
281
+ expert_map : Optional [torch .Tensor ] = None ,
282
+ custom_routing_function : Optional [Callable ] = None ,
283
+ scoring_func : str = "softmax" ,
284
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
285
+ apply_router_weight_on_input : bool = False ,
286
+ activation : str = "silu" ,
287
+ enable_eplb : bool = False ,
288
+ expert_load_view : Optional [torch .Tensor ] = None ,
289
+ logical_to_physical_map : Optional [torch .Tensor ] = None ,
290
+ logical_replica_count : Optional [torch .Tensor ] = None ,
291
+ ) -> torch .Tensor :
292
+ if enable_eplb :
293
+ raise NotImplementedError (
294
+ "EPLB not supported for `RTNMoEMethod` yet." )
295
+
296
+ from vllm .model_executor .layers .fused_moe import fused_experts
297
+
298
+ topk_weights , topk_ids = FusedMoE .select_experts (
299
+ hidden_states = x ,
300
+ router_logits = router_logits ,
301
+ use_grouped_topk = use_grouped_topk ,
302
+ top_k = top_k ,
303
+ renormalize = renormalize ,
304
+ topk_group = topk_group ,
305
+ num_expert_group = num_expert_group ,
306
+ custom_routing_function = custom_routing_function ,
307
+ scoring_func = scoring_func ,
308
+ e_score_correction_bias = e_score_correction_bias )
309
+
310
+ weight_bits = self .quant_config .weight_bits
311
+ group_size = self .quant_config .group_size
312
+
313
+ ret = fused_experts (
314
+ x ,
315
+ layer .w13_weight ,
316
+ layer .w2_weight ,
317
+ topk_weights = topk_weights ,
318
+ topk_ids = topk_ids ,
319
+ inplace = True ,
320
+ activation = activation ,
321
+ use_int4_w4a16 = weight_bits == 4 ,
322
+ use_int8_w8a16 = weight_bits == 8 ,
323
+ global_num_experts = global_num_experts ,
324
+ w1_scale = layer .w13_scale ,
325
+ w2_scale = layer .w2_scale ,
326
+ apply_router_weight_on_input = apply_router_weight_on_input ,
327
+ expert_map = expert_map ,
328
+ block_shape = [0 , group_size ])
329
+
330
+ return ret
331
+
332
+
212
333
def rtn_quantize (tensor : torch .Tensor , num_bits : int ,
213
334
group_size : int ) -> tuple [torch .Tensor , torch .Tensor ]:
214
335
"""Quantize a tensor using per-group static scaling factor.
@@ -221,34 +342,44 @@ def rtn_quantize(tensor: torch.Tensor, num_bits: int,
221
342
If equal to -1, each row in the input tensor is treated
222
343
as one group.
223
344
"""
345
+ batch_present = len (tensor .shape ) == 3
346
+ if not batch_present :
347
+ tensor = tensor .unsqueeze (0 )
224
348
225
349
q_range = 2 ** num_bits
226
- num_groups = (tensor .shape [0 ] * tensor .shape [1 ] //
227
- group_size if group_size != - 1 else tensor .shape [0 ])
350
+ num_groups = (tensor .shape [1 ] * tensor .shape [2 ] //
351
+ group_size if group_size != - 1 else tensor .shape [1 ])
228
352
"""Calculate a scaling factor per input group.
229
353
"""
230
- input_flat = tensor .reshape (num_groups , - 1 )
231
- input_min = torch .min (input_flat , dim = 1 , keepdim = True )[0 ]
232
- input_max = torch .max (input_flat , dim = 1 , keepdim = True )[0 ]
354
+ input_flat = tensor .reshape (tensor . shape [ 0 ], num_groups , - 1 )
355
+ input_min = torch .min (input_flat , dim = 2 , keepdim = True )[0 ]
356
+ input_max = torch .max (input_flat , dim = 2 , keepdim = True )[0 ]
233
357
input_max_abs = torch .max (input_min .abs (), input_max .abs ())
234
358
scale = (input_max_abs * 2.0 / (q_range - 1 ))
235
- """Scale each input group, truncate and round to the nearest integer.
359
+ """Scale each input group, round to the nearest integer, shift
360
+ the range and truncate.
236
361
"""
237
362
scaled_input = input_flat / scale
238
- scaled_input = scaled_input .clamp (- q_range // 2 , q_range // 2 - 1 )
239
363
scaled_input = scaled_input .round ()
364
+ scaled_input += q_range // 2
365
+ scaled_input = scaled_input .clamp (0 , q_range - 1 )
240
366
241
- scale = scale .reshape (tensor .shape [0 ], - 1 ).contiguous ()
242
- inputs_q = scaled_input .reshape (tensor .shape ).to (torch .int8 )
367
+ scale = scale .reshape (tensor .shape [0 ], tensor . shape [ 1 ], - 1 ).contiguous ()
368
+ inputs_q = scaled_input .reshape (tensor .shape ).to (torch .uint8 )
243
369
inputs_q = inputs_q .contiguous ()
244
370
245
371
if num_bits == 4 :
246
372
"""Pack two 4-bit values into each byte.
247
373
"""
248
- inputs_q = (inputs_q [:, 1 ::2 ] << 4 ) | (inputs_q [:, ::2 ] & 0xf )
249
- inputs_q = inputs_q .reshape (tensor .shape [0 ] // 2 , tensor .shape [1 ])
374
+ inputs_q = (inputs_q [:, :, 1 ::2 ] << 4 ) | (inputs_q [:, :, ::2 ] & 0xf )
375
+ inputs_q = inputs_q .reshape (tensor .shape [0 ], tensor .shape [1 ] // 2 ,
376
+ tensor .shape [2 ])
250
377
inputs_q = inputs_q .contiguous ()
251
378
379
+ if not batch_present :
380
+ inputs_q = inputs_q .squeeze (0 )
381
+ scale = scale .squeeze (0 )
382
+
252
383
return inputs_q , scale
253
384
254
385
@@ -259,31 +390,60 @@ def rtn_dequantize(tensor: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
259
390
tensor: The input tensor.
260
391
scale: The tensor with per-group scale factors.
261
392
"""
393
+ batch_present = len (tensor .shape ) == 3
394
+ if not batch_present :
395
+ tensor = tensor .unsqueeze (0 )
396
+ scale = scale .unsqueeze (0 )
262
397
263
- num_groups = scale .size (0 ) * scale .size (1 )
264
- input_dim , output_dim = tensor .shape
398
+ num_groups = scale .size (1 ) * scale .size (2 )
399
+ batch , input_dim , output_dim = tensor .shape
265
400
266
- num_bits = 8 if input_dim == scale .size (0 ) else 4
401
+ num_bits = 8 if input_dim == scale .size (1 ) else 4
402
+ q_range = 2 ** num_bits
267
403
if num_bits == 4 :
268
404
input_dim *= 2
269
405
270
- data = torch .empty ((input_dim , output_dim ),
406
+ data = torch .empty ((batch , input_dim , output_dim ),
271
407
dtype = scale .dtype ,
272
408
device = tensor .device )
273
409
274
410
if num_bits == 8 :
275
411
data .copy_ (tensor )
412
+ data -= q_range // 2
276
413
else :
277
414
"""Unpack two 4-bit values from each byte.
278
415
"""
279
- tensor = tensor .reshape (input_dim , output_dim // 2 )
416
+ tensor = tensor .reshape (batch , input_dim , output_dim // 2 )
280
417
for i in range (2 ):
281
- data [:, i ::2 ] = (tensor << 4 * (1 - i )) >> 4
418
+ data [:, :, i ::2 ] = ((tensor << 4 *
419
+ (1 - i )) >> 4 ).to (torch .int8 ) - q_range // 2
282
420
"""Scale each input group with its scaling factor.
283
421
"""
284
- scale = scale .reshape (num_groups , - 1 )
285
- data = data .reshape (num_groups , - 1 )
422
+ scale = scale .reshape (batch , num_groups , - 1 )
423
+ data = data .reshape (batch , num_groups , - 1 )
286
424
data = torch .mul (data , scale )
287
425
288
- input_deq = data .reshape ((input_dim , output_dim )).contiguous ()
426
+ input_deq = data .reshape ((batch , input_dim , output_dim )).contiguous ()
427
+ if not batch_present :
428
+ input_deq = input_deq .squeeze (0 )
429
+
289
430
return input_deq
431
+
432
+
433
+ def fix_weights (layer : torch .nn .Module ,
434
+ param_name : str ,
435
+ reshape : bool = False ):
436
+ """torch.compile does not know how to deal with a Parameter subclass
437
+ (aka RTNParameter). As we don't really need RTNParameters for the
438
+ forward pass, we replace them with equivalent instances of Parameters.
439
+ """
440
+ old_weight = getattr (layer , param_name )
441
+ assert isinstance (old_weight , RTNParameter )
442
+ data = old_weight .data .data
443
+
444
+ delattr (layer , param_name )
445
+
446
+ if reshape :
447
+ data = data .reshape (old_weight .shape [0 ], old_weight .shape [1 ] * 2 , - 1 )
448
+ new_weight = Parameter (data = data , requires_grad = False )
449
+ layer .register_parameter (param_name , new_weight )
0 commit comments