Skip to content

Commit afaf28d

Browse files
committed
[NVIDIA] Fix Llama4 Scout FP4 functionality issues
Fix the weight loading issues and accuray issues when using the NVIDIA ModelOpt Llama4 Scout FP4 model. Signed-off-by: Po-Han Huang <[email protected]>
1 parent fe56180 commit afaf28d

File tree

3 files changed

+61
-18
lines changed

3 files changed

+61
-18
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,17 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
874874
elif shard_id == "w2":
875875
param_data[expert_id] = loaded_weight
876876

877+
def _load_w13_weight_scale(self,
878+
shard_dim: int,
879+
loaded_weight: torch.Tensor,
880+
param: torch.Tensor,
881+
tp_rank: int):
882+
shard_size = param.shape[shard_dim]
883+
loaded_weight = loaded_weight.narrow(shard_dim,
884+
shard_size * tp_rank,
885+
shard_size)
886+
param.copy_(loaded_weight)
887+
877888
def _load_model_weight_or_group_weight_scale(self,
878889
shard_dim: int,
879890
expert_data: torch.Tensor,
@@ -1123,7 +1134,14 @@ def weight_loader(self,
11231134
"weight_scale_2" in weight_name if uses_weight_scale_2 else
11241135
"weight_scale" in weight_name) or "input_scale" in weight_name
11251136

1126-
if per_tensor_conditions:
1137+
if "w13_weight_scale" in weight_name:
1138+
self._load_w13_weight_scale(
1139+
shard_dim=shard_dim,
1140+
loaded_weight=loaded_weight,
1141+
param=param,
1142+
tp_rank=self.tp_rank
1143+
)
1144+
elif per_tensor_conditions:
11271145
self._load_per_tensor_weight_scale(
11281146
shard_id=shard_id,
11291147
param=param,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,8 +810,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
810810
# Swizzle the weight blockscale.
811811
# contracting dimension is input dimension
812812
# block_size = 16;
813-
assert (layer.weight_scale.shape[1] % 16 == 0), (
814-
"Expected weight_scale.dim(1) to be divisible by 16")
815813
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
816814
"Weight Block scale must be represented as FP8-E4M3")
817815
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)

vllm/model_executor/models/llama4.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,22 @@ def load_moe_expert_weights(
363363
continue
364364
param = params_dict[full_param_name]
365365
weight_loader = param.weight_loader
366+
367+
# Helper function to check if the weight is FP4.
368+
# We use uint8 to store FP4 weights for now.
369+
def is_fp4_weight(weight):
370+
return weight.dtype == torch.uint8
371+
366372
if fused:
367373
if "w13" in full_param_name:
368374
shard_idx = 0 if shard_id == "w1" else 1
369375
new_loaded_weight = new_loaded_weight[shard_idx]
370-
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
376+
377+
# Only transpose for non-FP4 weights
378+
# FP4 weights are already in the correct format and shouldn't be transposed here.
379+
if not is_fp4_weight(new_loaded_weight):
380+
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
381+
371382
layer_idx = extract_layer_index(name)
372383
# EP mapping
373384
expert_map = self.layers[
@@ -382,6 +393,11 @@ def load_moe_expert_weights(
382393
else:
383394
# TODO: add EP support for non fused weights
384395
pass
396+
397+
# Only transpose for FP4 weights
398+
if is_fp4_weight(new_loaded_weight):
399+
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
400+
385401
weight_loader(param,
386402
new_loaded_weight,
387403
full_param_name,
@@ -491,11 +507,17 @@ def load_weights(self, weights: Iterable[tuple[str,
491507
else:
492508
shard_id = "w1"
493509

510+
# Transpose if the weights are FP8 or FP4.
511+
if loaded_weight.dtype == torch.uint8 \
512+
or loaded_weight.dtype == torch.float8_e4m3fn:
513+
loaded_weight = loaded_weight.transpose(-1, -2)
514+
494515
weight_loader(param,
495516
loaded_weight,
496517
name,
497518
shard_id=shard_id,
498519
expert_id=0)
520+
499521
else:
500522
# Regular weight loader (handles both
501523
# param.weight_loader and default_weight_loader)
@@ -560,23 +582,28 @@ def permute_qk_weight_for_rotary(
560582
loaded_weight: torch.Tensor,
561583
) -> tuple[str, torch.Tensor]:
562584

585+
# Helper function to permute the weight's channels
563586
def permute(w: torch.Tensor, n_heads: int):
564-
attn_in = self.config.head_dim * n_heads
565-
attn_out = self.config.hidden_size
566-
567-
return w.view(n_heads, attn_in // n_heads // 2, 2,
568-
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
587+
head_dim = w.shape[0] // n_heads
588+
return (
589+
w.view(n_heads, head_dim // 2, 2, w.shape[1])
590+
.transpose(1, 2)
591+
.reshape(w.shape[0], w.shape[1])
592+
)
569593

570594
modules = name.split(".")
571595

572-
# rotary embeds should be sliced
573-
if ("wk" in modules or "k_proj" in modules) \
574-
and modules[-1] == "weight":
575-
loaded_weight = permute(loaded_weight,
576-
self.config.num_key_value_heads)
577-
elif ("wq" in modules or "q_proj" in modules) \
578-
and modules[-1] == "weight":
579-
loaded_weight = permute(loaded_weight,
580-
self.config.num_attention_heads)
596+
# Permute Q/K weights and weight block scales for rotary embedding
597+
is_weight = modules[-1] == "weight"
598+
is_nvfp4_weight_scale = (modules[-1] == "weight_scale"
599+
and loaded_weight.dtype == torch.float8_e4m3fn)
600+
601+
if is_weight or is_nvfp4_weight_scale:
602+
if ("wk" in modules or "k_proj" in modules):
603+
loaded_weight = permute(loaded_weight,
604+
self.config.num_key_value_heads)
605+
elif ("wq" in modules or "q_proj" in modules):
606+
loaded_weight = permute(loaded_weight,
607+
self.config.num_attention_heads)
581608

582609
return name, loaded_weight

0 commit comments

Comments
 (0)