Skip to content

Commit a738a0c

Browse files
committed
Merge branch 'main' into parambole/maxtext_qwen_moe
2 parents f73fe03 + 3078cd3 commit a738a0c

22 files changed

+906
-677
lines changed

MaxText/layers/deepseek.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def __call__(
259259
# NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints.
260260
# The `name` represents the weight name in JAX/checkpoints and so the class name
261261
# is just for readability.
262-
mlp_lnx = moe.RoutedAndSharedMoE(
262+
mlp_lnx = moe.get_routed_and_shared_moe(
263263
name="DeepSeekMoeBlock_0",
264264
config=cfg,
265265
mesh=self.mesh,

MaxText/layers/gemma.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,10 @@ def __call__(
105105
)
106106
attention_lnx += inputs
107107
residual = attention_lnx
108-
attn_output = rms_norm(
109-
num_features=attention_lnx.shape[-1],
110-
dtype=cfg.dtype,
111-
weight_dtype=cfg.weight_dtype,
112-
name="pre_ffw_norm",
113-
kernel_axes=("norm",),
114-
)(attention_lnx)
115108

116-
# MLP block.
109+
# MLP block with pre-norm.
117110
mlp_lnx = mlp_block(
118-
in_features=attn_output.shape[-1],
111+
in_features=residual.shape[-1],
119112
intermediate_dim=cfg.mlp_dim,
120113
activations=cfg.mlp_activations,
121114
intermediate_dropout_rate=cfg.dropout_rate,
@@ -124,7 +117,8 @@ def __call__(
124117
name="mlp",
125118
config=cfg,
126119
quant=self.quant,
127-
)(attn_output, deterministic=deterministic)
120+
use_pre_norm=True,
121+
)(residual, deterministic=deterministic)
128122
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
129123

130124
next_layer_addition = mlp_lnx + residual

MaxText/layers/gemma2.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,9 @@ def __call__(
116116
attention_lnx += inputs
117117
residual = attention_lnx
118118

119-
attn_output = rms_norm(
120-
num_features=attention_lnx.shape[-1],
121-
dtype=cfg.dtype,
122-
weight_dtype=cfg.weight_dtype,
123-
name="pre_ffw_norm_local",
124-
kernel_axes=("norm",),
125-
)(attention_lnx)
126-
127-
# MLP block.
119+
# MLP block with pre-norm.
128120
mlp_lnx = mlp_block(
129-
in_features=attn_output.shape[-1],
121+
in_features=attention_lnx.shape[-1],
130122
intermediate_dim=cfg.mlp_dim,
131123
activations=cfg.mlp_activations,
132124
intermediate_dropout_rate=cfg.dropout_rate,
@@ -136,7 +128,8 @@ def __call__(
136128
model_mode=model_mode,
137129
config=cfg,
138130
quant=self.quant,
139-
)(attn_output, deterministic=deterministic)
131+
use_pre_norm=True,
132+
)(attention_lnx, deterministic=deterministic)
140133

141134
if cfg.use_post_ffw_norm:
142135
mlp_lnx = rms_norm(

MaxText/layers/gemma3.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,17 +147,9 @@ def __call__(
147147
attention_lnx += inputs
148148
residual = attention_lnx
149149

150-
attn_output = rms_norm(
151-
num_features=attention_lnx.shape[-1],
152-
dtype=cfg.dtype,
153-
weight_dtype=cfg.weight_dtype,
154-
name="pre_ffw_norm",
155-
kernel_axes=("norm",),
156-
)(attention_lnx)
157-
158-
# MLP block.
150+
# MLP block with pre-norm.
159151
mlp_lnx = mlp_block(
160-
in_features=attn_output.shape[-1],
152+
in_features=attention_lnx.shape[-1],
161153
intermediate_dim=cfg.mlp_dim,
162154
activations=cfg.mlp_activations,
163155
intermediate_dropout_rate=cfg.dropout_rate,
@@ -166,7 +158,8 @@ def __call__(
166158
name="mlp",
167159
config=cfg,
168160
quant=self.quant,
169-
)(attn_output, deterministic=deterministic)
161+
use_pre_norm=True,
162+
)(attention_lnx, deterministic=deterministic)
170163

171164
if cfg.use_post_ffw_norm:
172165
mlp_lnx = rms_norm(

MaxText/layers/linears.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def __init__(
315315
rngs: nnx.Rngs,
316316
) -> None:
317317
"""A MlpBlock module.
318-
318+
319319
Args:
320320
config: Config object containing model parameters.
321321
in_features: Number of input features.
@@ -423,8 +423,19 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
423423
"""Applies Transformer MlpBlock module."""
424424
cfg = self.config
425425

426+
426427
if self.mlp_layer_norm is not None:
427428
inputs = self.mlp_layer_norm(inputs)
429+
if self.model_mode == MODEL_MODE_PREFILL:
430+
inputs = nn.with_logical_constraint(inputs, ("activation_batch",
431+
"prefill_activation_norm_length",
432+
"activation_embed")
433+
)
434+
else:
435+
inputs = nn.with_logical_constraint(inputs, ("activation_batch",
436+
"activation_norm_length",
437+
"activation_embed")
438+
)
428439

429440
# Iterate over specified MLP input activation functions.
430441
# e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.

MaxText/layers/llama2.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,9 @@ def __call__(
128128
attention_lnx = nn.with_logical_constraint(attention_lnx, activation_axis_names)
129129
intermediate_inputs = inputs + attention_lnx
130130

131-
# Fully Connected
132-
hidden_states = rms_norm(
133-
num_features=intermediate_inputs.shape[-1],
134-
dtype=cfg.dtype,
135-
weight_dtype=cfg.weight_dtype,
136-
name="post_self_attention_layer_norm",
137-
kernel_axes=("norm",),
138-
epsilon=cfg.normalization_layer_epsilon,
139-
)(intermediate_inputs)
140-
hidden_states = nn.with_logical_constraint(hidden_states, activation_axis_names)
141-
142-
# MLP block.
131+
# MLP block with pre-norm.
143132
mlp_lnx = mlp_block(
144-
in_features=hidden_states.shape[-1],
133+
in_features=intermediate_inputs.shape[-1],
145134
intermediate_dim=cfg.mlp_dim,
146135
activations=cfg.mlp_activations,
147136
intermediate_dropout_rate=cfg.dropout_rate,
@@ -151,7 +140,8 @@ def __call__(
151140
config=cfg,
152141
quant=self.quant,
153142
model_mode=model_mode,
154-
)(hidden_states, deterministic=deterministic)
143+
use_pre_norm=True,
144+
)(intermediate_inputs, deterministic=deterministic)
155145
mlp_lnx = nn.with_logical_constraint(mlp_lnx, activation_axis_names)
156146

157147
layer_output = mlp_lnx + intermediate_inputs

MaxText/layers/llama4.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -455,25 +455,24 @@ def __call__(
455455
)
456456
intermediate_inputs = inputs + attention_lnx
457457

458-
# Fully Connected
459-
hidden_states = rms_norm(
458+
load_balance_loss = None
459+
if self.is_moe_layer:
460+
# Fully Connected
461+
hidden_states = rms_norm(
460462
num_features=intermediate_inputs.shape[-1],
461463
dtype=cfg.dtype,
462464
weight_dtype=cfg.weight_dtype,
463465
name="post_self_attention_layer_norm",
464466
kernel_axes=("norm",),
465467
epsilon=cfg.normalization_layer_epsilon,
466-
)(intermediate_inputs)
467-
hidden_states = nn.with_logical_constraint(
468-
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
469-
)
470-
471-
load_balance_loss = None
472-
if self.is_moe_layer:
468+
)(intermediate_inputs)
469+
hidden_states = nn.with_logical_constraint(
470+
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
471+
)
473472
# NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints.
474473
# The `name` represents the weight name in JAX/checkpoints and so the class name
475474
# is just for readability.
476-
mlp_lnx = moe.RoutedAndSharedMoE(
475+
mlp_lnx = moe.get_routed_and_shared_moe(
477476
name="Llama4MoEBlock_0",
478477
config=cfg,
479478
mesh=self.mesh,
@@ -484,8 +483,9 @@ def __call__(
484483
quant=self.quant,
485484
)(hidden_states)
486485
else:
486+
# MLP block with pre-norm.
487487
mlp_lnx = mlp_block(
488-
in_features=hidden_states.shape[-1],
488+
in_features=intermediate_inputs.shape[-1],
489489
intermediate_dim=cfg.mlp_dim,
490490
activations=cfg.mlp_activations,
491491
intermediate_dropout_rate=cfg.dropout_rate,
@@ -494,7 +494,8 @@ def __call__(
494494
name="mlp",
495495
config=cfg,
496496
quant=self.quant,
497-
)(hidden_states, deterministic=deterministic)
497+
use_pre_norm=True,
498+
)(intermediate_inputs, deterministic=deterministic)
498499
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
499500

500501
layer_output = mlp_lnx + intermediate_inputs

MaxText/layers/mistral.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -118,21 +118,9 @@ def __call__(
118118
)
119119
intermediate_inputs = inputs + attention_lnx
120120

121-
# Fully Connected
122-
hidden_states = rms_norm(
123-
num_features=intermediate_inputs.shape[-1],
124-
dtype=cfg.dtype,
125-
weight_dtype=cfg.weight_dtype,
126-
name="post_self_attention_layer_norm",
127-
kernel_axes=("norm",),
128-
epsilon=cfg.normalization_layer_epsilon,
129-
)(intermediate_inputs)
130-
hidden_states = nn.with_logical_constraint(
131-
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
132-
)
133-
121+
# MLP block with pre-norm.
134122
mlp_lnx = mlp_block(
135-
in_features=hidden_states.shape[-1],
123+
in_features=intermediate_inputs.shape[-1],
136124
intermediate_dim=cfg.mlp_dim,
137125
activations=cfg.mlp_activations,
138126
intermediate_dropout_rate=cfg.dropout_rate,
@@ -141,7 +129,8 @@ def __call__(
141129
name="mlp",
142130
config=cfg,
143131
quant=self.quant,
144-
)(hidden_states, deterministic=deterministic)
132+
use_pre_norm=True,
133+
)(intermediate_inputs, deterministic=deterministic)
145134
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
146135

147136
layer_output = mlp_lnx + intermediate_inputs

MaxText/layers/mixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __call__(
136136
# NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints.
137137
# The `name` represents the weight name in JAX/checkpoints and so the class name
138138
# is just for readability.
139-
mlp_lnx, load_balance_loss = moe.RoutedMoE(
139+
mlp_lnx, load_balance_loss = moe.get_routed_moe(
140140
name="MoeBlock_0",
141141
config=cfg,
142142
num_experts=cfg.num_experts,

0 commit comments

Comments
 (0)