@@ -341,7 +341,7 @@ def build_optimizers_with_moe_load_balancing(
341
341
)
342
342
343
343
# for MoE auxiliary-loss-free load balancing
344
- def is_full_recompute (module ):
344
+ def _is_recomputation_enabled (module ):
345
345
return getattr (module , "checkpoint_impl" , None ) is CheckpointImpl .NO_REENTRANT
346
346
347
347
def _update_expert_bias (
@@ -358,9 +358,8 @@ def _update_expert_bias(
358
358
for transformer_block in model_part .layers .values ():
359
359
if not transformer_block .moe_enabled :
360
360
continue
361
- moe = transformer_block .moe
362
361
tokens_per_expert = transformer_block .moe .tokens_per_expert
363
- if is_full_recompute (transformer_block ):
362
+ if _is_recomputation_enabled (transformer_block ):
364
363
# TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice.
365
364
# This does not affect to expert choice, but affects the experts usage metrics.
366
365
# We divide by 2 to correct for this double-counting due to recomputation
@@ -372,12 +371,6 @@ def _update_expert_bias(
372
371
# avoid cat empty tensor
373
372
return
374
373
375
- n_expert = tokens_per_expert_list [0 ].numel ()
376
- assert all (
377
- t .numel () == n_expert for t in tokens_per_expert_list
378
- ), "All MoE layers must have the same number of experts."
379
-
380
- # [n_layers, n_expert], int32
381
374
tokens_per_expert_by_layer = torch .vstack (tokens_per_expert_list )
382
375
383
376
if dp_cp_mesh is not None :
0 commit comments