Skip to content

Commit b2bfd1d

Browse files
committed
mistral migration
1 parent a50d0f2 commit b2bfd1d

File tree

2 files changed

+67
-54
lines changed

2 files changed

+67
-54
lines changed

MaxText/layers/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def get_decoder_layers(self):
338338
return [llama2.LlamaDecoderLayer]
339339
case DecoderBlockType.MISTRAL:
340340
# TODO(ranran): update to Mistral with sliding window attention
341-
return [mistral.MistralDecoderLayer]
341+
return [mistral.mistral_decoder_layer_class()]
342342
case DecoderBlockType.MIXTRAL:
343343
return [mixtral.MixtralDecoderLayer]
344344
case DecoderBlockType.DEEPSEEK:

MaxText/layers/mistral.py

Lines changed: 66 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,45 @@
1919
# pylint: disable=no-name-in-module
2020

2121

22-
from typing import Optional
22+
from typing import Any
2323

2424
from jax.ad_checkpoint import checkpoint_name
2525
from jax.sharding import Mesh
2626
import jax.numpy as jnp
2727

2828
from flax import linen as nn
29+
from flax import nnx
2930

30-
from MaxText.layers.linears import mlp_block
31-
from MaxText.layers import models
31+
from MaxText.layers import initializers, nnx_wrappers
32+
from MaxText.layers.linears import MlpBlock
33+
from MaxText.layers.models import Config
34+
from MaxText.layers.attentions import Attention
3235
from MaxText.layers import quantizations
33-
from MaxText.layers.attentions import attention_as_linen
3436
from MaxText.layers.quantizations import AqtQuantization as Quant
35-
from MaxText.layers.normalizations import rms_norm
37+
from MaxText.layers.normalizations import RMSNorm
3638

3739

3840
# -----------------------------------------
3941
# The Decoder Layer for Mistral
4042
# -----------------------------------------
4143

4244

43-
class MistralDecoderLayer(nn.Module):
45+
class MistralDecoderLayer(nnx.Module):
4446
"""Transformer decoder layer that attends to the encoder."""
4547

46-
config: models.Config
47-
mesh: Mesh
48-
quant: Optional[Quant] = None
48+
def __init__(
49+
self,
50+
config: Config,
51+
mesh: Mesh,
52+
quant: Quant | None = None,
53+
rngs: nnx.Rngs | None = None,
54+
**kwargs: Any,
55+
):
56+
self.config = config
57+
self.mesh = mesh
58+
self.quant = quant
59+
self.rngs = rngs if rngs else kwargs.get("rngs", nnx.Rngs(0))
4960

50-
@nn.compact
5161
def __call__(
5262
self,
5363
inputs,
@@ -59,47 +69,43 @@ def __call__(
5969
page_state=None,
6070
slot=None,
6171
):
62-
cfg = self.config
63-
mesh = self.mesh
6472

6573
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
6674
inputs = checkpoint_name(inputs, "decoder_layer_input")
67-
lnx_rms = rms_norm(
75+
lnx_rms = RMSNorm(
6876
num_features=inputs.shape[-1],
69-
dtype=cfg.dtype,
70-
weight_dtype=cfg.weight_dtype,
71-
name="pre_self_attention_layer_norm",
77+
dtype=self.config.dtype,
78+
weight_dtype=self.config.weight_dtype,
7279
kernel_axes=("norm",),
73-
epsilon=cfg.normalization_layer_epsilon,
80+
epsilon=self.config.normalization_layer_epsilon,
81+
rngs=self.rngs,
7482
)
7583
lnx = lnx_rms(inputs)
7684

7785
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
7886

7987
# Self-attention block
80-
attention_layer = attention_as_linen(
81-
config=cfg,
82-
num_query_heads=cfg.num_query_heads,
83-
num_kv_heads=cfg.num_kv_heads,
84-
head_dim=cfg.head_dim,
85-
max_target_length=cfg.max_target_length,
86-
max_prefill_predict_length=cfg.max_prefill_predict_length,
87-
attention_kernel=cfg.attention,
88+
attention_layer = Attention(
89+
config=self.config,
90+
num_query_heads=self.config.num_query_heads,
91+
num_kv_heads=self.config.num_kv_heads,
92+
head_dim=self.config.head_dim,
93+
max_target_length=self.config.max_target_length,
94+
max_prefill_predict_length=self.config.max_prefill_predict_length,
95+
attention_kernel=self.config.attention,
96+
mesh=self.mesh,
97+
dtype=self.config.dtype,
8898
inputs_q=lnx,
8999
inputs_kv=lnx,
90-
mesh=mesh,
91-
dtype=cfg.dtype,
92-
weight_dtype=cfg.weight_dtype,
93-
dropout_rate=cfg.dropout_rate,
94-
name="self_attention",
95-
float32_qk_product=cfg.float32_qk_product,
96-
float32_logits=cfg.float32_logits,
100+
weight_dtype=self.config.weight_dtype,
101+
dropout_rate=self.config.dropout_rate,
102+
float32_qk_product=self.config.float32_qk_product,
103+
float32_logits=self.config.float32_logits,
97104
quant=self.quant,
98-
kv_quant=quantizations.configure_kv_quant(cfg),
99-
prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))),
100-
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
101-
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
102-
model_mode=model_mode,
105+
kv_quant=quantizations.configure_kv_quant(self.config),
106+
prefill_cache_axis_order=tuple(map(int, self.config.prefill_cache_axis_order.split(","))),
107+
ar_cache_axis_order=tuple(map(int, self.config.ar_cache_axis_order.split(","))),
108+
compute_axis_order=tuple(map(int, self.config.compute_axis_order.split(","))),
103109
)
104110

105111
attention_lnx = attention_layer(
@@ -118,40 +124,40 @@ def __call__(
118124
intermediate_inputs = inputs + attention_lnx
119125

120126
# Fully Connected
121-
hidden_states = rms_norm(
127+
hidden_states = RMSNorm(
122128
num_features=intermediate_inputs.shape[-1],
123-
dtype=cfg.dtype,
124-
weight_dtype=cfg.weight_dtype,
125-
name="post_self_attention_layer_norm",
129+
dtype=self.config.dtype,
130+
weight_dtype=self.config.weight_dtype,
126131
kernel_axes=("norm",),
127-
epsilon=cfg.normalization_layer_epsilon,
132+
epsilon=self.config.normalization_layer_epsilon,
133+
rngs=self.rngs,
128134
)(intermediate_inputs)
129135
hidden_states = nn.with_logical_constraint(
130136
hidden_states, ("activation_batch", "activation_norm_length", "activation_embed")
131137
)
132138

133-
mlp_lnx = mlp_block(
139+
mlp_lnx = MlpBlock(
134140
in_features=hidden_states.shape[-1],
135-
intermediate_dim=cfg.mlp_dim,
136-
activations=cfg.mlp_activations,
137-
intermediate_dropout_rate=cfg.dropout_rate,
138-
dtype=cfg.dtype,
139-
weight_dtype=cfg.weight_dtype,
140-
name="mlp",
141-
config=cfg,
141+
intermediate_dim=self.config.mlp_dim,
142+
activations=self.config.mlp_activations,
143+
intermediate_dropout_rate=self.config.dropout_rate,
144+
dtype=self.config.dtype,
145+
weight_dtype=self.config.weight_dtype,
146+
config=self.config,
142147
quant=self.quant,
148+
rngs=self.rngs,
143149
)(hidden_states, deterministic=deterministic)
144150
mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_norm_length", "activation_embed"))
145151

146152
layer_output = mlp_lnx + intermediate_inputs
147-
layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
153+
layer_output = nn.Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic)
148154

149155
layer_output = nn.with_logical_constraint(
150156
layer_output,
151157
("activation_batch", "activation_norm_length", "activation_embed"),
152158
)
153159

154-
if cfg.record_internal_nn_metrics:
160+
if self.config.record_internal_nn_metrics:
155161
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
156162
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
157163
self.sow(
@@ -160,7 +166,14 @@ def __call__(
160166
jnp.sum(layer_output == 0) / jnp.size(layer_output),
161167
)
162168

163-
if cfg.scan_layers:
169+
if self.config.scan_layers:
164170
return layer_output, None
165171
else:
166172
return layer_output
173+
174+
def mistral_decoder_layer_class() -> nn.Module:
175+
"""Create a MistralDecoderLayer Linen module"""
176+
return nnx_wrappers.to_linen_class(
177+
MistralDecoderLayer,
178+
metadata_fn=initializers.variable_to_logically_partitioned,
179+
)

0 commit comments

Comments
 (0)