Skip to content

Commit f340bcb

Browse files
author
wang55
committed
add comments, rename ac check func name
1 parent cb8c9b8 commit f340bcb

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

torchtitan/components/optimizer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def build_optimizers_with_moe_load_balancing(
341341
)
342342

343343
# for MoE auxiliary-loss-free load balancing
344-
def is_full_recompute(module):
344+
def _is_recomputation_enabled(module):
345345
return getattr(module, "checkpoint_impl", None) is CheckpointImpl.NO_REENTRANT
346346

347347
def _update_expert_bias(
@@ -358,9 +358,8 @@ def _update_expert_bias(
358358
for transformer_block in model_part.layers.values():
359359
if not transformer_block.moe_enabled:
360360
continue
361-
moe = transformer_block.moe
362361
tokens_per_expert = transformer_block.moe.tokens_per_expert
363-
if is_full_recompute(transformer_block):
362+
if _is_recomputation_enabled(transformer_block):
364363
# TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice.
365364
# This does not affect to expert choice, but affects the experts usage metrics.
366365
# We divide by 2 to correct for this double-counting due to recomputation
@@ -372,12 +371,6 @@ def _update_expert_bias(
372371
# avoid cat empty tensor
373372
return
374373

375-
n_expert = tokens_per_expert_list[0].numel()
376-
assert all(
377-
t.numel() == n_expert for t in tokens_per_expert_list
378-
), "All MoE layers must have the same number of experts."
379-
380-
# [n_layers, n_expert], int32
381374
tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list)
382375

383376
if dp_cp_mesh is not None:

torchtitan/models/moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int):
352352
)
353353
else:
354354
self.expert_bias = None
355-
# We create tokens_per_expert buffer anyhow to help us conunt the expert usage
355+
# tokens_per_expert will be used to track expert usage and to update the expert bias for load balancing
356356
self.register_buffer(
357357
"tokens_per_expert",
358358
torch.zeros(num_experts, dtype=torch.float32),

0 commit comments

Comments
 (0)