20
20
from torchtitan .components .ft import FTManager , has_torchft
21
21
from torchtitan .config import Optimizer as OptimizerConfig
22
22
from torchtitan .distributed import ParallelDims
23
+ from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import CheckpointImpl
23
24
24
25
__all__ = [
25
26
"OptimizersContainer" ,
@@ -340,6 +341,9 @@ def build_optimizers_with_moe_load_balancing(
340
341
)
341
342
342
343
# for MoE auxiliary-loss-free load balancing
344
+ def is_full_recompute (module ):
345
+ return getattr (module , "checkpoint_impl" , None ) is CheckpointImpl .NO_REENTRANT
346
+
343
347
def _update_expert_bias (
344
348
model_parts : list [nn .Module ],
345
349
parallel_dims : ParallelDims ,
@@ -349,25 +353,71 @@ def _update_expert_bias(
349
353
)
350
354
# TODO: Currently this sync is blocking (thus exposed) and happens on the
351
355
# default compute stream. Need to assess if this is OK performance-wise.
356
+ tokens_per_expert_list = []
352
357
for model_part in model_parts :
353
358
for transformer_block in model_part .layers .values ():
354
- if transformer_block .moe_enabled :
359
+ if not transformer_block .moe_enabled :
360
+ continue
361
+ moe = transformer_block .moe
362
+ tokens_per_expert = transformer_block .moe .tokens_per_expert
363
+ if is_full_recompute (transformer_block ):
364
+ # TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice.
365
+ # This does not affect to expert choice, but affects the experts usage metrics.
366
+ # We divide by 2 to correct for this double-counting due to recomputation
367
+ # TODO: new API to help determine if AC is enabled https://github.com/pytorch/pytorch/pull/160888
368
+ tokens_per_expert = tokens_per_expert // 2
369
+ tokens_per_expert_list .append (tokens_per_expert )
370
+
371
+ if len (tokens_per_expert_list ) == 0 :
372
+ # avoid cat empty tensor
373
+ return
374
+
375
+ n_expert = tokens_per_expert_list [0 ].numel ()
376
+ assert all (
377
+ t .numel () == n_expert for t in tokens_per_expert_list
378
+ ), "All MoE layers must have the same number of experts."
379
+
380
+ # [n_layers, n_expert], int32
381
+ tokens_per_expert_by_layer = torch .vstack (tokens_per_expert_list )
382
+
383
+ if dp_cp_mesh is not None :
384
+ # Perform single all-reduce to get global statistics across all processes
385
+ pg = dp_cp_mesh .get_group ()
386
+ torch .distributed .all_reduce (
387
+ tokens_per_expert_by_layer , group = pg , op = torch .distributed .ReduceOp .SUM
388
+ )
389
+
390
+ moe_layer_idx = 0
391
+ with torch .no_grad ():
392
+ for model_part in model_parts :
393
+ for layer_id , transformer_block in enumerate (
394
+ model_part .layers .values ()
395
+ ):
396
+ if not transformer_block .moe_enabled :
397
+ continue
355
398
moe = transformer_block .moe
399
+
400
+ tokens_per_expert = tokens_per_expert_by_layer [
401
+ moe_layer_idx
402
+ ].float ()
403
+ moe_layer_idx += 1
404
+ # uncomment to log expert usage once we fix https://github.com/pytorch/torchtitan/pull/1578
405
+ # sum_tokens = tokens_per_expert.sum().clamp(min=1.0)
406
+ # expert_usage_metrics = {
407
+ # f"moe_ep_usage/L-{layer_id}_EP-{ep_idx}": usage / sum_tokens
408
+ # for ep_idx, usage in enumerate(tokens_per_expert)
409
+ # }
410
+
356
411
if moe .load_balance_coeff is None :
357
- return
358
-
359
- if dp_cp_mesh is not None :
360
- torch .distributed .all_reduce (
361
- moe .tokens_per_expert , group = dp_cp_mesh .get_group ()
362
- )
363
-
364
- with torch .no_grad ():
365
- expert_bias_delta = moe .load_balance_coeff * torch .sign (
366
- moe .tokens_per_expert .mean () - moe .tokens_per_expert
367
- )
368
- expert_bias_delta = expert_bias_delta - expert_bias_delta .mean ()
369
- moe .expert_bias .add_ (expert_bias_delta )
370
- moe .tokens_per_expert .zero_ ()
412
+ continue
413
+ # update the expert bias
414
+ # this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed
415
+ expert_bias_delta = moe .load_balance_coeff * torch .sign (
416
+ tokens_per_expert .mean () - tokens_per_expert
417
+ )
418
+ expert_bias_delta = expert_bias_delta - expert_bias_delta .mean ()
419
+ moe .expert_bias .add_ (expert_bias_delta )
420
+ moe .tokens_per_expert .zero_ ()
371
421
372
422
optimizers .register_step_pre_hook (
373
423
lambda * args , ** kwargs : _update_expert_bias (
0 commit comments