@@ -363,11 +363,23 @@ def load_moe_expert_weights(
363
363
continue
364
364
param = params_dict [full_param_name ]
365
365
weight_loader = param .weight_loader
366
+
367
+ # Helper function to check if the weight is FP4.
368
+ # We use uint8 to store FP4 weights for now.
369
+ def is_fp4_weight (weight ):
370
+ return weight .dtype == torch .uint8
371
+
366
372
if fused :
367
373
if "w13" in full_param_name :
368
374
shard_idx = 0 if shard_id == "w1" else 1
369
375
new_loaded_weight = new_loaded_weight [shard_idx ]
370
- new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
376
+
377
+ # Only transpose for non-FP4 weights
378
+ # FP4 weights are already in the correct format and
379
+ # shouldn't be transposed here.
380
+ if not is_fp4_weight (new_loaded_weight ):
381
+ new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
382
+
371
383
layer_idx = extract_layer_index (name )
372
384
# EP mapping
373
385
expert_map = self .layers [
@@ -382,6 +394,11 @@ def load_moe_expert_weights(
382
394
else :
383
395
# TODO: add EP support for non fused weights
384
396
pass
397
+
398
+ # Only transpose for FP4 weights
399
+ if is_fp4_weight (new_loaded_weight ):
400
+ new_loaded_weight = new_loaded_weight .transpose (- 1 , - 2 )
401
+
385
402
weight_loader (param ,
386
403
new_loaded_weight ,
387
404
full_param_name ,
@@ -447,6 +464,7 @@ def load_weights(self, weights: Iterable[tuple[str,
447
464
param = params_dict [name ]
448
465
weight_loader = getattr (param , "weight_loader" ,
449
466
default_weight_loader )
467
+
450
468
if weight_loader == default_weight_loader :
451
469
weight_loader (param , loaded_weight )
452
470
else :
@@ -491,11 +509,17 @@ def load_weights(self, weights: Iterable[tuple[str,
491
509
else :
492
510
shard_id = "w1"
493
511
512
+ # Transpose if the weights are FP8 or FP4.
513
+ if loaded_weight .dtype == torch .uint8 \
514
+ or loaded_weight .dtype == torch .float8_e4m3fn :
515
+ loaded_weight = loaded_weight .transpose (- 1 , - 2 )
516
+
494
517
weight_loader (param ,
495
518
loaded_weight ,
496
519
name ,
497
520
shard_id = shard_id ,
498
521
expert_id = 0 )
522
+
499
523
else :
500
524
# Regular weight loader (handles both
501
525
# param.weight_loader and default_weight_loader)
@@ -560,23 +584,32 @@ def permute_qk_weight_for_rotary(
560
584
loaded_weight : torch .Tensor ,
561
585
) -> tuple [str , torch .Tensor ]:
562
586
587
+ # Helper function to permute the weight's channels
563
588
def permute (w : torch .Tensor , n_heads : int ):
564
589
attn_in = self .config .head_dim * n_heads
565
590
attn_out = self .config .hidden_size
566
591
592
+ # If the weight is FP4 packed as uint8, we need to divide attn_out
593
+ # by 2.
594
+ if w .dtype == torch .uint8 and w .shape [1 ] * 2 == attn_out :
595
+ attn_out = attn_out // 2
596
+
567
597
return w .view (n_heads , attn_in // n_heads // 2 , 2 ,
568
598
attn_out ).transpose (1 , 2 ).reshape (attn_in , attn_out )
569
599
570
600
modules = name .split ("." )
571
601
572
- # rotary embeds should be sliced
573
- if ("wk" in modules or "k_proj" in modules ) \
574
- and modules [- 1 ] == "weight" :
575
- loaded_weight = permute (loaded_weight ,
576
- self .config .num_key_value_heads )
577
- elif ("wq" in modules or "q_proj" in modules ) \
578
- and modules [- 1 ] == "weight" :
579
- loaded_weight = permute (loaded_weight ,
580
- self .config .num_attention_heads )
602
+ # Permute Q/K weights and weight block scales for rotary embedding
603
+ is_weight = modules [- 1 ] == "weight"
604
+ is_nvfp4_weight_scale = (modules [- 1 ] == "weight_scale" and
605
+ loaded_weight .dtype == torch .float8_e4m3fn )
606
+
607
+ if is_weight or is_nvfp4_weight_scale :
608
+ if ("wk" in modules or "k_proj" in modules ):
609
+ loaded_weight = permute (loaded_weight ,
610
+ self .config .num_key_value_heads )
611
+ elif ("wq" in modules or "q_proj" in modules ):
612
+ loaded_weight = permute (loaded_weight ,
613
+ self .config .num_attention_heads )
581
614
582
615
return name , loaded_weight
0 commit comments