Skip to content

Commit 707db88

Browse files
committed
[NVIDIA] Fix Llama4 Scout FP4 functionality issues
Fix the weight loading issues and accuray issues when using the NVIDIA ModelOpt Llama4 Scout FP4 model. Signed-off-by: Po-Han Huang <[email protected]>
1 parent d128d0d commit 707db88

File tree

3 files changed

+57
-13
lines changed

3 files changed

+57
-13
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,14 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
874874
elif shard_id == "w2":
875875
param_data[expert_id] = loaded_weight
876876

877+
def _load_w13_weight_scale(self, shard_dim: int,
878+
loaded_weight: torch.Tensor,
879+
param: torch.Tensor, tp_rank: int):
880+
shard_size = param.shape[shard_dim]
881+
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
882+
shard_size)
883+
param.copy_(loaded_weight)
884+
877885
def _load_model_weight_or_group_weight_scale(self,
878886
shard_dim: int,
879887
expert_data: torch.Tensor,
@@ -1123,7 +1131,12 @@ def weight_loader(self,
11231131
"weight_scale_2" in weight_name if uses_weight_scale_2 else
11241132
"weight_scale" in weight_name) or "input_scale" in weight_name
11251133

1126-
if per_tensor_conditions:
1134+
if "w13_weight_scale" in weight_name:
1135+
self._load_w13_weight_scale(shard_dim=shard_dim,
1136+
loaded_weight=loaded_weight,
1137+
param=param,
1138+
tp_rank=self.tp_rank)
1139+
elif per_tensor_conditions:
11271140
self._load_per_tensor_weight_scale(
11281141
shard_id=shard_id,
11291142
param=param,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,8 +778,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
778778
# Swizzle the weight blockscale.
779779
# contracting dimension is input dimension
780780
# block_size = 16;
781-
assert (layer.weight_scale.shape[1] % 16 == 0), (
782-
"Expected weight_scale.dim(1) to be divisible by 16")
783781
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
784782
"Weight Block scale must be represented as FP8-E4M3")
785783
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)

vllm/model_executor/models/llama4.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,23 @@ def load_moe_expert_weights(
363363
continue
364364
param = params_dict[full_param_name]
365365
weight_loader = param.weight_loader
366+
367+
# Helper function to check if the weight is FP4.
368+
# We use uint8 to store FP4 weights for now.
369+
def is_fp4_weight(weight):
370+
return weight.dtype == torch.uint8
371+
366372
if fused:
367373
if "w13" in full_param_name:
368374
shard_idx = 0 if shard_id == "w1" else 1
369375
new_loaded_weight = new_loaded_weight[shard_idx]
370-
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
376+
377+
# Only transpose for non-FP4 weights
378+
# FP4 weights are already in the correct format and
379+
# shouldn't be transposed here.
380+
if not is_fp4_weight(new_loaded_weight):
381+
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
382+
371383
layer_idx = extract_layer_index(name)
372384
# EP mapping
373385
expert_map = self.layers[
@@ -382,6 +394,11 @@ def load_moe_expert_weights(
382394
else:
383395
# TODO: add EP support for non fused weights
384396
pass
397+
398+
# Only transpose for FP4 weights
399+
if is_fp4_weight(new_loaded_weight):
400+
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
401+
385402
weight_loader(param,
386403
new_loaded_weight,
387404
full_param_name,
@@ -447,6 +464,7 @@ def load_weights(self, weights: Iterable[tuple[str,
447464
param = params_dict[name]
448465
weight_loader = getattr(param, "weight_loader",
449466
default_weight_loader)
467+
450468
if weight_loader == default_weight_loader:
451469
weight_loader(param, loaded_weight)
452470
else:
@@ -491,11 +509,17 @@ def load_weights(self, weights: Iterable[tuple[str,
491509
else:
492510
shard_id = "w1"
493511

512+
# Transpose if the weights are FP8 or FP4.
513+
if loaded_weight.dtype == torch.uint8 \
514+
or loaded_weight.dtype == torch.float8_e4m3fn:
515+
loaded_weight = loaded_weight.transpose(-1, -2)
516+
494517
weight_loader(param,
495518
loaded_weight,
496519
name,
497520
shard_id=shard_id,
498521
expert_id=0)
522+
499523
else:
500524
# Regular weight loader (handles both
501525
# param.weight_loader and default_weight_loader)
@@ -560,23 +584,32 @@ def permute_qk_weight_for_rotary(
560584
loaded_weight: torch.Tensor,
561585
) -> tuple[str, torch.Tensor]:
562586

587+
# Helper function to permute the weight's channels
563588
def permute(w: torch.Tensor, n_heads: int):
564589
attn_in = self.config.head_dim * n_heads
565590
attn_out = self.config.hidden_size
566591

592+
# If the weight is FP4 packed as uint8, we need to divide attn_out
593+
# by 2.
594+
if w.dtype == torch.uint8 and w.shape[1] * 2 == attn_out:
595+
attn_out = attn_out // 2
596+
567597
return w.view(n_heads, attn_in // n_heads // 2, 2,
568598
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
569599

570600
modules = name.split(".")
571601

572-
# rotary embeds should be sliced
573-
if ("wk" in modules or "k_proj" in modules) \
574-
and modules[-1] == "weight":
575-
loaded_weight = permute(loaded_weight,
576-
self.config.num_key_value_heads)
577-
elif ("wq" in modules or "q_proj" in modules) \
578-
and modules[-1] == "weight":
579-
loaded_weight = permute(loaded_weight,
580-
self.config.num_attention_heads)
602+
# Permute Q/K weights and weight block scales for rotary embedding
603+
is_weight = modules[-1] == "weight"
604+
is_nvfp4_weight_scale = (modules[-1] == "weight_scale" and
605+
loaded_weight.dtype == torch.float8_e4m3fn)
606+
607+
if is_weight or is_nvfp4_weight_scale:
608+
if ("wk" in modules or "k_proj" in modules):
609+
loaded_weight = permute(loaded_weight,
610+
self.config.num_key_value_heads)
611+
elif ("wq" in modules or "q_proj" in modules):
612+
loaded_weight = permute(loaded_weight,
613+
self.config.num_attention_heads)
581614

582615
return name, loaded_weight

0 commit comments

Comments
 (0)