diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index aa176ef05fcc..ba3571d7755d 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable +import typing +from collections.abc import Callable, Iterable import torch import torch.nn as nn @@ -8,7 +9,10 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_fusion_shared_expert_enabled, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -212,11 +216,16 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, + num_experts=self.config.n_routed_experts + + ( + self.config.n_shared_experts + if is_rocm_aiter_fusion_shared_expert_enabled() + else 0 + ), ) params_dict = dict(self.named_parameters()) @@ -227,6 +236,10 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is None: continue + is_fuse_shared_experts_layer = ( + is_rocm_aiter_fusion_shared_expert_enabled() + and ("mlp.shared_experts" in name) + ) name = self._rewrite_spec_layer_name(spec_layer, name) for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -240,6 +253,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # for mlp.experts[0].gate_gate_up_proj, which breaks load. if ("mlp.experts." in name) and name not in params_dict: continue + if is_fuse_shared_experts_layer: + continue name_mapped = name.replace(weight_name, param_name) # QKV fusion is optional, fall back to normal @@ -260,45 +275,105 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: weight_loader(param, loaded_weight, shard_id) break else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - ) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - # According to DeepSeek-V3 Technical Report, MTP modules - # shares embedding layer. We only load the first weights. - if ( - spec_layer != self.model.mtp_start_layer_idx - and ".layers" not in name - ): - continue - - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader + # Special handling: when AITER fusion_shared_experts is enabled, + # checkpoints may provide a single widened shared_experts tensor + # without explicit expert indices + # (e.g. ...mlp.shared_experts.gate_proj.weight). + # For models with multiple shared experts, split that tensor + # evenly into per-shared-expert slices and load them into + # appended expert slots mlp.experts.{n_routed_experts + j}.* + # accordingly. + num_chunks = 1 + if is_fuse_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 + # Determine split axis based on op type + # gate/up: ColumnParallel → split along dim 0 + # down: RowParallel → split along dim 1 + split_dim = 1 if "down_proj.weight" in name else 0 + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} " + f"not divisible by num_chunks {num_chunks}" ) - weight_loader(param, loaded_weight) - loaded_params.add(name) + chunk_size = total // num_chunks + + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fuse_shared_experts_layer: + if split_dim == 0: + weight_to_load = loaded_weight[ + j * chunk_size : (j + 1) * chunk_size, : + ] + else: + weight_to_load = loaded_weight[ + :, j * chunk_size : (j + 1) * chunk_size + ] + # Synthesize an expert-style name so expert mapping + # can route it + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + # Use expert_params_mapping to locate the destination + # param and delegate to its expert-aware weight_loader + # with expert_id. + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = chunk_name.replace(weight_name, param_name) + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fuse_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if ( + spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name + ): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + if not is_fuse_shared_experts_layer: + loaded_params.add(name) return loaded_params def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: