diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 2d930527b2be..5d7d1bf0c596 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -22,8 +22,9 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Grok1 model.""" +import typing from collections.abc import Iterable -from typing import Optional, Union +from typing import Callable, Optional, Union import torch import torch.nn.functional as F @@ -31,8 +32,9 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -48,7 +50,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -76,10 +78,29 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, - prefix: str = ""): + prefix: str = "", + enable_eplb: bool = False) -> None: super().__init__() self.hidden_size = hidden_size + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + # No built-in separation between routed and shared experts. + self.n_routed_experts: int = num_experts + self.n_shared_experts: int = 0 + + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(hidden_size, num_experts, @@ -88,17 +109,21 @@ def __init__(self, quant_config=None, prefix=f"{prefix}.gate") - self.experts = FusedMoE(num_experts=num_experts, - top_k=top_k, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - params_dtype=params_dtype, - reduce_results=True, - renormalize=True, - quant_config=quant_config, - tp_size=tp_size, - activation="gelu", - prefix=f"{prefix}.experts") + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + activation="gelu", + prefix=f"{prefix}.experts", + enable_eplb=enable_eplb, + num_redundant_experts=self.n_redundant_experts, + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -110,6 +135,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) + def get_expert_weights(self) -> list[torch.Tensor]: + """Returns the weights of all experts in this MoE block.""" + return self.experts.get_expert_weights() + class Grok1Attention(nn.Module): @@ -208,6 +237,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + enable_eplb: bool = False, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -242,7 +272,8 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant_config=quant_config, - prefix=f"{prefix}.moe_block") + prefix=f"{prefix}.moe_block", + enable_eplb=enable_eplb) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -293,6 +324,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config + enable_eplb = vllm_config.parallel_config.enable_eplb self.config = config self.quant_config = quant_config @@ -314,9 +346,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Grok1DecoderLayer( - config, cache_config, quant_config=quant_config, prefix=prefix - ), + lambda prefix: Grok1DecoderLayer(config, + cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -376,7 +410,9 @@ def load_weights(self, weights: Iterable[tuple[str, ckpt_gate_proj_name="linear", # Grok1 specific ckpt_down_proj_name="linear_1", # Grok1 specific ckpt_up_proj_name="linear_v", # Grok1 specific - num_experts=num_experts) + num_experts=num_experts, + num_redundant_experts=self.num_redundant_experts, + ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -415,26 +451,47 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight, shard_id) break else: + is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue - name = name.replace(weight_name, param_name) + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + name_mapped = name.replace(weight_name, param_name) + # Skip layers on other devices. - if is_pp_missing_parameter(name, self): + if is_pp_missing_parameter(name_mapped, self): continue - if ((name.endswith(".bias") or name.endswith("_bias")) - and name not in params_dict): + if ((name_mapped.endswith(".bias") + or name_mapped.endswith("_bias")) + and name_mapped not in params_dict): continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. if ((name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict): @@ -460,7 +517,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -510,6 +567,38 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + # All layers in Grok1 have MoE blocks. + self.expert_weights = [ + layer.moe_block.get_expert_weights() for layer in self.model.layers + ] + + # Set MoE hyperparameters. + self.moe_layers: list[FusedMoE] = [ + layer.moe_block.experts for layer in self.model.layers + ] + + example_moe = typing.cast(Grok1MoE, self.model.layers[0].moe_block) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_redundant_experts = example_moe.n_redundant_experts + self.num_shared_experts = example_moe.n_shared_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, experts in enumerate(self.moe_layers): + experts.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids)