Skip to content

Commit cb4e41a

Browse files
author
wang55
committed
count the expert usage as well for MoE
1 parent d791c85 commit cb4e41a

File tree

2 files changed

+32
-27
lines changed

2 files changed

+32
-27
lines changed

torchtitan/components/optimizer.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -359,15 +359,12 @@ def _update_expert_bias(
359359
if not transformer_block.moe_enabled:
360360
continue
361361
moe = transformer_block.moe
362-
if moe.load_balance_coeff is None:
363-
continue
364362
tokens_per_expert = transformer_block.moe.tokens_per_expert
365-
if is_full_recompute(transformer_block.moe) or is_full_recompute(
366-
transformer_block
367-
):
363+
if is_full_recompute(transformer_block):
368364
# TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice.
369365
# This does not affect to expert choice, but affects the experts usage metrics.
370366
# We divide by 2 to correct for this double-counting due to recomputation
367+
# TODO: new API to help determine if AC is enabled https://github.com/pytorch/pytorch/pull/160888
371368
tokens_per_expert = tokens_per_expert // 2
372369
tokens_per_expert_list.append(tokens_per_expert)
373370

@@ -390,30 +387,37 @@ def _update_expert_bias(
390387
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
391388
)
392389

393-
layer_idx = 0
390+
moe_layer_idx = 0
394391
with torch.no_grad():
395392
for model_part in model_parts:
396-
for transformer_block in model_part.layers.values():
393+
for layer_id, transformer_block in enumerate(
394+
model_part.layers.values()
395+
):
397396
if not transformer_block.moe_enabled:
398397
continue
399398
moe = transformer_block.moe
400399

400+
tokens_per_expert = tokens_per_expert_by_layer[
401+
moe_layer_idx
402+
].float()
403+
moe_layer_idx += 1
404+
# uncomment to log expert usage once we fix https://github.com/pytorch/torchtitan/pull/1578
405+
# sum_tokens = tokens_per_expert.sum().clamp(min=1.0)
406+
# expert_usage_metrics = {
407+
# f"moe_ep_usage/L-{layer_id}_EP-{ep_idx}": usage / sum_tokens
408+
# for ep_idx, usage in enumerate(tokens_per_expert)
409+
# }
410+
401411
if moe.load_balance_coeff is None:
402412
continue
403-
404-
tokens_per_expert = tokens_per_expert_by_layer[layer_idx].float()
405-
layer_idx += 1
406-
407413
# update the expert bias
408-
# https://github.com/pytorch/torchtitan/issues/1506
409414
# this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed
410-
expert_bias_delta = load_balance_coeff * torch.sign(
415+
expert_bias_delta = moe.load_balance_coeff * torch.sign(
411416
tokens_per_expert.mean() - tokens_per_expert
412417
)
413418
expert_bias_delta = expert_bias_delta - expert_bias_delta.mean()
414419
moe.expert_bias.add_(expert_bias_delta)
415420
moe.tokens_per_expert.zero_()
416-
# placeholder to record and log the expert usage
417421

418422
optimizers.register_step_pre_hook(
419423
lambda *args, **kwargs: _update_expert_bias(

torchtitan/models/moe.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,14 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
350350
torch.zeros(num_experts, dtype=torch.float32),
351351
persistent=True,
352352
)
353-
self.register_buffer(
354-
"tokens_per_expert",
355-
torch.zeros(num_experts, dtype=torch.float32),
356-
persistent=False,
357-
)
358353
else:
359354
self.expert_bias = None
355+
# We create tokens_per_expert buffer anyhow to help us conunt the expert usage
356+
self.register_buffer(
357+
"tokens_per_expert",
358+
torch.zeros(num_experts, dtype=torch.float32),
359+
persistent=False,
360+
)
360361

361362
def forward(self, x: torch.Tensor) -> torch.Tensor:
362363
"""
@@ -378,12 +379,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
378379
) = self.router(x, self.expert_bias)
379380

380381
# tokens_per_expert will be used to update the expert bias for load balancing.
382+
# and also to count the expert usage
381383
# TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert --
382384
# first in the forward pass, and then in the backward pass. However, this has no
383385
# effect on the expert bias update thanks to the torch.sign() operator.
384-
if self.load_balance_coeff is not None:
385-
with torch.no_grad():
386-
self.tokens_per_expert.add_(num_tokens_per_expert)
386+
with torch.no_grad():
387+
self.tokens_per_expert.add_(num_tokens_per_expert)
387388

388389
# top_scores and token_indices_experts_sorted shape (bs*slen*top_k,)
389390
# num_tokens_per_expert shape (num_experts,)
@@ -444,11 +445,11 @@ def init_weights(
444445
if self.shared_experts is not None:
445446
self.shared_experts.init_weights(init_std)
446447

447-
if self.load_balance_coeff is not None:
448-
with torch.device(buffer_device):
448+
with torch.device(buffer_device):
449+
self.tokens_per_expert = torch.zeros(
450+
self.experts.num_experts, dtype=torch.float32
451+
)
452+
if self.load_balance_coeff is not None:
449453
self.expert_bias = torch.zeros(
450454
self.experts.num_experts, dtype=torch.float32
451455
)
452-
self.tokens_per_expert = torch.zeros(
453-
self.experts.num_experts, dtype=torch.float32
454-
)

0 commit comments

Comments
 (0)