@@ -363,11 +363,22 @@ def load_moe_expert_weights(
363
363
continue
364
364
param = params_dict [full_param_name ]
365
365
weight_loader = param .weight_loader
366
+
367
+ # Helper function to check if the weight is FP4.
368
+ # We use uint8 to store FP4 weights for now.
369
+ def is_fp4_weight (weight ):
370
+ return weight .dtype == torch .uint8
371
+
366
372
if fused :
367
373
if "w13" in full_param_name :
368
374
shard_idx = 0 if shard_id == "w1" else 1
369
375
new_loaded_weight = new_loaded_weight [shard_idx ]
370
- new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
376
+
377
+ # Only transpose for non-FP4 weights
378
+ # FP4 weights are already in the correct format and shouldn't be transposed here.
379
+ if not is_fp4_weight (new_loaded_weight ):
380
+ new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
381
+
371
382
layer_idx = extract_layer_index (name )
372
383
# EP mapping
373
384
expert_map = self .layers [
@@ -382,6 +393,11 @@ def load_moe_expert_weights(
382
393
else :
383
394
# TODO: add EP support for non fused weights
384
395
pass
396
+
397
+ # Only transpose for FP4 weights
398
+ if is_fp4_weight (new_loaded_weight ):
399
+ new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
400
+
385
401
weight_loader (param ,
386
402
new_loaded_weight ,
387
403
full_param_name ,
@@ -491,11 +507,17 @@ def load_weights(self, weights: Iterable[tuple[str,
491
507
else :
492
508
shard_id = "w1"
493
509
510
+ # Transpose if the weights are FP8 or FP4.
511
+ if loaded_weight .dtype == torch .uint8 \
512
+ or loaded_weight .dtype == torch .float8_e4m3fn :
513
+ loaded_weight = loaded_weight .transpose (- 1 , - 2 )
514
+
494
515
weight_loader (param ,
495
516
loaded_weight ,
496
517
name ,
497
518
shard_id = shard_id ,
498
519
expert_id = 0 )
520
+
499
521
else :
500
522
# Regular weight loader (handles both
501
523
# param.weight_loader and default_weight_loader)
@@ -560,23 +582,28 @@ def permute_qk_weight_for_rotary(
560
582
loaded_weight : torch .Tensor ,
561
583
) -> tuple [str , torch .Tensor ]:
562
584
585
+ # Helper function to permute the weight's channels
563
586
def permute (w : torch .Tensor , n_heads : int ):
564
- attn_in = self .config .head_dim * n_heads
565
- attn_out = self .config .hidden_size
566
-
567
- return w .view (n_heads , attn_in // n_heads // 2 , 2 ,
568
- attn_out ).transpose (1 , 2 ).reshape (attn_in , attn_out )
587
+ head_dim = w .shape [0 ] // n_heads
588
+ return (
589
+ w .view (n_heads , head_dim // 2 , 2 , w .shape [1 ])
590
+ .transpose (1 , 2 )
591
+ .reshape (w .shape [0 ], w .shape [1 ])
592
+ )
569
593
570
594
modules = name .split ("." )
571
595
572
- # rotary embeds should be sliced
573
- if ("wk" in modules or "k_proj" in modules ) \
574
- and modules [- 1 ] == "weight" :
575
- loaded_weight = permute (loaded_weight ,
576
- self .config .num_key_value_heads )
577
- elif ("wq" in modules or "q_proj" in modules ) \
578
- and modules [- 1 ] == "weight" :
579
- loaded_weight = permute (loaded_weight ,
580
- self .config .num_attention_heads )
596
+ # Permute Q/K weights and weight block scales for rotary embedding
597
+ is_weight = modules [- 1 ] == "weight"
598
+ is_nvfp4_weight_scale = (modules [- 1 ] == "weight_scale"
599
+ and loaded_weight .dtype == torch .float8_e4m3fn )
600
+
601
+ if is_weight or is_nvfp4_weight_scale :
602
+ if ("wk" in modules or "k_proj" in modules ):
603
+ loaded_weight = permute (loaded_weight ,
604
+ self .config .num_key_value_heads )
605
+ elif ("wq" in modules or "q_proj" in modules ):
606
+ loaded_weight = permute (loaded_weight ,
607
+ self .config .num_attention_heads )
581
608
582
609
return name , loaded_weight
0 commit comments