Skip to content

Conversation

@zhyajie
Copy link

@zhyajie zhyajie commented Oct 27, 2025

Purpose

This PR aims to fix the loading errors for the DeepSeek MTP weights when VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled (which is the default setting).
The issue occurs during model loading where a KeyError is thrown for the parameter 'model.layers.61.mtp_block.mlp.shared_experts.down_proj.weight_scale_inv'.
Root Cause: The issue was introduced by PR #24097 which added fused shared experts optimization for ROCm but did not properly adapt it for the DeepSeek MTP model architecture. This causes a KeyError during weight loading when the shared_experts parameter is missing for shared experts in MTP blocks.
The repair method refers to the changes made to vllm/model_executor/models/deepseek_v2.py in this PR: #24097

Test Plan

The following tests validate DeepSeek models by collecting benchmark metrics and performning correctness tests through lm_eval.

vLLM server launch command:

AITER_ENABLE_VSKIP=0 \
VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
vllm serve $MODEL \
--tensor-parallel-size 8 \
--disable-log-requests \
--compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
--trust-remote-code \
--speculative-config='{"method": "deepseek_mtp", "num_speculative_tokens": 1}' \
--block-size 1

lm_eval command:

lm_eval --model local-completions --tasks gsm8k --model_args model=${model_name},base_url=http://localhost:8000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=False

Test Result

berfor this PR,

(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627] WorkerProc failed to start.
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627] Traceback (most recent call last):
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/executor/multiproc_executor.py", line 601, in worker_main
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     worker = WorkerProc(*args, **kwargs)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/executor/multiproc_executor.py", line 456, in __init__
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.worker.load_model()
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/worker/gpu_worker.py", line 233, in load_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/worker/gpu_model_runner.py", line 2895, in load_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.drafter.load_model(self.model)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/v1/spec_decode/eagle.py", line 930, in load_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.model = get_model(
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/model_executor/model_loader/__init__.py", line 130, in get_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     return loader.load_model(vllm_config=vllm_config, model_config=model_config)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/model_executor/model_loader/base_loader.py", line 55, in load_model
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     self.load_weights(model, model_config)
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/model_executor/model_loader/default_loader.py", line 300, in load_weights
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     loaded_weights = model.load_weights(
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]   File "/home/yajizhan/code/vllm/vllm/model_executor/models/deepseek_mtp.py", line 296, in load_weights
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627]     param = params_dict[name]
(Worker_TP4 pid=309299) ERROR 10-27 08:09:37 [multiproc_executor.py:627] KeyError: 'model.layers.61.mtp_block.mlp.shared_experts.down_proj.weight_scale_inv'

after this PR, The service can start normally, the MTP weights are loaded properly, and the results of the gsm8k test and mtp model acceptance rate are as follows.

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9530|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9522|±  |0.0059|
INFO 10-27 08:36:51 [metrics.py:100] SpecDecoding metrics: Mean acceptance length: 1.92, Accepted throughput: 76.00 tokens/s, Drafted throughput: 82.70 tokens/s, Accepted: 760 tokens, Drafted: 827 tokens, Per-position acceptance rate: 0.919, Avg Draft acceptance rate: 91.9%
NFO 10-27 08:37:01 [metrics.py:100] SpecDecoding metrics: Mean acceptance length: 1.94, Accepted throughput: 80.39 tokens/s, Drafted throughput: 85.59 tokens/s, Accepted: 804 tokens, Drafted: 856 tokens, Per-position acceptance rate: 0.939, Avg Draft acceptance rate: 93.9%

@zhyajie zhyajie requested a review from luccafong as a code owner October 27, 2025 08:57
@mergify mergify bot added deepseek Related to DeepSeek models rocm Related to AMD ROCm labels Oct 27, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR addresses a critical bug in the DeepSeek-MTP model loading process when VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS is enabled, which caused a KeyError due to missing parameters. The fix involves adjusting the expert parameters mapping and adding conditional logic to handle the fused shared experts layer. The changes ensure correct weight loading and proper model initialization, as validated by benchmark metrics and correctness tests.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +321 to +355
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)

param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fuse_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
# Skip loading extra bias for GPTQ models.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Skip remote experts instead of falling back to generic loader

The new loop only breaks when param.weight_loader(..., return_success=True) succeeds, but weight_loader returns False for experts that are not hosted on the current rank (the common case in distributed MoE loading). When this happens the for loop completes without a break, so the else clause executes and calls weight_loader(param, loaded_weight) without shard_id/expert_id. That generic path either raises TypeError or copies remote expert tensors into the wrong parameter, causing multi-rank DeepSeek checkpoints to fail. Previously the loop always broke after invoking weight_loader, allowing it to silently skip remote experts. The fallback should only trigger when the weight name does not correspond to an expert at all, not when the expert is simply non-local.

Useful? React with 👍 / 👎.

@zhyajie
Copy link
Author

zhyajie commented Oct 27, 2025

@HAIAI @kliuae Please take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant