Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 118 additions & 42 deletions vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import typing
from collections.abc import Callable, Iterable

import torch
import torch.nn as nn
from transformers import PretrainedConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE

from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_fusion_shared_expert_enabled,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -212,11 +217,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]

expert_params_mapping = FusedMoE.make_expert_params_mapping(
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts
+ (
self.config.n_shared_experts
if is_rocm_aiter_fusion_shared_expert_enabled()
else 0
),
)

params_dict = dict(self.named_parameters())
Expand All @@ -227,6 +237,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
is_fuse_shared_experts_layer = (
is_rocm_aiter_fusion_shared_expert_enabled()
and ("mlp.shared_experts" in name)
)
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
Expand All @@ -240,6 +254,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fuse_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)

# QKV fusion is optional, fall back to normal
Expand All @@ -260,45 +276,105 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue

param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fuse_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = 1 if "down_proj.weight" in name else 0
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
chunk_size = total // num_chunks

for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight

if is_fuse_shared_experts_layer:
if split_dim == 0:
weight_to_load = loaded_weight[
j * chunk_size : (j + 1) * chunk_size, :
]
else:
weight_to_load = loaded_weight[
:, j * chunk_size : (j + 1) * chunk_size
]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)

# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue

# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)

param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fuse_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
# Skip loading extra bias for GPTQ models.
Comment on lines +321 to +355

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Skip remote experts instead of falling back to generic loader

The new loop only breaks when param.weight_loader(..., return_success=True) succeeds, but weight_loader returns False for experts that are not hosted on the current rank (the common case in distributed MoE loading). When this happens the for loop completes without a break, so the else clause executes and calls weight_loader(param, loaded_weight) without shard_id/expert_id. That generic path either raises TypeError or copies remote expert tensors into the wrong parameter, causing multi-rank DeepSeek checkpoints to fail. Previously the loop always broke after invoking weight_loader, allowing it to silently skip remote experts. The fallback should only trigger when the weight name does not correspond to an expert at all, not when the expert is simply non-local.

Useful? React with 👍 / 👎.

if name.endswith(".bias") and name not in params_dict:
continue

name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

# According to DeepSeek-V3 Technical Report, MTP modules
# shares embedding layer. We only load the first weights.
if (
spec_layer != self.model.mtp_start_layer_idx
and ".layers" not in name
):
continue

param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not is_fuse_shared_experts_layer:
loaded_params.add(name)
return loaded_params

def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
Expand Down
Loading