-
-
Notifications
You must be signed in to change notification settings - Fork 10.2k
[EPLB]: Add EPLB support for Grok1 [WIP] #21273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -22,17 +22,19 @@ | |||||||||
# See the License for the specific language governing permissions and | ||||||||||
# limitations under the License. | ||||||||||
"""Inference-only Grok1 model.""" | ||||||||||
import typing | ||||||||||
from collections.abc import Iterable | ||||||||||
from typing import Optional, Union | ||||||||||
from typing import Callable, Optional, Union | ||||||||||
|
||||||||||
import torch | ||||||||||
import torch.nn.functional as F | ||||||||||
from torch import nn | ||||||||||
|
||||||||||
from vllm.attention import Attention | ||||||||||
from vllm.compilation.decorators import support_torch_compile | ||||||||||
from vllm.config import CacheConfig, VllmConfig | ||||||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size | ||||||||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config | ||||||||||
from vllm.distributed import (get_ep_group, get_pp_group, | ||||||||||
get_tensor_model_parallel_world_size) | ||||||||||
from vllm.model_executor.layers.fused_moe import FusedMoE | ||||||||||
from vllm.model_executor.layers.layernorm import RMSNorm | ||||||||||
from vllm.model_executor.layers.linear import (QKVParallelLinear, | ||||||||||
|
@@ -48,7 +50,7 @@ | |||||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||||||||||
from vllm.sequence import IntermediateTensors | ||||||||||
|
||||||||||
from .interfaces import SupportsLoRA, SupportsPP | ||||||||||
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP | ||||||||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, | ||||||||||
make_empty_intermediate_tensors_factory, make_layers, | ||||||||||
maybe_prefix) | ||||||||||
|
@@ -76,10 +78,29 @@ def __init__(self, | |||||||||
params_dtype: Optional[torch.dtype] = None, | ||||||||||
quant_config: Optional[QuantizationConfig] = None, | ||||||||||
tp_size: Optional[int] = None, | ||||||||||
prefix: str = ""): | ||||||||||
prefix: str = "", | ||||||||||
enable_eplb: bool = False) -> None: | ||||||||||
super().__init__() | ||||||||||
self.hidden_size = hidden_size | ||||||||||
|
||||||||||
self.ep_group = get_ep_group().device_group | ||||||||||
self.ep_rank = self.ep_group.rank() | ||||||||||
self.ep_size = self.ep_group.size() | ||||||||||
# No built-in separation between routed and shared experts. | ||||||||||
self.n_routed_experts: int = num_experts | ||||||||||
self.n_shared_experts: int = 0 | ||||||||||
|
||||||||||
# Load balancing settings. | ||||||||||
vllm_config = get_current_vllm_config() | ||||||||||
parallel_config = vllm_config.parallel_config | ||||||||||
self.enable_eplb = enable_eplb | ||||||||||
|
||||||||||
self.n_redundant_experts = parallel_config.num_redundant_experts | ||||||||||
self.n_logical_experts = self.n_routed_experts | ||||||||||
self.n_physical_experts = (self.n_logical_experts + | ||||||||||
self.n_redundant_experts) | ||||||||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The number of physical experts (
Suggested change
|
||||||||||
|
||||||||||
# Gate always runs at half / full precision for now. | ||||||||||
self.gate = ReplicatedLinear(hidden_size, | ||||||||||
num_experts, | ||||||||||
|
@@ -88,17 +109,21 @@ def __init__(self, | |||||||||
quant_config=None, | ||||||||||
prefix=f"{prefix}.gate") | ||||||||||
|
||||||||||
self.experts = FusedMoE(num_experts=num_experts, | ||||||||||
top_k=top_k, | ||||||||||
hidden_size=hidden_size, | ||||||||||
intermediate_size=intermediate_size, | ||||||||||
params_dtype=params_dtype, | ||||||||||
reduce_results=True, | ||||||||||
renormalize=True, | ||||||||||
quant_config=quant_config, | ||||||||||
tp_size=tp_size, | ||||||||||
activation="gelu", | ||||||||||
prefix=f"{prefix}.experts") | ||||||||||
self.experts = FusedMoE( | ||||||||||
num_experts=num_experts, | ||||||||||
top_k=top_k, | ||||||||||
hidden_size=hidden_size, | ||||||||||
intermediate_size=intermediate_size, | ||||||||||
params_dtype=params_dtype, | ||||||||||
reduce_results=True, | ||||||||||
renormalize=True, | ||||||||||
quant_config=quant_config, | ||||||||||
tp_size=tp_size, | ||||||||||
activation="gelu", | ||||||||||
prefix=f"{prefix}.experts", | ||||||||||
enable_eplb=enable_eplb, | ||||||||||
num_redundant_experts=self.n_redundant_experts, | ||||||||||
) | ||||||||||
|
||||||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||||||||||
# NOTE: hidden_states can have either 1D or 2D shape. | ||||||||||
|
@@ -110,6 +135,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |||||||||
final_hidden_states = self.experts(hidden_states, router_logits) | ||||||||||
return final_hidden_states.view(orig_shape) | ||||||||||
|
||||||||||
def get_expert_weights(self) -> list[torch.Tensor]: | ||||||||||
"""Returns the weights of all experts in this MoE block.""" | ||||||||||
return self.experts.get_expert_weights() | ||||||||||
|
||||||||||
|
||||||||||
class Grok1Attention(nn.Module): | ||||||||||
|
||||||||||
|
@@ -208,6 +237,7 @@ def __init__( | |||||||||
cache_config: Optional[CacheConfig] = None, | ||||||||||
quant_config: Optional[QuantizationConfig] = None, | ||||||||||
prefix: str = "", | ||||||||||
enable_eplb: bool = False, | ||||||||||
) -> None: | ||||||||||
super().__init__() | ||||||||||
self.hidden_size = config.hidden_size | ||||||||||
|
@@ -242,7 +272,8 @@ def __init__( | |||||||||
hidden_size=config.hidden_size, | ||||||||||
intermediate_size=config.intermediate_size, | ||||||||||
quant_config=quant_config, | ||||||||||
prefix=f"{prefix}.moe_block") | ||||||||||
prefix=f"{prefix}.moe_block", | ||||||||||
enable_eplb=enable_eplb) | ||||||||||
|
||||||||||
self.pre_attn_norm = RMSNorm(config.hidden_size, | ||||||||||
eps=config.rms_norm_eps) | ||||||||||
|
@@ -293,6 +324,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |||||||||
cache_config = vllm_config.cache_config | ||||||||||
quant_config = vllm_config.quant_config | ||||||||||
lora_config = vllm_config.lora_config | ||||||||||
enable_eplb = vllm_config.parallel_config.enable_eplb | ||||||||||
|
||||||||||
self.config = config | ||||||||||
self.quant_config = quant_config | ||||||||||
|
@@ -314,9 +346,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |||||||||
|
||||||||||
self.start_layer, self.end_layer, self.layers = make_layers( | ||||||||||
config.num_hidden_layers, | ||||||||||
lambda prefix: Grok1DecoderLayer( | ||||||||||
config, cache_config, quant_config=quant_config, prefix=prefix | ||||||||||
), | ||||||||||
lambda prefix: Grok1DecoderLayer(config, | ||||||||||
cache_config, | ||||||||||
quant_config=quant_config, | ||||||||||
prefix=prefix, | ||||||||||
enable_eplb=enable_eplb), | ||||||||||
prefix=f"{prefix}.layers") | ||||||||||
|
||||||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||||||||
|
@@ -376,7 +410,9 @@ def load_weights(self, weights: Iterable[tuple[str, | |||||||||
ckpt_gate_proj_name="linear", # Grok1 specific | ||||||||||
ckpt_down_proj_name="linear_1", # Grok1 specific | ||||||||||
ckpt_up_proj_name="linear_v", # Grok1 specific | ||||||||||
num_experts=num_experts) | ||||||||||
num_experts=num_experts, | ||||||||||
num_redundant_experts=self.num_redundant_experts, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The attribute To fix this, you should initialize |
||||||||||
) | ||||||||||
|
||||||||||
params_dict = dict(self.named_parameters()) | ||||||||||
loaded_params: set[str] = set() | ||||||||||
|
@@ -415,26 +451,47 @@ def load_weights(self, weights: Iterable[tuple[str, | |||||||||
weight_loader(param, loaded_weight, shard_id) | ||||||||||
break | ||||||||||
else: | ||||||||||
is_expert_weight = False | ||||||||||
for mapping in expert_params_mapping: | ||||||||||
param_name, weight_name, expert_id, shard_id = mapping | ||||||||||
if weight_name not in name: | ||||||||||
continue | ||||||||||
name = name.replace(weight_name, param_name) | ||||||||||
|
||||||||||
# Anyway, this is an expert weight and should not be | ||||||||||
# attempted to load as other weights later | ||||||||||
is_expert_weight = True | ||||||||||
name_mapped = name.replace(weight_name, param_name) | ||||||||||
|
||||||||||
# Skip layers on other devices. | ||||||||||
if is_pp_missing_parameter(name, self): | ||||||||||
if is_pp_missing_parameter(name_mapped, self): | ||||||||||
continue | ||||||||||
if ((name.endswith(".bias") or name.endswith("_bias")) | ||||||||||
and name not in params_dict): | ||||||||||
if ((name_mapped.endswith(".bias") | ||||||||||
or name_mapped.endswith("_bias")) | ||||||||||
and name_mapped not in params_dict): | ||||||||||
continue | ||||||||||
param = params_dict[name] | ||||||||||
weight_loader = param.weight_loader | ||||||||||
weight_loader(param, | ||||||||||
loaded_weight, | ||||||||||
name, | ||||||||||
shard_id=shard_id, | ||||||||||
expert_id=expert_id) | ||||||||||
break | ||||||||||
|
||||||||||
param = params_dict[name_mapped] | ||||||||||
# We should ask the weight loader to return success or not | ||||||||||
# here since otherwise we may skip experts with other | ||||||||||
# available replicas. | ||||||||||
weight_loader = typing.cast(Callable[..., bool], | ||||||||||
param.weight_loader) | ||||||||||
success = weight_loader(param, | ||||||||||
loaded_weight, | ||||||||||
name_mapped, | ||||||||||
shard_id=shard_id, | ||||||||||
expert_id=expert_id, | ||||||||||
return_success=True) | ||||||||||
if success: | ||||||||||
name = name_mapped | ||||||||||
break | ||||||||||
else: | ||||||||||
if is_expert_weight: | ||||||||||
# We've checked that this is an expert weight | ||||||||||
# However it's not mapped locally to this rank | ||||||||||
# So we simply skip it | ||||||||||
continue | ||||||||||
|
||||||||||
# Skip loading extra bias for GPTQ models. | ||||||||||
if ((name.endswith(".bias") or name.endswith("_bias")) | ||||||||||
and name not in params_dict): | ||||||||||
|
@@ -460,7 +517,7 @@ def load_weights(self, weights: Iterable[tuple[str, | |||||||||
return loaded_params | ||||||||||
|
||||||||||
|
||||||||||
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): | ||||||||||
class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): | ||||||||||
fall_back_to_pt_during_load = False | ||||||||||
|
||||||||||
packed_modules_mapping = { | ||||||||||
|
@@ -510,6 +567,38 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |||||||||
self.make_empty_intermediate_tensors = ( | ||||||||||
self.model.make_empty_intermediate_tensors) | ||||||||||
|
||||||||||
# All layers in Grok1 have MoE blocks. | ||||||||||
self.expert_weights = [ | ||||||||||
layer.moe_block.get_expert_weights() for layer in self.model.layers | ||||||||||
] | ||||||||||
|
||||||||||
# Set MoE hyperparameters. | ||||||||||
self.moe_layers: list[FusedMoE] = [ | ||||||||||
layer.moe_block.experts for layer in self.model.layers | ||||||||||
] | ||||||||||
|
||||||||||
example_moe = typing.cast(Grok1MoE, self.model.layers[0].moe_block) | ||||||||||
self.num_logical_experts = example_moe.n_logical_experts | ||||||||||
self.num_physical_experts = example_moe.n_physical_experts | ||||||||||
self.num_local_physical_experts = example_moe.n_local_physical_experts | ||||||||||
self.num_routed_experts = example_moe.n_routed_experts | ||||||||||
self.num_redundant_experts = example_moe.n_redundant_experts | ||||||||||
self.num_shared_experts = example_moe.n_shared_experts | ||||||||||
Comment on lines
+570
to
+586
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block of code will fail when pipeline parallelism is enabled. You should iterate through the layers to find the first valid one to use as an example, and filter out # All layers in Grok1 have MoE blocks.
# NOTE: This may be empty if no layers are on this rank.
local_layers = [
layer for layer in self.model.layers
if hasattr(layer, "moe_block")
]
self.expert_weights = [
layer.moe_block.get_expert_weights() for layer in local_layers
]
# Set MoE hyperparameters.
self.moe_layers: list[FusedMoE] = [
layer.moe_block.experts for layer in local_layers
]
if local_layers:
example_moe = typing.cast(Grok1MoE, local_layers[0].moe_block)
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_redundant_experts = example_moe.n_redundant_experts
self.num_shared_experts = example_moe.n_shared_experts
else:
# No layers on this rank, set defaults for the interface.
self.num_logical_experts = 0
self.num_physical_experts = 0
self.num_local_physical_experts = 0
self.num_routed_experts = 0
self.num_redundant_experts = 0
self.num_shared_experts = 0 |
||||||||||
|
||||||||||
def set_eplb_state( | ||||||||||
self, | ||||||||||
expert_load_view: torch.Tensor, | ||||||||||
logical_to_physical_map: torch.Tensor, | ||||||||||
logical_replica_count: torch.Tensor, | ||||||||||
) -> None: | ||||||||||
for layer_idx, experts in enumerate(self.moe_layers): | ||||||||||
experts.set_eplb_state( | ||||||||||
moe_layer_idx=layer_idx, | ||||||||||
expert_load_view=expert_load_view, | ||||||||||
logical_to_physical_map=logical_to_physical_map, | ||||||||||
logical_replica_count=logical_replica_count, | ||||||||||
) | ||||||||||
|
||||||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||||||||||
return self.model.get_input_embeddings(input_ids) | ||||||||||
|
||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function
get_ep_group()
returns atorch.distributed.ProcessGroup
, which does not have adevice_group
attribute. This will cause anAttributeError
at runtime. You should useget_ep_group()
directly.