19
19
# pylint: disable=no-name-in-module
20
20
21
21
22
- from typing import Optional
22
+ from typing import Any
23
23
24
24
from jax .ad_checkpoint import checkpoint_name
25
25
from jax .sharding import Mesh
26
26
import jax .numpy as jnp
27
27
28
28
from flax import linen as nn
29
+ from flax import nnx
29
30
30
- from MaxText .layers .linears import mlp_block
31
- from MaxText .layers import models
31
+ from MaxText .layers import initializers , nnx_wrappers
32
+ from MaxText .layers .linears import MlpBlock
33
+ from MaxText .layers .models import Config
34
+ from MaxText .layers .attentions import Attention
32
35
from MaxText .layers import quantizations
33
- from MaxText .layers .attentions import attention_as_linen
34
36
from MaxText .layers .quantizations import AqtQuantization as Quant
35
- from MaxText .layers .normalizations import rms_norm
37
+ from MaxText .layers .normalizations import RMSNorm
36
38
37
39
38
40
# -----------------------------------------
39
41
# The Decoder Layer for Mistral
40
42
# -----------------------------------------
41
43
42
44
43
- class MistralDecoderLayer (nn .Module ):
45
+ class MistralDecoderLayer (nnx .Module ):
44
46
"""Transformer decoder layer that attends to the encoder."""
45
47
46
- config : models .Config
47
- mesh : Mesh
48
- quant : Optional [Quant ] = None
48
+ def __init__ (
49
+ self ,
50
+ config : Config ,
51
+ mesh : Mesh ,
52
+ quant : Quant | None = None ,
53
+ rngs : nnx .Rngs | None = None ,
54
+ ** kwargs : Any ,
55
+ ):
56
+ self .config = config
57
+ self .mesh = mesh
58
+ self .quant = quant
59
+ self .rngs = rngs if rngs else kwargs .get ("rngs" , nnx .Rngs (0 ))
49
60
50
- @nn .compact
51
61
def __call__ (
52
62
self ,
53
63
inputs ,
@@ -59,47 +69,43 @@ def __call__(
59
69
page_state = None ,
60
70
slot = None ,
61
71
):
62
- cfg = self .config
63
- mesh = self .mesh
64
72
65
73
inputs = nn .with_logical_constraint (inputs , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
66
74
inputs = checkpoint_name (inputs , "decoder_layer_input" )
67
- lnx_rms = rms_norm (
75
+ lnx_rms = RMSNorm (
68
76
num_features = inputs .shape [- 1 ],
69
- dtype = cfg .dtype ,
70
- weight_dtype = cfg .weight_dtype ,
71
- name = "pre_self_attention_layer_norm" ,
77
+ dtype = self .config .dtype ,
78
+ weight_dtype = self .config .weight_dtype ,
72
79
kernel_axes = ("norm" ,),
73
- epsilon = cfg .normalization_layer_epsilon ,
80
+ epsilon = self .config .normalization_layer_epsilon ,
81
+ rngs = self .rngs ,
74
82
)
75
83
lnx = lnx_rms (inputs )
76
84
77
85
lnx = nn .with_logical_constraint (lnx , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
78
86
79
87
# Self-attention block
80
- attention_layer = attention_as_linen (
81
- config = cfg ,
82
- num_query_heads = cfg .num_query_heads ,
83
- num_kv_heads = cfg .num_kv_heads ,
84
- head_dim = cfg .head_dim ,
85
- max_target_length = cfg .max_target_length ,
86
- max_prefill_predict_length = cfg .max_prefill_predict_length ,
87
- attention_kernel = cfg .attention ,
88
+ attention_layer = Attention (
89
+ config = self .config ,
90
+ num_query_heads = self .config .num_query_heads ,
91
+ num_kv_heads = self .config .num_kv_heads ,
92
+ head_dim = self .config .head_dim ,
93
+ max_target_length = self .config .max_target_length ,
94
+ max_prefill_predict_length = self .config .max_prefill_predict_length ,
95
+ attention_kernel = self .config .attention ,
96
+ mesh = self .mesh ,
97
+ dtype = self .config .dtype ,
88
98
inputs_q = lnx ,
89
99
inputs_kv = lnx ,
90
- mesh = mesh ,
91
- dtype = cfg .dtype ,
92
- weight_dtype = cfg .weight_dtype ,
93
- dropout_rate = cfg .dropout_rate ,
94
- name = "self_attention" ,
95
- float32_qk_product = cfg .float32_qk_product ,
96
- float32_logits = cfg .float32_logits ,
100
+ weight_dtype = self .config .weight_dtype ,
101
+ dropout_rate = self .config .dropout_rate ,
102
+ float32_qk_product = self .config .float32_qk_product ,
103
+ float32_logits = self .config .float32_logits ,
97
104
quant = self .quant ,
98
- kv_quant = quantizations .configure_kv_quant (cfg ),
99
- prefill_cache_axis_order = tuple (map (int , cfg .prefill_cache_axis_order .split ("," ))),
100
- ar_cache_axis_order = tuple (map (int , cfg .ar_cache_axis_order .split ("," ))),
101
- compute_axis_order = tuple (map (int , cfg .compute_axis_order .split ("," ))),
102
- model_mode = model_mode ,
105
+ kv_quant = quantizations .configure_kv_quant (self .config ),
106
+ prefill_cache_axis_order = tuple (map (int , self .config .prefill_cache_axis_order .split ("," ))),
107
+ ar_cache_axis_order = tuple (map (int , self .config .ar_cache_axis_order .split ("," ))),
108
+ compute_axis_order = tuple (map (int , self .config .compute_axis_order .split ("," ))),
103
109
)
104
110
105
111
attention_lnx = attention_layer (
@@ -118,40 +124,40 @@ def __call__(
118
124
intermediate_inputs = inputs + attention_lnx
119
125
120
126
# Fully Connected
121
- hidden_states = rms_norm (
127
+ hidden_states = RMSNorm (
122
128
num_features = intermediate_inputs .shape [- 1 ],
123
- dtype = cfg .dtype ,
124
- weight_dtype = cfg .weight_dtype ,
125
- name = "post_self_attention_layer_norm" ,
129
+ dtype = self .config .dtype ,
130
+ weight_dtype = self .config .weight_dtype ,
126
131
kernel_axes = ("norm" ,),
127
- epsilon = cfg .normalization_layer_epsilon ,
132
+ epsilon = self .config .normalization_layer_epsilon ,
133
+ rngs = self .rngs ,
128
134
)(intermediate_inputs )
129
135
hidden_states = nn .with_logical_constraint (
130
136
hidden_states , ("activation_batch" , "activation_norm_length" , "activation_embed" )
131
137
)
132
138
133
- mlp_lnx = mlp_block (
139
+ mlp_lnx = MlpBlock (
134
140
in_features = hidden_states .shape [- 1 ],
135
- intermediate_dim = cfg .mlp_dim ,
136
- activations = cfg .mlp_activations ,
137
- intermediate_dropout_rate = cfg .dropout_rate ,
138
- dtype = cfg .dtype ,
139
- weight_dtype = cfg .weight_dtype ,
140
- name = "mlp" ,
141
- config = cfg ,
141
+ intermediate_dim = self .config .mlp_dim ,
142
+ activations = self .config .mlp_activations ,
143
+ intermediate_dropout_rate = self .config .dropout_rate ,
144
+ dtype = self .config .dtype ,
145
+ weight_dtype = self .config .weight_dtype ,
146
+ config = self .config ,
142
147
quant = self .quant ,
148
+ rngs = self .rngs ,
143
149
)(hidden_states , deterministic = deterministic )
144
150
mlp_lnx = nn .with_logical_constraint (mlp_lnx , ("activation_batch" , "activation_norm_length" , "activation_embed" ))
145
151
146
152
layer_output = mlp_lnx + intermediate_inputs
147
- layer_output = nn .Dropout (rate = cfg .dropout_rate , broadcast_dims = (- 2 ,))(layer_output , deterministic = deterministic )
153
+ layer_output = nn .Dropout (rate = self . config .dropout_rate , broadcast_dims = (- 2 ,))(layer_output , deterministic = deterministic )
148
154
149
155
layer_output = nn .with_logical_constraint (
150
156
layer_output ,
151
157
("activation_batch" , "activation_norm_length" , "activation_embed" ),
152
158
)
153
159
154
- if cfg .record_internal_nn_metrics :
160
+ if self . config .record_internal_nn_metrics :
155
161
self .sow ("intermediates" , "activation_mean" , jnp .mean (layer_output ))
156
162
self .sow ("intermediates" , "activation_stdev" , jnp .std (layer_output ))
157
163
self .sow (
@@ -160,7 +166,14 @@ def __call__(
160
166
jnp .sum (layer_output == 0 ) / jnp .size (layer_output ),
161
167
)
162
168
163
- if cfg .scan_layers :
169
+ if self . config .scan_layers :
164
170
return layer_output , None
165
171
else :
166
172
return layer_output
173
+
174
+ def mistral_decoder_layer_class () -> nn .Module :
175
+ """Create a MistralDecoderLayer Linen module"""
176
+ return nnx_wrappers .to_linen_class (
177
+ MistralDecoderLayer ,
178
+ metadata_fn = initializers .variable_to_logically_partitioned ,
179
+ )
0 commit comments