Skip to content

improve MoE bias update logic in optimizer #1593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointImpl
from torch.distributed.checkpoint.state_dict import (
get_optimizer_state_dict,
set_optimizer_state_dict,
Expand Down Expand Up @@ -340,6 +341,9 @@ def build_optimizers_with_moe_load_balancing(
)

# for MoE auxiliary-loss-free load balancing
def _is_recomputation_enabled(module):
return getattr(module, "checkpoint_impl", None) is CheckpointImpl.NO_REENTRANT

def _update_expert_bias(
model_parts: list[nn.Module],
parallel_dims: ParallelDims,
Expand All @@ -349,25 +353,52 @@ def _update_expert_bias(
)
# TODO: Currently this sync is blocking (thus exposed) and happens on the
# default compute stream. Need to assess if this is OK performance-wise.
tokens_per_expert_list = []
for model_part in model_parts:
for transformer_block in model_part.layers.values():
if transformer_block.moe_enabled:
if not transformer_block.moe_enabled:
continue
if transformer_block.moe.load_balance_coeff is None:
return
tokens_per_expert = transformer_block.moe.tokens_per_expert
if _is_recomputation_enabled(transformer_block):
# TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice.
# This does not affect to expert choice, but affects the experts usage metrics.
# We divide by 2 to correct for this double-counting due to recomputation
# TODO: new API to help determine if AC is enabled https://github.com/pytorch/pytorch/pull/160888
tokens_per_expert = tokens_per_expert // 2
tokens_per_expert_list.append(tokens_per_expert)

tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list)

if dp_cp_mesh is not None:
# Perform single all-reduce to get global statistics across all processes
pg = dp_cp_mesh.get_group()
torch.distributed.all_reduce(
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
)

moe_layer_idx = 0
with torch.no_grad():
for model_part in model_parts:
for transformer_block in model_part.layers.values():
if not transformer_block.moe_enabled:
continue
moe = transformer_block.moe
if moe.load_balance_coeff is None:
return

if dp_cp_mesh is not None:
torch.distributed.all_reduce(
moe.tokens_per_expert, group=dp_cp_mesh.get_group()
)

with torch.no_grad():
expert_bias_delta = moe.load_balance_coeff * torch.sign(
moe.tokens_per_expert.mean() - moe.tokens_per_expert
)
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
moe.expert_bias.add_(expert_bias_delta)
moe.tokens_per_expert.zero_()

tokens_per_expert = tokens_per_expert_by_layer[
moe_layer_idx
].float()
moe_layer_idx += 1

# update the expert bias
# this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed
expert_bias_delta = moe.load_balance_coeff * torch.sign(
tokens_per_expert.mean() - tokens_per_expert
)
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
moe.expert_bias.add_(expert_bias_delta)
moe.tokens_per_expert.zero_()

optimizers.register_step_pre_hook(
lambda *args, **kwargs: _update_expert_bias(
Expand Down
27 changes: 14 additions & 13 deletions torchtitan/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,13 +350,14 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
torch.zeros(num_experts, dtype=torch.float32),
persistent=True,
)
self.register_buffer(
"tokens_per_expert",
torch.zeros(num_experts, dtype=torch.float32),
persistent=False,
)
else:
self.expert_bias = None
# tokens_per_expert will be used to track expert usage and to update the expert bias for load balancing
self.register_buffer(
"tokens_per_expert",
torch.zeros(num_experts, dtype=torch.float32),
persistent=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -378,12 +379,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
) = self.router(x, self.expert_bias)

# tokens_per_expert will be used to update the expert bias for load balancing.
# and also to count the expert usage
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
# first in the forward pass, and then in the backward pass. However, this has no
# effect on the expert bias update thanks to the torch.sign() operator.
if self.load_balance_coeff is not None:
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)
with torch.no_grad():
self.tokens_per_expert.add_(num_tokens_per_expert)

# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
# num_tokens_per_expert shape (num_experts,)
Expand Down Expand Up @@ -444,11 +445,11 @@ def init_weights(
if self.shared_experts is not None:
self.shared_experts.init_weights(init_std)

if self.load_balance_coeff is not None:
with torch.device(buffer_device):
with torch.device(buffer_device):
self.tokens_per_expert = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)
if self.load_balance_coeff is not None:
self.expert_bias = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)
self.tokens_per_expert = torch.zeros(
self.experts.num_experts, dtype=torch.float32
)
Loading