Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 123 additions & 34 deletions vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Grok1 model."""
import typing
from collections.abc import Iterable
from typing import Optional, Union
from typing import Callable, Optional, Union

import torch
import torch.nn.functional as F
from torch import nn

from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Expand All @@ -48,7 +50,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -76,10 +78,29 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = ""):
prefix: str = "",
enable_eplb: bool = False) -> None:
super().__init__()
self.hidden_size = hidden_size

self.ep_group = get_ep_group().device_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function get_ep_group() returns a torch.distributed.ProcessGroup, which does not have a device_group attribute. This will cause an AttributeError at runtime. You should use get_ep_group() directly.

Suggested change
self.ep_group = get_ep_group().device_group
self.ep_group = get_ep_group()

self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
# No built-in separation between routed and shared experts.
self.n_routed_experts: int = num_experts
self.n_shared_experts: int = 0

# Load balancing settings.
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.enable_eplb = enable_eplb

self.n_redundant_experts = parallel_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The number of physical experts (self.n_physical_experts) should be divisible by the expert parallelism size (self.ep_size). If it's not, this integer division will lead to an incorrect number of local experts on some ranks, which can cause issues later. It's best to add an assertion to ensure this condition is met.

Suggested change
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
assert self.n_physical_experts % self.ep_size == 0, \
"The number of physical experts must be divisible by the expert parallel size."
self.n_local_physical_experts = self.n_physical_experts // self.ep_size


# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(hidden_size,
num_experts,
Expand All @@ -88,17 +109,21 @@ def __init__(self,
quant_config=None,
prefix=f"{prefix}.gate")

self.experts = FusedMoE(num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
activation="gelu",
prefix=f"{prefix}.experts")
self.experts = FusedMoE(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
tp_size=tp_size,
activation="gelu",
prefix=f"{prefix}.experts",
enable_eplb=enable_eplb,
num_redundant_experts=self.n_redundant_experts,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
Expand All @@ -110,6 +135,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = self.experts(hidden_states, router_logits)
return final_hidden_states.view(orig_shape)

def get_expert_weights(self) -> list[torch.Tensor]:
"""Returns the weights of all experts in this MoE block."""
return self.experts.get_expert_weights()


class Grok1Attention(nn.Module):

Expand Down Expand Up @@ -208,6 +237,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -242,7 +272,8 @@ def __init__(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.moe_block")
prefix=f"{prefix}.moe_block",
enable_eplb=enable_eplb)

self.pre_attn_norm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -293,6 +324,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
enable_eplb = vllm_config.parallel_config.enable_eplb

self.config = config
self.quant_config = quant_config
Expand All @@ -314,9 +346,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Grok1DecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
lambda prefix: Grok1DecoderLayer(config,
cache_config,
quant_config=quant_config,
prefix=prefix,
enable_eplb=enable_eplb),
prefix=f"{prefix}.layers")

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -376,7 +410,9 @@ def load_weights(self, weights: Iterable[tuple[str,
ckpt_gate_proj_name="linear", # Grok1 specific
ckpt_down_proj_name="linear_1", # Grok1 specific
ckpt_up_proj_name="linear_v", # Grok1 specific
num_experts=num_experts)
num_experts=num_experts,
num_redundant_experts=self.num_redundant_experts,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The attribute self.num_redundant_experts is used here but it is not defined on the Grok1Model class. It is defined on Grok1ForCausalLM, but load_weights is a method of Grok1Model. This will result in an AttributeError.

To fix this, you should initialize num_redundant_experts in Grok1Model.__init__ and set it as an attribute there, similar to how enable_eplb is handled.

)

params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
Expand Down Expand Up @@ -415,26 +451,47 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)

# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
if is_pp_missing_parameter(name_mapped, self):
continue
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
if ((name_mapped.endswith(".bias")
or name_mapped.endswith("_bias"))
and name_mapped not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break

param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue

# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
Expand All @@ -460,7 +517,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts):
fall_back_to_pt_during_load = False

packed_modules_mapping = {
Expand Down Expand Up @@ -510,6 +567,38 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

# All layers in Grok1 have MoE blocks.
self.expert_weights = [
layer.moe_block.get_expert_weights() for layer in self.model.layers
]

# Set MoE hyperparameters.
self.moe_layers: list[FusedMoE] = [
layer.moe_block.experts for layer in self.model.layers
]

example_moe = typing.cast(Grok1MoE, self.model.layers[0].moe_block)
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_redundant_experts = example_moe.n_redundant_experts
self.num_shared_experts = example_moe.n_shared_experts
Comment on lines +570 to +586
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block of code will fail when pipeline parallelism is enabled. self.model.layers can contain PPMissingLayer placeholders, which do not have a moe_block attribute. Accessing self.model.layers[0].moe_block directly is not safe.

You should iterate through the layers to find the first valid one to use as an example, and filter out PPMissingLayer when building expert_weights and moe_layers. A safer way is to check for the existence of moe_block or filter for layers that are not PPMissingLayer.

# All layers in Grok1 have MoE blocks.
        # NOTE: This may be empty if no layers are on this rank.
        local_layers = [
            layer for layer in self.model.layers
            if hasattr(layer, "moe_block")
        ]
        self.expert_weights = [
            layer.moe_block.get_expert_weights() for layer in local_layers
        ]

        # Set MoE hyperparameters.
        self.moe_layers: list[FusedMoE] = [
            layer.moe_block.experts for layer in local_layers
        ]

        if local_layers:
            example_moe = typing.cast(Grok1MoE, local_layers[0].moe_block)
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_redundant_experts = example_moe.n_redundant_experts
            self.num_shared_experts = example_moe.n_shared_experts
        else:
            # No layers on this rank, set defaults for the interface.
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_redundant_experts = 0
            self.num_shared_experts = 0


def set_eplb_state(
self,
expert_load_view: torch.Tensor,
logical_to_physical_map: torch.Tensor,
logical_replica_count: torch.Tensor,
) -> None:
for layer_idx, experts in enumerate(self.moe_layers):
experts.set_eplb_state(
moe_layer_idx=layer_idx,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand Down