@@ -363,11 +363,22 @@ def load_moe_expert_weights(
363
363
continue
364
364
param = params_dict [full_param_name ]
365
365
weight_loader = param .weight_loader
366
+
367
+ # Helper function to check if the weight is FP4.
368
+ # We use uint8 to store FP4 weights for now.
369
+ def is_fp4_weight (weight ):
370
+ return weight .dtype == torch .uint8
371
+
366
372
if fused :
367
373
if "w13" in full_param_name :
368
374
shard_idx = 0 if shard_id == "w1" else 1
369
375
new_loaded_weight = new_loaded_weight [shard_idx ]
370
- new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
376
+
377
+ # Only transpose for non-FP4 weights
378
+ # FP4 weights are already in the correct format and shouldn't be transposed here.
379
+ if not is_fp4_weight (new_loaded_weight ):
380
+ new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
381
+
371
382
layer_idx = extract_layer_index (name )
372
383
# EP mapping
373
384
expert_map = self .layers [
@@ -382,6 +393,11 @@ def load_moe_expert_weights(
382
393
else :
383
394
# TODO: add EP support for non fused weights
384
395
pass
396
+
397
+ # Only transpose for FP4 weights
398
+ if is_fp4_weight (new_loaded_weight ):
399
+ new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
400
+
385
401
weight_loader (param ,
386
402
new_loaded_weight ,
387
403
full_param_name ,
@@ -402,6 +418,12 @@ def load_weights(self, weights: Iterable[tuple[str,
402
418
(".gate_up_proj" , ".gate_proj" , 0 ),
403
419
(".gate_up_proj" , ".up_proj" , 1 ),
404
420
]
421
+ expert_scale_params_mapping = [
422
+ # (expert_name, expert_id, shard_id)
423
+ ("w13_" , 0 , 'w1' ),
424
+ ("w13_" , 0 , 'w3' ),
425
+ ("w2_" , 0 , 'w2' )
426
+ ]
405
427
fused_experts_params = False
406
428
expert_params_mapping = FusedMoE .make_expert_params_mapping (
407
429
ckpt_gate_proj_name = "gate_proj" ,
@@ -483,19 +505,19 @@ def load_weights(self, weights: Iterable[tuple[str,
483
505
'supports_moe_loading' , False )
484
506
485
507
if supports_moe :
486
- # This is a MoE weight loader
487
- if "w13_" in name :
488
- shard_id = "w1"
489
- elif "w2_" in name :
490
- shard_id = "w2"
491
- else :
492
- shard_id = "w1"
493
-
494
- weight_loader ( param ,
495
- loaded_weight ,
496
- name ,
497
- shard_id = shard_id ,
498
- expert_id = 0 )
508
+ # Transpose if the weights are FP8 or FP4.
509
+ if loaded_weight . dtype == torch . uint8 or loaded_weight . dtype == torch . float8_e4m3fn :
510
+ loaded_weight = loaded_weight . transpose ( - 1 , - 2 )
511
+ param . data . fill_ ( 0 )
512
+
513
+ for ( expert_name , expert_id , shard_id ) in expert_scale_params_mapping :
514
+ if expert_name in name :
515
+ weight_loader ( param ,
516
+ loaded_weight ,
517
+ name ,
518
+ shard_id = shard_id ,
519
+ expert_id = expert_id )
520
+
499
521
else :
500
522
# Regular weight loader (handles both
501
523
# param.weight_loader and default_weight_loader)
@@ -560,23 +582,28 @@ def permute_qk_weight_for_rotary(
560
582
loaded_weight : torch .Tensor ,
561
583
) -> tuple [str , torch .Tensor ]:
562
584
585
+ # Helper function to permute the weight's channels
563
586
def permute (w : torch .Tensor , n_heads : int ):
564
- attn_in = self .config .head_dim * n_heads
565
- attn_out = self .config .hidden_size
566
-
567
- return w .view (n_heads , attn_in // n_heads // 2 , 2 ,
568
- attn_out ).transpose (1 , 2 ).reshape (attn_in , attn_out )
587
+ head_dim = w .shape [0 ] // n_heads
588
+ return (
589
+ w .view (n_heads , head_dim // 2 , 2 , w .shape [1 ])
590
+ .transpose (1 , 2 )
591
+ .reshape (w .shape [0 ], w .shape [1 ])
592
+ )
569
593
570
594
modules = name .split ("." )
571
595
572
- # rotary embeds should be sliced
573
- if ("wk" in modules or "k_proj" in modules ) \
574
- and modules [- 1 ] == "weight" :
575
- loaded_weight = permute (loaded_weight ,
576
- self .config .num_key_value_heads )
577
- elif ("wq" in modules or "q_proj" in modules ) \
578
- and modules [- 1 ] == "weight" :
579
- loaded_weight = permute (loaded_weight ,
580
- self .config .num_attention_heads )
596
+ # Permute Q/K weights and weight block scales for rotary embedding
597
+ is_weight = modules [- 1 ] == "weight"
598
+ is_nvfp4_weight_scale = (modules [- 1 ] == "weight_scale"
599
+ and loaded_weight .dtype == torch .float8_e4m3fn )
600
+
601
+ if is_weight or is_nvfp4_weight_scale :
602
+ if ("wk" in modules or "k_proj" in modules ):
603
+ loaded_weight = permute (loaded_weight ,
604
+ self .config .num_key_value_heads )
605
+ elif ("wq" in modules or "q_proj" in modules ):
606
+ loaded_weight = permute (loaded_weight ,
607
+ self .config .num_attention_heads )
581
608
582
609
return name , loaded_weight
0 commit comments