@@ -359,15 +359,12 @@ def _update_expert_bias(
359
359
if not transformer_block .moe_enabled :
360
360
continue
361
361
moe = transformer_block .moe
362
- if moe .load_balance_coeff is None :
363
- continue
364
362
tokens_per_expert = transformer_block .moe .tokens_per_expert
365
- if is_full_recompute (transformer_block .moe ) or is_full_recompute (
366
- transformer_block
367
- ):
363
+ if is_full_recompute (transformer_block ):
368
364
# TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice.
369
365
# This does not affect to expert choice, but affects the experts usage metrics.
370
366
# We divide by 2 to correct for this double-counting due to recomputation
367
+ # TODO: new API to help determine if AC is enabled https://github.com/pytorch/pytorch/pull/160888
371
368
tokens_per_expert = tokens_per_expert // 2
372
369
tokens_per_expert_list .append (tokens_per_expert )
373
370
@@ -390,30 +387,37 @@ def _update_expert_bias(
390
387
tokens_per_expert_by_layer , group = pg , op = torch .distributed .ReduceOp .SUM
391
388
)
392
389
393
- layer_idx = 0
390
+ moe_layer_idx = 0
394
391
with torch .no_grad ():
395
392
for model_part in model_parts :
396
- for transformer_block in model_part .layers .values ():
393
+ for layer_id , transformer_block in enumerate (
394
+ model_part .layers .values ()
395
+ ):
397
396
if not transformer_block .moe_enabled :
398
397
continue
399
398
moe = transformer_block .moe
400
399
400
+ tokens_per_expert = tokens_per_expert_by_layer [
401
+ moe_layer_idx
402
+ ].float ()
403
+ moe_layer_idx += 1
404
+ # uncomment to log expert usage once we fix https://github.com/pytorch/torchtitan/pull/1578
405
+ # sum_tokens = tokens_per_expert.sum().clamp(min=1.0)
406
+ # expert_usage_metrics = {
407
+ # f"moe_ep_usage/L-{layer_id}_EP-{ep_idx}": usage / sum_tokens
408
+ # for ep_idx, usage in enumerate(tokens_per_expert)
409
+ # }
410
+
401
411
if moe .load_balance_coeff is None :
402
412
continue
403
-
404
- tokens_per_expert = tokens_per_expert_by_layer [layer_idx ].float ()
405
- layer_idx += 1
406
-
407
413
# update the expert bias
408
- # https://github.com/pytorch/torchtitan/issues/1506
409
414
# this is not exactly the same as https://arxiv.org/pdf/2408.15664 proposed
410
- expert_bias_delta = load_balance_coeff * torch .sign (
415
+ expert_bias_delta = moe . load_balance_coeff * torch .sign (
411
416
tokens_per_expert .mean () - tokens_per_expert
412
417
)
413
418
expert_bias_delta = expert_bias_delta - expert_bias_delta .mean ()
414
419
moe .expert_bias .add_ (expert_bias_delta )
415
420
moe .tokens_per_expert .zero_ ()
416
- # placeholder to record and log the expert usage
417
421
418
422
optimizers .register_step_pre_hook (
419
423
lambda * args , ** kwargs : _update_expert_bias (
0 commit comments