@@ -585,7 +585,7 @@ def permute_qk_weight_for_rotary(
585
585
) -> tuple [str , torch .Tensor ]:
586
586
587
587
# Helper function to permute the weight's channels
588
- def permute (w : torch .Tensor , n_heads : int ):
588
+ def permute (w : torch .Tensor , n_heads : int , is_weight_scale : bool ):
589
589
attn_in = self .config .head_dim * n_heads
590
590
attn_out = self .config .hidden_size
591
591
@@ -594,6 +594,11 @@ def permute(w: torch.Tensor, n_heads: int):
594
594
if w .dtype == torch .uint8 and w .shape [1 ] * 2 == attn_out :
595
595
attn_out = attn_out // 2
596
596
597
+ # If the weight is a weight scale, we need to divide attn_out by
598
+ # block size, which is currently 16.
599
+ elif w .dtype == torch .float8_e4m3fn and is_weight_scale :
600
+ attn_out = attn_out // 16
601
+
597
602
return w .view (n_heads , attn_in // n_heads // 2 , 2 ,
598
603
attn_out ).transpose (1 , 2 ).reshape (attn_in , attn_out )
599
604
@@ -607,9 +612,11 @@ def permute(w: torch.Tensor, n_heads: int):
607
612
if is_weight or is_nvfp4_weight_scale :
608
613
if ("wk" in modules or "k_proj" in modules ):
609
614
loaded_weight = permute (loaded_weight ,
610
- self .config .num_key_value_heads )
615
+ self .config .num_key_value_heads ,
616
+ is_nvfp4_weight_scale )
611
617
elif ("wq" in modules or "q_proj" in modules ):
612
618
loaded_weight = permute (loaded_weight ,
613
- self .config .num_attention_heads )
619
+ self .config .num_attention_heads ,
620
+ is_nvfp4_weight_scale )
614
621
615
622
return name , loaded_weight
0 commit comments