Skip to content

Commit a85f74b

Browse files
Edwardf0t1aarnphm
authored andcommitted
Add ModelOpt Qwen3 nvfp4 support (vllm-project#20101)
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 1aa0235 commit a85f74b

File tree

3 files changed

+58
-37
lines changed

3 files changed

+58
-37
lines changed

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -764,39 +764,41 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
764764
return None
765765
return remapped_name
766766

767-
possible_scale_names = [".k_scale", ".v_scale"]
768-
modelopt_scale_names = [
769-
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
767+
# Define scale name mapping patterns in order of precedence
768+
scale_mapping_patterns = [
769+
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
770+
# .self_attn.attn.{k,v}_scale
771+
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
772+
r".self_attn.attn.\2_scale"),
773+
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
774+
# .self_attn.attn.{k,v}_scale
775+
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
776+
# Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
777+
# .self_attn.attn.{k,v}_scale
778+
(r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"
779+
),
780+
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
781+
(r"\.([kv])_scale$", r".attn.\1_scale"),
770782
]
771-
# Also support qkv_proj scale parameters (from stacked parameter processing)
772-
qkv_proj_scale_names = [
773-
".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale"
774-
]
775-
for scale_name in possible_scale_names:
776-
if name.endswith(scale_name):
777-
if any(mo_scale_name in name
778-
for mo_scale_name in modelopt_scale_names):
779-
remapped_name = name.replace(
780-
f".self_attn.{scale_name[1]}_proj{scale_name}",
781-
f".self_attn.attn{scale_name}")
782-
elif any(qkv_scale_name in name
783-
for qkv_scale_name in qkv_proj_scale_names):
784-
# Handle qkv_proj scale parameters
785-
remapped_name = name.replace(
786-
f".self_attn.qkv_proj{scale_name}",
787-
f".self_attn.attn{scale_name}")
788-
else:
789-
remapped_name = name.replace(scale_name, f".attn{scale_name}")
790-
if remapped_name not in params_dict:
791-
logger.warning_once(
792-
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
793-
scale_name,
794-
name,
795-
remapped_name,
796-
scale_name,
797-
)
798-
return None
799-
return remapped_name
783+
784+
# Check if name ends with k_scale or v_scale
785+
if name.endswith((".k_scale", ".v_scale")):
786+
import regex as re
787+
788+
for pattern, replacement in scale_mapping_patterns:
789+
if re.search(pattern, name):
790+
remapped_name = re.sub(pattern, replacement, name)
791+
if remapped_name not in params_dict:
792+
scale_type = name.split(".")[-1]
793+
logger.warning_once(
794+
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
795+
scale_type,
796+
name,
797+
remapped_name,
798+
scale_type,
799+
)
800+
return None
801+
return remapped_name
800802

801803
# If there were no matches, return the untouched param name
802804
return name

vllm/model_executor/models/qwen2.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,9 +408,18 @@ def load_weights(self, weights: Iterable[tuple[str,
408408
continue
409409
if is_pp_missing_parameter(name, self):
410410
continue
411+
if name.endswith("scale"):
412+
# Remapping the name of FP8 kv-scale.
413+
name = maybe_remap_kv_scale_name(name, params_dict)
414+
if name is None:
415+
continue
411416
param = params_dict[name]
412-
weight_loader = param.weight_loader
413-
weight_loader(param, loaded_weight, shard_id)
417+
weight_loader = getattr(param, "weight_loader",
418+
default_weight_loader)
419+
if weight_loader == default_weight_loader:
420+
weight_loader(param, loaded_weight)
421+
else:
422+
weight_loader(param, loaded_weight, shard_id)
414423
break
415424
else:
416425
# Skip loading extra bias for GPTQ models.

vllm/model_executor/models/qwen3_moe.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@
4848
from vllm.model_executor.layers.rotary_embedding import get_rope
4949
from vllm.model_executor.layers.vocab_parallel_embedding import (
5050
ParallelLMHead, VocabParallelEmbedding)
51-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51+
from vllm.model_executor.model_loader.weight_utils import (
52+
default_weight_loader, maybe_remap_kv_scale_name)
5253
from vllm.model_executor.sampling_metadata import SamplingMetadata
5354
from vllm.sequence import IntermediateTensors
5455

@@ -471,12 +472,21 @@ def load_weights(self, weights: Iterable[tuple[str,
471472
# Skip layers on other devices.
472473
if is_pp_missing_parameter(name, self):
473474
continue
475+
if name.endswith("scale"):
476+
# Remapping the name of FP8 kv-scale.
477+
name = maybe_remap_kv_scale_name(name, params_dict)
478+
if name is None:
479+
continue
474480
if name not in params_dict:
475481
continue
476482

477483
param = params_dict[name]
478-
weight_loader = param.weight_loader
479-
weight_loader(param, loaded_weight, shard_id)
484+
weight_loader = getattr(param, "weight_loader",
485+
default_weight_loader)
486+
if weight_loader == default_weight_loader:
487+
weight_loader(param, loaded_weight)
488+
else:
489+
weight_loader(param, loaded_weight, shard_id)
480490
break
481491
else:
482492
is_expert_weight = False

0 commit comments

Comments
 (0)