Skip to content

Commit 668f1ae

Browse files
nvpohanhvadiklyutiy
authored andcommitted
[NVIDIA] Fix Llama4 Scout FP4 functionality issues (vllm-project#21499)
Signed-off-by: Po-Han Huang <[email protected]>
1 parent f01928d commit 668f1ae

File tree

3 files changed

+219
-70
lines changed

3 files changed

+219
-70
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,14 @@ def _load_per_tensor_weight_scale(self, shard_id: str,
874874
elif shard_id == "w2":
875875
param_data[expert_id] = loaded_weight
876876

877+
def _load_w13_weight_scale(self, shard_dim: int,
878+
loaded_weight: torch.Tensor,
879+
param: torch.Tensor, tp_rank: int):
880+
shard_size = param.shape[shard_dim]
881+
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
882+
shard_size)
883+
param.copy_(loaded_weight)
884+
877885
def _load_model_weight_or_group_weight_scale(self,
878886
shard_dim: int,
879887
expert_data: torch.Tensor,
@@ -1123,7 +1131,12 @@ def weight_loader(self,
11231131
"weight_scale_2" in weight_name if uses_weight_scale_2 else
11241132
"weight_scale" in weight_name) or "input_scale" in weight_name
11251133

1126-
if per_tensor_conditions:
1134+
if "w13_weight_scale" in weight_name:
1135+
self._load_w13_weight_scale(shard_dim=shard_dim,
1136+
loaded_weight=loaded_weight,
1137+
param=param,
1138+
tp_rank=self.tp_rank)
1139+
elif per_tensor_conditions:
11271140
self._load_per_tensor_weight_scale(
11281141
shard_id=shard_id,
11291142
param=param,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,8 +778,6 @@ def process_weights_after_loading(self, layer: Module) -> None:
778778
# Swizzle the weight blockscale.
779779
# contracting dimension is input dimension
780780
# block_size = 16;
781-
assert (layer.weight_scale.shape[1] % 16 == 0), (
782-
"Expected weight_scale.dim(1) to be divisible by 16")
783781
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
784782
"Weight Block scale must be represented as FP8-E4M3")
785783
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)

0 commit comments

Comments
 (0)