@@ -455,25 +455,24 @@ def __call__(
455
455
)
456
456
intermediate_inputs = inputs + attention_lnx
457
457
458
- # Fully Connected
459
- hidden_states = rms_norm (
458
+ load_balance_loss = None
459
+ if self .is_moe_layer :
460
+ # Fully Connected
461
+ hidden_states = rms_norm (
460
462
num_features = intermediate_inputs .shape [- 1 ],
461
463
dtype = cfg .dtype ,
462
464
weight_dtype = cfg .weight_dtype ,
463
465
name = "post_self_attention_layer_norm" ,
464
466
kernel_axes = ("norm" ,),
465
467
epsilon = cfg .normalization_layer_epsilon ,
466
- )(intermediate_inputs )
467
- hidden_states = nn .with_logical_constraint (
468
- hidden_states , ("activation_batch" , "activation_norm_length" , "activation_embed" )
469
- )
470
-
471
- load_balance_loss = None
472
- if self .is_moe_layer :
468
+ )(intermediate_inputs )
469
+ hidden_states = nn .with_logical_constraint (
470
+ hidden_states , ("activation_batch" , "activation_norm_length" , "activation_embed" )
471
+ )
473
472
# NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints.
474
473
# The `name` represents the weight name in JAX/checkpoints and so the class name
475
474
# is just for readability.
476
- mlp_lnx = moe .RoutedAndSharedMoE (
475
+ mlp_lnx = moe .get_routed_and_shared_moe (
477
476
name = "Llama4MoEBlock_0" ,
478
477
config = cfg ,
479
478
mesh = self .mesh ,
@@ -484,8 +483,9 @@ def __call__(
484
483
quant = self .quant ,
485
484
)(hidden_states )
486
485
else :
486
+ # MLP block with pre-norm.
487
487
mlp_lnx = mlp_block (
488
- in_features = hidden_states .shape [- 1 ],
488
+ in_features = intermediate_inputs .shape [- 1 ],
489
489
intermediate_dim = cfg .mlp_dim ,
490
490
activations = cfg .mlp_activations ,
491
491
intermediate_dropout_rate = cfg .dropout_rate ,
@@ -494,7 +494,8 @@ def __call__(
494
494
name = "mlp" ,
495
495
config = cfg ,
496
496
quant = self .quant ,
497
- )(hidden_states , deterministic = deterministic )
497
+ use_pre_norm = True ,
498
+ )(intermediate_inputs , deterministic = deterministic )
498
499
mlp_lnx = nn .with_logical_constraint (mlp_lnx , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
499
500
500
501
layer_output = mlp_lnx + intermediate_inputs
0 commit comments