From 9a623b80f5df35eac91775792b87a14439e4ec51 Mon Sep 17 00:00:00 2001 From: wang55 Date: Tue, 19 Aug 2025 04:25:32 +0200 Subject: [PATCH] Enhance MoE bias update in optimizer We put all experts usage into a buffer such that we only need one reduce rather than #number-of-layers times Additionally handle cases where tokens per expert are counted twice during full recompute. (assume all moe layers have same number of experts) --- torchtitan/components/optimizer.py | 63 ++++++++++++++++++++++-------- torchtitan/models/moe.py | 27 +++++++------ 2 files changed, 61 insertions(+), 29 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index ce71ac7f0c..d3e9628103 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -9,6 +9,7 @@ import torch import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointImpl from torch.distributed.checkpoint.state_dict import ( get_optimizer_state_dict, set_optimizer_state_dict, @@ -340,6 +341,9 @@ def build_optimizers_with_moe_load_balancing( ) # for MoE auxiliary-loss-free load balancing + def _is_recomputation_enabled(module): + return getattr(module, "checkpoint_impl", None) is CheckpointImpl.NO_REENTRANT + def _update_expert_bias( model_parts: list[nn.Module], parallel_dims: ParallelDims, @@ -349,25 +353,52 @@ def _update_expert_bias( ) # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. + tokens_per_expert_list = [] for model_part in model_parts: for transformer_block in model_part.layers.values(): - if transformer_block.moe_enabled: + if not transformer_block.moe_enabled: + continue + if transformer_block.moe.load_balance_coeff is None: + return + tokens_per_expert = transformer_block.moe.tokens_per_expert + if _is_recomputation_enabled(transformer_block): + # TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice. + # This does not affect to expert choice, but affects the experts usage metrics. + # We divide by 2 to correct for this double-counting due to recomputation + # TODO: new API to help determine if AC is enabled https://github.com/pytorch/pytorch/pull/160888 + tokens_per_expert = tokens_per_expert // 2 + tokens_per_expert_list.append(tokens_per_expert) + + tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) + + if dp_cp_mesh is not None: + # Perform single all-reduce to get global statistics across all processes + pg = dp_cp_mesh.get_group() + torch.distributed.all_reduce( + tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM + ) + + moe_layer_idx = 0 + with torch.no_grad(): + for model_part in model_parts: + for transformer_block in model_part.layers.values(): + if not transformer_block.moe_enabled: + continue moe = transformer_block.moe - if moe.load_balance_coeff is None: - return - - if dp_cp_mesh is not None: - torch.distributed.all_reduce( - moe.tokens_per_expert, group=dp_cp_mesh.get_group() - ) - - with torch.no_grad(): - expert_bias_delta = moe.load_balance_coeff * torch.sign( - moe.tokens_per_expert.mean() - moe.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - moe.expert_bias.add_(expert_bias_delta) - moe.tokens_per_expert.zero_() + + tokens_per_expert = tokens_per_expert_by_layer[ + moe_layer_idx + ].float() + moe_layer_idx += 1 + + # update the expert bias + # this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed + expert_bias_delta = moe.load_balance_coeff * torch.sign( + tokens_per_expert.mean() - tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() optimizers.register_step_pre_hook( lambda *args, **kwargs: _update_expert_bias( diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 40bd6c2cca..0d63a30c6e 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -350,13 +350,14 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): torch.zeros(num_experts, dtype=torch.float32), persistent=True, ) - self.register_buffer( - "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), - persistent=False, - ) else: self.expert_bias = None + # tokens_per_expert will be used to track expert usage and to update the expert bias for load balancing + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=False, + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -378,12 +379,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) = self.router(x, self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- # first in the forward pass, and then in the backward pass. However, this has no # effect on the expert bias update thanks to the torch.sign() operator. - if self.load_balance_coeff is not None: - with torch.no_grad(): - self.tokens_per_expert.add_(num_tokens_per_expert) + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) @@ -444,11 +445,11 @@ def init_weights( if self.shared_experts is not None: self.shared_experts.init_weights(init_std) - if self.load_balance_coeff is not None: - with torch.device(buffer_device): + with torch.device(buffer_device): + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + if self.load_balance_coeff is not None: self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - )