@@ -291,6 +291,40 @@ def pad_proj_weight(self, data):
291
291
self .hidden_size , - 1 )
292
292
return out_weight
293
293
294
+ def pad_qkv_weight_scale_offset (self , data ):
295
+ reshaped_data = data .reshape (
296
+ - 1 , 3 , self .origin_hidden_size_per_attention_head , 1 )
297
+ data1 = reshaped_data [:, :, :self .
298
+ half_origin_hidden_size_per_attention_head , :]
299
+ data2 = reshaped_data [:, :, self .
300
+ half_origin_hidden_size_per_attention_head :, :]
301
+ data1_paded = torch .nn .functional .pad (
302
+ data1 , (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head , 0 ,
303
+ 0 , 0 , 0 ))
304
+ data2_paded = torch .nn .functional .pad (
305
+ data2 , (0 , 0 , 0 , self .half_pad_hidden_size_per_attention_head , 0 ,
306
+ 0 , 0 , 0 ))
307
+ res = torch .cat ([data1_paded , data2_paded ], dim = 2 )
308
+ res = res .reshape (- 1 , 1 )
309
+ return res
310
+
311
+ def pad_qkv_deq_scale_quant_bias (self , data ):
312
+ reshaped_data = data .reshape (
313
+ - 1 , 3 , self .origin_hidden_size_per_attention_head )
314
+ data1 = reshaped_data [:, :, :self .
315
+ half_origin_hidden_size_per_attention_head ]
316
+ data2 = reshaped_data [:, :,
317
+ self .half_origin_hidden_size_per_attention_head :]
318
+
319
+ data1_paded = torch .nn .functional .pad (
320
+ data1 , (0 , self .half_pad_hidden_size_per_attention_head ))
321
+ data2_paded = torch .nn .functional .pad (
322
+ data2 , (0 , self .half_pad_hidden_size_per_attention_head ))
323
+
324
+ res = torch .cat ([data1_paded , data2_paded ], dim = 2 )
325
+ res = res .reshape (- 1 )
326
+ return res
327
+
294
328
def load_weights (self , weights : Iterable [Tuple [str ,
295
329
torch .Tensor ]]) -> Set [str ]:
296
330
stacked_params_mapping : list [tuple [str , str , Union [str , int ]]] = [
@@ -318,11 +352,23 @@ def load_weights(self, weights: Iterable[Tuple[str,
318
352
weight_loader = getattr (param , "weight_loader" ,
319
353
default_weight_loader )
320
354
weight_loader (param , loaded_weight )
321
- if ("attn.proj.weight" in name ) and self .enable_pad :
355
+ if ("attn.proj.weight_scale" in name or
356
+ "attn.proj.weight_offset" in name ) and self .enable_pad :
357
+ continue
358
+ elif ("attn.proj.deq_scale" in name
359
+ or "attn.proj.quant_bias" in name ) and self .enable_pad :
360
+ continue
361
+ elif ("attn.qkv.weight_scale" in name
362
+ or "attn.qkv.weight_offset" in name ) and self .enable_pad :
363
+ param .data = self .pad_qkv_weight_scale_offset (param .data )
364
+ elif ("attn.qkv.deq_scale" in name
365
+ or "attn.qkv.quant_bias" in name ) and self .enable_pad :
366
+ param .data = self .pad_qkv_deq_scale_quant_bias (param .data )
367
+ elif ("attn.proj.weight" in name ) and self .enable_pad :
322
368
param .data = self .pad_proj_weight (param .data )
323
- if ("attn.qkv.weight" in name ) and self .enable_pad :
369
+ elif ("attn.qkv.weight" in name ) and self .enable_pad :
324
370
param .data = self .pad_qkv_weight (param .data )
325
- if ("attn.qkv.bias" in name ) and self .enable_pad :
371
+ elif ("attn.qkv.bias" in name ) and self .enable_pad :
326
372
param .data = self .pad_qkv_bias (param .data )
327
373
loaded_params .add (name )
328
374
return loaded_params
@@ -445,6 +491,17 @@ def forward(
445
491
dummy_inputs = Qwen2_5_VLDummyInputsBuilder )
446
492
class AscendQwen2_5_VLForConditionalGeneration (
447
493
Qwen2_5_VLForConditionalGeneration ):
494
+ packed_modules_mapping = {
495
+ "qkv_proj" : [
496
+ "q_proj" ,
497
+ "k_proj" ,
498
+ "v_proj" ,
499
+ ],
500
+ "gate_up_proj" : [
501
+ "gate_proj" ,
502
+ "up_proj" ,
503
+ ],
504
+ }
448
505
449
506
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
450
507
super ().__init__ (vllm_config = vllm_config , prefix = prefix )
0 commit comments