Skip to content

Commit 23bd139

Browse files
committed
Fix attn_out size for weight scale
Signed-off-by: Po-Han Huang <[email protected]>
1 parent 707db88 commit 23bd139

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

vllm/model_executor/models/llama4.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def permute_qk_weight_for_rotary(
585585
) -> tuple[str, torch.Tensor]:
586586

587587
# Helper function to permute the weight's channels
588-
def permute(w: torch.Tensor, n_heads: int):
588+
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
589589
attn_in = self.config.head_dim * n_heads
590590
attn_out = self.config.hidden_size
591591

@@ -594,6 +594,11 @@ def permute(w: torch.Tensor, n_heads: int):
594594
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
595595
attn_out = attn_out // 2
596596

597+
# If the weight is a weight scale, we need to divide attn_out by
598+
# block size, which is currently 16.
599+
elif w.dtype == torch.float8_e4m3fn and is_weight_scale:
600+
attn_out = attn_out // 16
601+
597602
return w.view(n_heads, attn_in // n_heads // 2, 2,
598603
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
599604

@@ -607,9 +612,11 @@ def permute(w: torch.Tensor, n_heads: int):
607612
if is_weight or is_nvfp4_weight_scale:
608613
if ("wk" in modules or "k_proj" in modules):
609614
loaded_weight = permute(loaded_weight,
610-
self.config.num_key_value_heads)
615+
self.config.num_key_value_heads,
616+
is_nvfp4_weight_scale)
611617
elif ("wq" in modules or "q_proj" in modules):
612618
loaded_weight = permute(loaded_weight,
613-
self.config.num_attention_heads)
619+
self.config.num_attention_heads,
620+
is_nvfp4_weight_scale)
614621

615622
return name, loaded_weight

0 commit comments

Comments
 (0)