Skip to content

Commit 765aaff

Browse files
committed
[NVIDIA] Fix Llama4 Scout FP4 functionality issues
Fix the weight loading issues and accuray issues when using the NVIDIA ModelOpt Llama4 Scout FP4 model. Signed-off-by: Po-Han Huang <[email protected]>
1 parent fe56180 commit 765aaff

File tree

3 files changed

+56
-30
lines changed

3 files changed

+56
-30
lines changed

vllm/engine/arg_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ class EngineArgs:
362362
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
363363
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
364364

365+
device: Device = DeviceConfig.device
365366
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
366367
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
367368
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,8 +810,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
810810
# Swizzle the weight blockscale.
811811
# contracting dimension is input dimension
812812
# block_size = 16;
813-
assert (layer.weight_scale.shape[1] % 16 == 0), (
814-
"Expected weight_scale.dim(1) to be divisible by 16")
815813
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
816814
"Weight Block scale must be represented as FP8-E4M3")
817815
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)

vllm/model_executor/models/llama4.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,22 @@ def load_moe_expert_weights(
363363
continue
364364
param = params_dict[full_param_name]
365365
weight_loader = param.weight_loader
366+
367+
# Helper function to check if the weight is FP4.
368+
# We use uint8 to store FP4 weights for now.
369+
def is_fp4_weight(weight):
370+
return weight.dtype == torch.uint8
371+
366372
if fused:
367373
if "w13" in full_param_name:
368374
shard_idx = 0 if shard_id == "w1" else 1
369375
new_loaded_weight = new_loaded_weight[shard_idx]
370-
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
376+
377+
# Only transpose for non-FP4 weights
378+
# FP4 weights are already in the correct format and shouldn't be transposed here.
379+
if not is_fp4_weight(new_loaded_weight):
380+
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
381+
371382
layer_idx = extract_layer_index(name)
372383
# EP mapping
373384
expert_map = self.layers[
@@ -382,6 +393,11 @@ def load_moe_expert_weights(
382393
else:
383394
# TODO: add EP support for non fused weights
384395
pass
396+
397+
# Only transpose for FP4 weights
398+
if is_fp4_weight(new_loaded_weight):
399+
new_loaded_weight = new_loaded_weight.transpose(-1, -2)
400+
385401
weight_loader(param,
386402
new_loaded_weight,
387403
full_param_name,
@@ -402,6 +418,12 @@ def load_weights(self, weights: Iterable[tuple[str,
402418
(".gate_up_proj", ".gate_proj", 0),
403419
(".gate_up_proj", ".up_proj", 1),
404420
]
421+
expert_scale_params_mapping = [
422+
# (expert_name, expert_id, shard_id)
423+
("w13_", 0, 'w1'),
424+
("w13_", 0, 'w3'),
425+
("w2_", 0, 'w2')
426+
]
405427
fused_experts_params = False
406428
expert_params_mapping = FusedMoE.make_expert_params_mapping(
407429
ckpt_gate_proj_name="gate_proj",
@@ -483,19 +505,19 @@ def load_weights(self, weights: Iterable[tuple[str,
483505
'supports_moe_loading', False)
484506

485507
if supports_moe:
486-
# This is a MoE weight loader
487-
if "w13_" in name:
488-
shard_id = "w1"
489-
elif "w2_" in name:
490-
shard_id = "w2"
491-
else:
492-
shard_id = "w1"
493-
494-
weight_loader(param,
495-
loaded_weight,
496-
name,
497-
shard_id=shard_id,
498-
expert_id=0)
508+
# Transpose if the weights are FP8 or FP4.
509+
if loaded_weight.dtype == torch.uint8 or loaded_weight.dtype == torch.float8_e4m3fn:
510+
loaded_weight = loaded_weight.transpose(-1, -2)
511+
param.data.fill_(0)
512+
513+
for (expert_name, expert_id, shard_id) in expert_scale_params_mapping:
514+
if expert_name in name:
515+
weight_loader(param,
516+
loaded_weight,
517+
name,
518+
shard_id=shard_id,
519+
expert_id=expert_id)
520+
499521
else:
500522
# Regular weight loader (handles both
501523
# param.weight_loader and default_weight_loader)
@@ -560,23 +582,28 @@ def permute_qk_weight_for_rotary(
560582
loaded_weight: torch.Tensor,
561583
) -> tuple[str, torch.Tensor]:
562584

585+
# Helper function to permute the weight's channels
563586
def permute(w: torch.Tensor, n_heads: int):
564-
attn_in = self.config.head_dim * n_heads
565-
attn_out = self.config.hidden_size
566-
567-
return w.view(n_heads, attn_in // n_heads // 2, 2,
568-
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
587+
head_dim = w.shape[0] // n_heads
588+
return (
589+
w.view(n_heads, head_dim // 2, 2, w.shape[1])
590+
.transpose(1, 2)
591+
.reshape(w.shape[0], w.shape[1])
592+
)
569593

570594
modules = name.split(".")
571595

572-
# rotary embeds should be sliced
573-
if ("wk" in modules or "k_proj" in modules) \
574-
and modules[-1] == "weight":
575-
loaded_weight = permute(loaded_weight,
576-
self.config.num_key_value_heads)
577-
elif ("wq" in modules or "q_proj" in modules) \
578-
and modules[-1] == "weight":
579-
loaded_weight = permute(loaded_weight,
580-
self.config.num_attention_heads)
596+
# Permute Q/K weights and weight block scales for rotary embedding
597+
is_weight = modules[-1] == "weight"
598+
is_nvfp4_weight_scale = (modules[-1] == "weight_scale"
599+
and loaded_weight.dtype == torch.float8_e4m3fn)
600+
601+
if is_weight or is_nvfp4_weight_scale:
602+
if ("wk" in modules or "k_proj" in modules):
603+
loaded_weight = permute(loaded_weight,
604+
self.config.num_key_value_heads)
605+
elif ("wq" in modules or "q_proj" in modules):
606+
loaded_weight = permute(loaded_weight,
607+
self.config.num_attention_heads)
581608

582609
return name, loaded_weight

0 commit comments

Comments
 (0)