Skip to content

Commit cb8c9b8

Browse files
wang55wang55
authored andcommitted
Enhance MoE bias update in optimizer
We put all experts usage into a buffer such that we only need one reduce rather than #number-of-layers times Additionally handle cases where tokens per expert are counted twice during full recompute. (assume all moe layers have same number of experts) count the expert usage as well for MoE
1 parent f9e8897 commit cb8c9b8

File tree

2 files changed

+79
-28
lines changed

2 files changed

+79
-28
lines changed

torchtitan/components/optimizer.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torchtitan.components.ft import FTManager, has_torchft
2121
from torchtitan.config import Optimizer as OptimizerConfig
2222
from torchtitan.distributed import ParallelDims
23+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointImpl
2324

2425
__all__ = [
2526
"OptimizersContainer",
@@ -340,6 +341,9 @@ def build_optimizers_with_moe_load_balancing(
340341
)
341342

342343
# for MoE auxiliary-loss-free load balancing
344+
def is_full_recompute(module):
345+
return getattr(module, "checkpoint_impl", None) is CheckpointImpl.NO_REENTRANT
346+
343347
def _update_expert_bias(
344348
model_parts: list[nn.Module],
345349
parallel_dims: ParallelDims,
@@ -349,25 +353,71 @@ def _update_expert_bias(
349353
)
350354
# TODO: Currently this sync is blocking (thus exposed) and happens on the
351355
# default compute stream. Need to assess if this is OK performance-wise.
356+
tokens_per_expert_list = []
352357
for model_part in model_parts:
353358
for transformer_block in model_part.layers.values():
354-
if transformer_block.moe_enabled:
359+
if not transformer_block.moe_enabled:
360+
continue
361+
moe = transformer_block.moe
362+
tokens_per_expert = transformer_block.moe.tokens_per_expert
363+
if is_full_recompute(transformer_block):
364+
# TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice.
365+
# This does not affect to expert choice, but affects the experts usage metrics.
366+
# We divide by 2 to correct for this double-counting due to recomputation
367+
# TODO: new API to help determine if AC is enabled https://github.com/pytorch/pytorch/pull/160888
368+
tokens_per_expert = tokens_per_expert // 2
369+
tokens_per_expert_list.append(tokens_per_expert)
370+
371+
if len(tokens_per_expert_list) == 0:
372+
# avoid cat empty tensor
373+
return
374+
375+
n_expert = tokens_per_expert_list[0].numel()
376+
assert all(
377+
t.numel() == n_expert for t in tokens_per_expert_list
378+
), "All MoE layers must have the same number of experts."
379+
380+
# [n_layers, n_expert], int32
381+
tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list)
382+
383+
if dp_cp_mesh is not None:
384+
# Perform single all-reduce to get global statistics across all processes
385+
pg = dp_cp_mesh.get_group()
386+
torch.distributed.all_reduce(
387+
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
388+
)
389+
390+
moe_layer_idx = 0
391+
with torch.no_grad():
392+
for model_part in model_parts:
393+
for layer_id, transformer_block in enumerate(
394+
model_part.layers.values()
395+
):
396+
if not transformer_block.moe_enabled:
397+
continue
355398
moe = transformer_block.moe
399+
400+
tokens_per_expert = tokens_per_expert_by_layer[
401+
moe_layer_idx
402+
].float()
403+
moe_layer_idx += 1
404+
# uncomment to log expert usage once we fix https://github.com/pytorch/torchtitan/pull/1578
405+
# sum_tokens = tokens_per_expert.sum().clamp(min=1.0)
406+
# expert_usage_metrics = {
407+
# f"moe_ep_usage/L-{layer_id}_EP-{ep_idx}": usage / sum_tokens
408+
# for ep_idx, usage in enumerate(tokens_per_expert)
409+
# }
410+
356411
if moe.load_balance_coeff is None:
357-
return
358-
359-
if dp_cp_mesh is not None:
360-
torch.distributed.all_reduce(
361-
moe.tokens_per_expert, group=dp_cp_mesh.get_group()
362-
)
363-
364-
with torch.no_grad():
365-
expert_bias_delta = moe.load_balance_coeff * torch.sign(
366-
moe.tokens_per_expert.mean() - moe.tokens_per_expert
367-
)
368-
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
369-
moe.expert_bias.add_(expert_bias_delta)
370-
moe.tokens_per_expert.zero_()
412+
continue
413+
# update the expert bias
414+
# this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed
415+
expert_bias_delta = moe.load_balance_coeff * torch.sign(
416+
tokens_per_expert.mean() - tokens_per_expert
417+
)
418+
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
419+
moe.expert_bias.add_(expert_bias_delta)
420+
moe.tokens_per_expert.zero_()
371421

372422
optimizers.register_step_pre_hook(
373423
lambda *args, **kwargs: _update_expert_bias(

torchtitan/models/moe.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,14 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
350350
torch.zeros(num_experts, dtype=torch.float32),
351351
persistent=True,
352352
)
353-
self.register_buffer(
354-
"tokens_per_expert",
355-
torch.zeros(num_experts, dtype=torch.float32),
356-
persistent=False,
357-
)
358353
else:
359354
self.expert_bias = None
355+
# We create tokens_per_expert buffer anyhow to help us conunt the expert usage
356+
self.register_buffer(
357+
"tokens_per_expert",
358+
torch.zeros(num_experts, dtype=torch.float32),
359+
persistent=False,
360+
)
360361

361362
def forward(self, x: torch.Tensor) -> torch.Tensor:
362363
"""
@@ -378,12 +379,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
378379
) = self.router(x, self.expert_bias)
379380

380381
# tokens_per_expert will be used to update the expert bias for load balancing.
382+
# and also to count the expert usage
381383
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
382384
# first in the forward pass, and then in the backward pass. However, this has no
383385
# effect on the expert bias update thanks to the torch.sign() operator.
384-
if self.load_balance_coeff is not None:
385-
with torch.no_grad():
386-
self.tokens_per_expert.add_(num_tokens_per_expert)
386+
with torch.no_grad():
387+
self.tokens_per_expert.add_(num_tokens_per_expert)
387388

388389
# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
389390
# num_tokens_per_expert shape (num_experts,)
@@ -444,11 +445,11 @@ def init_weights(
444445
if self.shared_experts is not None:
445446
self.shared_experts.init_weights(init_std)
446447

447-
if self.load_balance_coeff is not None:
448-
with torch.device(buffer_device):
448+
with torch.device(buffer_device):
449+
self.tokens_per_expert = torch.zeros(
450+
self.experts.num_experts, dtype=torch.float32
451+
)
452+
if self.load_balance_coeff is not None:
449453
self.expert_bias = torch.zeros(
450454
self.experts.num_experts, dtype=torch.float32
451455
)
452-
self.tokens_per_expert = torch.zeros(
453-
self.experts.num_experts, dtype=torch.float32
454-
)

0 commit comments

Comments
 (0)