diff --git a/MaxText/inference/kvcache.py b/MaxText/inference/kvcache.py index ffe9e35ad..b4357619a 100644 --- a/MaxText/inference/kvcache.py +++ b/MaxText/inference/kvcache.py @@ -310,8 +310,9 @@ def __init__( self.model_mode = model_mode self.use_chunked_prefill = use_chunked_prefill - self._initialize_prefill_caches(model_mode) - self._initialize_ar_cache_vars(model_mode) + if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + self._initialize_prefill_caches(model_mode) + self._initialize_ar_cache_vars(model_mode) @property def prefill_key_vars(self): diff --git a/MaxText/layers/decoders.py b/MaxText/layers/decoders.py index 40ec5deea..dcb030015 100644 --- a/MaxText/layers/decoders.py +++ b/MaxText/layers/decoders.py @@ -85,12 +85,12 @@ def __call__( ): cfg = self.config mesh = self.mesh - if model_mode == MODEL_MODE_PREFILL: + if self.model_mode == MODEL_MODE_PREFILL: logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") else: logical_axis_names = ("activation_batch", "activation_length", "activation_embed") - if model_mode == MODEL_MODE_PREFILL: + if self.model_mode == MODEL_MODE_PREFILL: inputs = nn.with_logical_constraint(inputs, logical_axis_names) else: inputs = nn.with_logical_constraint(inputs, logical_axis_names) @@ -105,7 +105,7 @@ def __call__( epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), )(inputs) - if model_mode == MODEL_MODE_PREFILL: + if self.model_mode == MODEL_MODE_PREFILL: lnx = nn.with_logical_constraint(lnx, logical_axis_names) else: lnx = nn.with_logical_constraint(lnx, logical_axis_names) @@ -133,7 +133,7 @@ def __call__( ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), reshape_q=cfg.reshape_q, - model_mode=model_mode, + model_mode=self.model_mode, ) attention_lnx = attention_layer( @@ -142,10 +142,10 @@ def __call__( decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, - model_mode=model_mode, + model_mode=self.model_mode, ) - if model_mode == MODEL_MODE_PREFILL: + if self.model_mode == MODEL_MODE_PREFILL: attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names) else: attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names) @@ -159,11 +159,11 @@ def __call__( dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, name="mlp", - model_mode=model_mode, + model_mode=self.model_mode, config=cfg, quant=self.quant, )(lnx, deterministic=deterministic) - if model_mode == MODEL_MODE_PREFILL: + if self.model_mode == MODEL_MODE_PREFILL: mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names) else: mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names) @@ -175,7 +175,7 @@ def __call__( ) layer_output = next_layer_addition_dropped_out + inputs - if model_mode == MODEL_MODE_PREFILL: + if self.model_mode == MODEL_MODE_PREFILL: layer_output = nn.with_logical_constraint( layer_output, logical_axis_names, @@ -221,13 +221,13 @@ def __call__( ) -> jnp.ndarray: for lyr in range(self.num_decoder_layers): inputs = self.decoder_layer( - config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode + config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode )( inputs, decoder_segment_ids, decoder_positions, deterministic, - model_mode, + self.model_mode, slot=slot, page_state=page_state, ) @@ -341,7 +341,7 @@ def get_decoder_layers(self): case DecoderBlockType.DEFAULT: return [DecoderLayer] case DecoderBlockType.LLAMA2: - return [llama2.LlamaDecoderLayer] + return [llama2.LlamaDecoderLayerToLinen] case DecoderBlockType.MISTRAL: # TODO(ranran): update to Mistral with sliding window attention return [mistral.MistralDecoderLayer] @@ -358,9 +358,9 @@ def get_decoder_layers(self): case DecoderBlockType.GPT3: return [gpt3.Gpt3DecoderLayer] case DecoderBlockType.QWEN3: - return [qwen3.Qwen3DecoderLayer] + return [qwen3.Qwen3DecoderLayerToLinen] case DecoderBlockType.QWEN3_MOE: - return [qwen3.Qwen3MoeDecoderLayer] + return [qwen3.Qwen3MoeDecoderLayerToLinen] case DecoderBlockType.SIMPLE: return [simple_layer.SimpleDecoderLayer] case DecoderBlockType.SIMPLE_MLP: @@ -382,9 +382,7 @@ def move_to_device(variables): def map_fn(path, value): max_logging.log(f"models.py: Moving parameter {path} to device") - return jax.device_put( - value, max_utils.device_space() - ) + return jax.device_put(value, max_utils.device_space()) return jax.tree_util.tree_map_with_path(map_fn, variables) @@ -446,7 +444,14 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me length=length, metadata_params={nn.PARTITION_NAME: metadata_axis_name}, ) - return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, model_mode=model_mode, **kwargs) + return scan_fn( + config=cfg, + mesh=mesh, + name=metadata_axis_name, + quant=self.quant, + model_mode=self.model_mode, + **kwargs + ) def get_pipeline_stage_module(self, decoder_blocks): """get pipeline stage module""" @@ -498,7 +503,7 @@ def _apply_embedding( """Applies token and positional embeddings to the input tokens.""" cfg = self.config - y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) + y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=self.model_mode) # Merge the image embeddings with the text embeddings for multimodal models if image_embeddings is not None and cfg.use_multimodal: @@ -526,7 +531,7 @@ def _apply_embedding( embedding_init=nn.initializers.normal(stddev=1.0), name="position_embedder", config=cfg, - )(decoder_positions, model_mode=model_mode) + )(decoder_positions, model_mode=self.model_mode) return y @nn.compact @@ -572,7 +577,7 @@ def _apply_output_head(self, y, deterministic, model_mode): )( y ) # We do not quantize the logits matmul. - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab")) else: logits = nn.with_logical_constraint( @@ -604,7 +609,7 @@ def __call__( # [batch, length] -> [batch, length, emb_dim] y = self._apply_embedding( - decoder_input_tokens, decoder_positions, deterministic, model_mode, image_embeddings, bidirectional_mask + decoder_input_tokens, decoder_positions, deterministic, self.model_mode, image_embeddings, bidirectional_mask ) policy = self.get_remat_policy() @@ -614,12 +619,12 @@ def __call__( decoder_segment_ids, decoder_positions, deterministic, - model_mode, + self.model_mode, ) if cfg.using_pipeline_parallelism: if cfg.pipeline_fsdp_ag_once: partition_spec = self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode + y, decoder_segment_ids, decoder_positions, deterministic, self.model_mode ) else: partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. @@ -639,7 +644,7 @@ def __call__( "dense_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, + model_mode=self.model_mode, )(y, *broadcast_args) if num_moe_layers_outside_pp > 0: y, _ = self.scan_decoder_layers( @@ -649,7 +654,7 @@ def __call__( "moe_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, + model_mode=self.model_mode, )(y, *broadcast_args) y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) else: # Not DeepSeek @@ -665,7 +670,7 @@ def __call__( "layers_outside_pipeline", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, + model_mode=self.model_mode, )(y, *broadcast_args) else: if cfg.scan_layers: @@ -685,7 +690,7 @@ def __call__( "dense_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, + model_mode=self.model_mode, )(y, *broadcast_args) moe_layer = RemattedBlockLayers[1] moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) @@ -697,7 +702,7 @@ def __call__( "moe_layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, + model_mode=self.model_mode, )(y, *broadcast_args) elif cfg.decoder_block == DecoderBlockType.GEMMA3: y = self._apply_gemma3_scanned_blocks( @@ -705,7 +710,7 @@ def __call__( decoder_segment_ids, decoder_positions, deterministic, - model_mode, + self.model_mode, bidirectional_mask, previous_chunk, page_state, @@ -728,7 +733,7 @@ def __call__( "layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, + model_mode=self.model_mode, **layer_kwargs, )(y, *broadcast_args) else: @@ -749,7 +754,7 @@ def __call__( decoder_segment_ids, decoder_positions, deterministic, - model_mode, + self.model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, @@ -777,7 +782,7 @@ def __call__( decoder_segment_ids, decoder_positions, deterministic, - model_mode, + self.model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, @@ -789,7 +794,7 @@ def __call__( # After the final transformer layer, `y` holds the raw, un-normalized hidden state. hidden_state = y - logits = self._apply_output_head(hidden_state, deterministic, model_mode) + logits = self._apply_output_head(hidden_state, deterministic, self.model_mode) # The API of the Decoder is now a tuple, providing both the main output # and the raw hidden state needed for auxiliary tasks. @@ -828,7 +833,7 @@ def _apply_gemma3_scanned_blocks( decoder_segment_ids, decoder_positions, deterministic, - model_mode, + self.model_mode, ) y, _ = self.scan_decoder_layers( cfg, @@ -837,7 +842,7 @@ def _apply_gemma3_scanned_blocks( "layers", mesh, in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, + model_mode=self.model_mode, **layer_kwargs, )(y, *broadcast_args, **layer_call_kwargs) @@ -854,7 +859,7 @@ def _apply_gemma3_scanned_blocks( decoder_segment_ids, decoder_positions, deterministic, - model_mode, + self.model_mode, previous_chunk=previous_chunk, page_state=page_state, slot=slot, diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 2d0eb487a..d5c88f4a1 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -21,18 +21,20 @@ import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name from jax.sharding import Mesh -# from jax.experimental.pallas.ops.tpu import flash_attention from flax import linen as nn +from flax import nnx from MaxText.inference import page_manager from MaxText.common_types import Config -from MaxText.layers.linears import mlp_block +from MaxText.layers.linears import MlpBlock +from MaxText.layers import initializers +from MaxText.layers import nnx_wrappers from MaxText.layers import quantizations -from MaxText.layers.attentions import attention_as_linen +from MaxText.layers.attentions import Attention from MaxText.layers.quantizations import AqtQuantization as Quant -from MaxText.layers.normalizations import rms_norm -from MaxText.common_types import MODEL_MODE_PREFILL +from MaxText.layers.normalizations import RMSNorm +from MaxText.common_types import MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE # ----------------------------------------- @@ -40,15 +42,95 @@ # ----------------------------------------- -class LlamaDecoderLayer(nn.Module): +class LlamaDecoderLayer(nnx.Module): """Transformer decoder layer that attends to the encoder.""" - config: Config - mesh: Mesh - model_mode: str - quant: Optional[Quant] = None + def __init__( + self, + config: Config, + model_mode: str, + mesh: Mesh, + rngs: nnx.Rngs, + quant: Optional[Quant] = None, + ): + + self.config = config + self.mesh = mesh + self.quant = quant + + batch_size = config.micro_batch_size_to_train_on + + if model_mode == MODEL_MODE_PREFILL: + seq_len = config.max_prefill_predict_length + elif model_mode == MODEL_MODE_AUTOREGRESSIVE: + seq_len = 1 + else: + seq_len = config.max_target_length + + dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.self_attention = Attention( + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + dropout_rate=config.dropout_rate, + float32_qk_product=config.float32_qk_product, + float32_logits=config.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(config), + prefill_cache_axis_order=tuple(map(int, config.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, config.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, config.compute_axis_order.split(","))), + reshape_q=config.reshape_q, + use_ragged_attention=config.use_ragged_attention, + ragged_block_size=config.ragged_block_size, + model_mode=model_mode, + rngs=rngs, + ) + + self.post_self_attention_layer_norm = RMSNorm( + num_features=config.emb_dim, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + kernel_axes=("norm",), + epsilon=config.normalization_layer_epsilon, + rngs=rngs, + ) + + self.mlp = MlpBlock( + in_features=config.emb_dim, + intermediate_dim=config.mlp_dim, + activations=config.mlp_activations, + intermediate_dropout_rate=config.dropout_rate, + dtype=config.dtype, + weight_dtype=config.weight_dtype, + config=config, + quant=self.quant, + model_mode=model_mode, + rngs=rngs, + ) + + self.dropout = nnx.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=rngs) + - @nn.compact def __call__( self, inputs, @@ -61,7 +143,6 @@ def __call__( previous_chunk=None, ): cfg = self.config - mesh = self.mesh if model_mode == MODEL_MODE_PREFILL: activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") @@ -70,48 +151,12 @@ def __call__( inputs = nn.with_logical_constraint(inputs, activation_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") - lnx_rms = rms_norm( - num_features=inputs.shape[-1], - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="pre_self_attention_layer_norm", - kernel_axes=("norm",), - epsilon=cfg.normalization_layer_epsilon, - ) - lnx = lnx_rms(inputs) + lnx = self.pre_self_attention_layer_norm(inputs) lnx = nn.with_logical_constraint(lnx, activation_axis_names) # Self-attention block - attention_layer = attention_as_linen( - config=cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name="self_attention", - float32_qk_product=cfg.float32_qk_product, - float32_logits=cfg.float32_logits, - quant=self.quant, - kv_quant=quantizations.configure_kv_quant(cfg), - prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), - ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), - compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), - reshape_q=cfg.reshape_q, - use_ragged_attention=cfg.use_ragged_attention, - ragged_block_size=cfg.ragged_block_size, - model_mode=model_mode, - ) - - attention_lnx = attention_layer( + attention_lnx = self.self_attention( lnx, lnx, decoder_positions, @@ -127,35 +172,15 @@ def __call__( intermediate_inputs = inputs + attention_lnx # Fully Connected - hidden_states = rms_norm( - num_features=intermediate_inputs.shape[-1], - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="post_self_attention_layer_norm", - kernel_axes=("norm",), - epsilon=cfg.normalization_layer_epsilon, - )(intermediate_inputs) + hidden_states = self.post_self_attention_layer_norm(intermediate_inputs) hidden_states = nn.with_logical_constraint(hidden_states, activation_axis_names) # MLP block. - mlp_lnx = mlp_block( - in_features=hidden_states.shape[-1], - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="mlp", - config=cfg, - quant=self.quant, - model_mode=model_mode, - )(hidden_states, deterministic=deterministic) + mlp_lnx = self.mlp(hidden_states, deterministic=deterministic) mlp_lnx = nn.with_logical_constraint(mlp_lnx, activation_axis_names) layer_output = mlp_lnx + intermediate_inputs - - layer_output = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(layer_output, deterministic=deterministic) - + layer_output = self.dropout(layer_output, deterministic=deterministic) layer_output = nn.with_logical_constraint(layer_output, activation_axis_names) if cfg.record_internal_nn_metrics: @@ -171,3 +196,9 @@ def __call__( return layer_output, None else: return layer_output + + +LlamaDecoderLayerToLinen = nnx_wrappers.to_linen_class( + LlamaDecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 6482ff9bf..82ca5da04 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -54,6 +54,15 @@ class Transformer(nn.Module): model_mode: str = MODEL_MODE_TRAIN # May be different than the model_mode passed to __call__ # pylint: enable=attribute-defined-outside-init + def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): + """Initializes the model.""" + module = self.clone(model_mode=model_mode) + return nn.Module.init(module, *args, **kwargs) + + def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): + """Initializes the model.""" + module = self.clone(model_mode=model_mode) + return nn.Module.apply(module, *args, **kwargs) def setup(self): """Initialize shared_embedding & decoder layers.""" @@ -91,7 +100,6 @@ def __call__( decoder_segment_ids=None, encoder_images: Optional[jnp.ndarray] = None, enable_dropout=True, - model_mode=MODEL_MODE_TRAIN, previous_chunk=None, true_length: Optional[int] = None, slot: Optional[int] = None, @@ -107,7 +115,7 @@ def __call__( for this request. """ - if decoder_segment_ids is not None and model_mode == MODEL_MODE_AUTOREGRESSIVE: + if decoder_segment_ids is not None and self.model_mode == MODEL_MODE_AUTOREGRESSIVE: raise ValueError( f"During autoregressive decoding we assume the tokens are in the active sequence" f" which is always {DECODING_ACTIVE_SEQUENCE_INDICATOR}." @@ -128,7 +136,7 @@ def __call__( decoder_positions=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, - model_mode=model_mode, + model_mode=self.model_mode, previous_chunk=previous_chunk, slot=slot, page_state=page_state, @@ -164,7 +172,7 @@ def __call__( position_ids=decoder_positions, decoder_segment_ids=decoder_segment_ids, deterministic=not enable_dropout, - model_mode=model_mode, + model_mode=self.model_mode, ) return logits @@ -205,7 +213,6 @@ def __call__( decoder_segment_ids=None, encoder_images: Optional[jnp.ndarray] = None, enable_dropout=True, - model_mode=MODEL_MODE_TRAIN, previous_chunk=None, true_length: Optional[int] = None, slot: Optional[int] = None, @@ -221,7 +228,7 @@ def __call__( decoder_segment_ids, encoder_images, enable_dropout, - model_mode, + self.model_mode, previous_chunk, true_length, slot, @@ -238,7 +245,7 @@ def __call__( decoder_segment_ids=decoder_segment_ids, encoder_images=encoder_images, enable_dropout=enable_dropout, - model_mode=model_mode, + model_mode=self.model_mode, previous_chunk=previous_chunk, true_length=true_length, slot=slot, diff --git a/MaxText/layers/nnx_wrappers.py b/MaxText/layers/nnx_wrappers.py index 76e38738a..d158a76ee 100644 --- a/MaxText/layers/nnx_wrappers.py +++ b/MaxText/layers/nnx_wrappers.py @@ -530,68 +530,70 @@ def to_linen_class( base_skip_rng: bool = False, **partial_kwargs: tp.Any, ) -> type[ToLinen]: - """Dynamically wraps an NNX module class into a Flax Linen module class.""" + """A dynamically created Linen Module that wraps a specific NNX Module. + + This class is not meant to be used directly. Instead, it is created and + returned by the `to_linen_class` function. It acts as a "partially applied" + version of the `ToLinen` wrapper, where the NNX module to be wrapped and + its default arguments are pre-configured. + + When you instantiate this class, it behaves like a standard Linen module. + The arguments you provide during instantiation can override the defaults + that were set when this class was created by `to_linen_class`. + + For example: + >>> from flax import linen as nn, nnx + >>> from MaxText.layers import linears + >>> # Create a specialized Linen wrapper for linears.DenseGeneral + >>> LinenDenseGeneral = to_linen_class(linears.DenseGeneral) + >>> # Now, LinenDenseGeneral can be used like a regular Linen module + >>> class MyModel(nn.Module): + ... def setup(self): + ... # Instantiate the wrapped linears.DenseGeneral with its arguments + ... self.dense = LinenDenseGeneral( + ... in_features_shape=10, out_features_shape=5 + ... ) + ... def __call__(self, x): + ... return self.dense(x) + + Attributes: + (The attributes are dynamically set by the `ToLinen` parent class based + on the arguments provided during instantiation.) + """ + + def __init__( + self, + args=None, + kwargs=None, + nnx_class=None, + skip_rng=None, + metadata_fn=None, + name=_MISSING, + parent=_MISSING, + **other_kwargs, + ): + linen_kwargs = {} + if not isinstance(parent, _Missing): + linen_kwargs["parent"] = parent + if not isinstance(name, _Missing): + linen_kwargs["name"] = name + ToLinen.__init__( + self, + nnx_class=nnx_class or base_nnx_class, + args=args or (), + metadata_fn=metadata_fn or base_metadata_fn, + skip_rng=skip_rng or base_skip_rng, + kwargs=FrozenDict({**partial_kwargs, **(kwargs or {}), **other_kwargs}), + **linen_kwargs, + ) class ToLinenPartial(ToLinen): - """A dynamically created Linen Module that wraps a specific NNX Module. - - This class is not meant to be used directly. Instead, it is created and - returned by the `to_linen_class` function. It acts as a "partially applied" - version of the `ToLinen` wrapper, where the NNX module to be wrapped and - its default arguments are pre-configured. - - When you instantiate this class, it behaves like a standard Linen module. - The arguments you provide during instantiation can override the defaults - that were set when this class was created by `to_linen_class`. - - For example: - >>> from flax import linen as nn, nnx - >>> from MaxText.layers import linears - >>> # Create a specialized Linen wrapper for linears.DenseGeneral - >>> LinenDenseGeneral = to_linen_class(linears.DenseGeneral) - >>> # Now, LinenDenseGeneral can be used like a regular Linen module - >>> class MyModel(nn.Module): - ... def setup(self): - ... # Instantiate the wrapped linears.DenseGeneral with its arguments - ... self.dense = LinenDenseGeneral( - ... in_features_shape=10, out_features_shape=5 - ... ) - ... def __call__(self, x): - ... return self.dense(x) - - Attributes: - (The attributes are dynamically set by the `ToLinen` parent class based - on the arguments provided during instantiation.) - """ + """A dynamically created Linen Module that wraps a specific NNX Module.""" def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - - def __init__(self, - args=None, - kwargs=None, - nnx_class=None, - skip_rng=None, - metadata_fn=None, - name=_MISSING, - parent=_MISSING, - **other_kwargs, - ): - linen_kwargs = {} - if not isinstance(parent, _Missing): - linen_kwargs["parent"] = parent - if not isinstance(name, _Missing): - linen_kwargs["name"] = name - ToLinen.__init__( - self, - nnx_class=nnx_class or base_nnx_class, - args=args or (), - metadata_fn=metadata_fn or base_metadata_fn, - skip_rng=skip_rng or base_skip_rng, - kwargs=FrozenDict({**partial_kwargs, **(kwargs or {}), **other_kwargs}), - **linen_kwargs, - ) - cls.__init__ = __init__ + ToLinenPartial.__init__ = __init__ + return ToLinenPartial diff --git a/MaxText/layers/qwen3.py b/MaxText/layers/qwen3.py index 292e0ede0..f32bbb6fc 100644 --- a/MaxText/layers/qwen3.py +++ b/MaxText/layers/qwen3.py @@ -1,4 +1,4 @@ -# Copyright 2023–2025 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,112 +22,166 @@ from jax.sharding import Mesh import jax.numpy as jnp -from flax import linen as nn +from flax import nnx -from MaxText.common_types import Config +from MaxText.common_types import Config, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from MaxText.layers import attentions from MaxText.layers import initializers from MaxText.layers import linears from MaxText.layers import moe +from MaxText.layers import nnx_wrappers from MaxText.layers import quantizations -from MaxText.layers.normalizations import rms_norm +from MaxText.layers.normalizations import RMSNorm from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.inference import page_manager # ----------------------------------------- -# Helper functions for Qwen3 layers +# Self-Attention Block for Qwen3 # ----------------------------------------- -def self_attention_with_norm( - inputs: jnp.ndarray, - cfg: Config, - mesh: Mesh, - quant: Optional[Quant], - decoder_segment_ids: Optional[jnp.ndarray], - decoder_positions: Optional[jnp.ndarray], - deterministic: bool, - model_mode: str, -): - """A helper function for self-attention block with normalization.""" - - inputs_checkpoint = checkpoint_name(inputs, "decoder_layer_input") - - # Corresponds to Qwen3's `input_layernorm` - lnx = rms_norm( - num_features=inputs.shape[-1], - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="pre_self_attention_layer_norm", - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - )(inputs_checkpoint) - lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) - - # Self-attention block - attention_layer = attentions.attention_as_linen( - config=cfg, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - name="self_attention", - quant=quant, - kv_quant=quantizations.configure_kv_quant(cfg), - use_qk_norm=cfg.use_qk_norm, - query_pre_attn_scalar=(cfg.head_dim**-0.5), # Qwen3 specific scaling - model_mode=model_mode, - ) - - attention_output = attention_layer( - lnx, # inputs_q - lnx, # inputs_kv - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode, - ) - attention_output = nn.with_logical_constraint( - attention_output, ("activation_batch", "activation_length", "activation_embed") - ) - - # Residual connection after attention - residual_after_attention = inputs_checkpoint + attention_output - - # Post Attention LayerNorm (corresponds to Qwen3's `post_attention_layernorm`) - hidden_states = rms_norm( - num_features=residual_after_attention.shape[-1], - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="post_self_attention_layer_norm", - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - )(residual_after_attention) - hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed")) - - return hidden_states, residual_after_attention +class Qwen3SelfAttentionWithNorm(nnx.Module): + """A self-attention block with pre and post normalization.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + quant: Optional[Quant], + model_mode: str, + rngs: nnx.Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + cfg = self.config + + if model_mode == MODEL_MODE_PREFILL: + seq_len = cfg.max_prefill_predict_length + elif model_mode == MODEL_MODE_AUTOREGRESSIVE: + seq_len = 1 + else: + seq_len = cfg.max_target_length + + dummy_inputs_shape = (cfg.micro_batch_size_to_train_on, seq_len, cfg.emb_dim) + + self.pre_self_attention_layer_norm = RMSNorm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="pre_self_attention_layer_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=rngs, + ) + self.self_attention = attentions.Attention( + config=cfg, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=dummy_inputs_shape, + inputs_kv_shape=dummy_inputs_shape, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=quant, + kv_quant=quantizations.configure_kv_quant(cfg), + use_qk_norm=cfg.use_qk_norm, + query_pre_attn_scalar=(cfg.head_dim**-0.5), # Qwen3 specific scaling + model_mode=model_mode, + rngs=rngs, + ) + self.post_self_attention_layer_norm = RMSNorm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="post_self_attention_layer_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=rngs, + ) + + def __call__( + self, + inputs: jnp.ndarray, + decoder_segment_ids: Optional[jnp.ndarray], + decoder_positions: Optional[jnp.ndarray], + deterministic: bool, + model_mode: str, + activation_axis_names: tuple[str, ...], + ): + """Helper function for self-attention block with normalization.""" + inputs_checkpoint = checkpoint_name(inputs, "decoder_layer_input") + + lnx = self.pre_self_attention_layer_norm(inputs_checkpoint) + lnx = nnx.with_logical_constraint(lnx, activation_axis_names) + + attention_output = self.self_attention( + lnx, # inputs_q + lnx, # inputs_kv + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) + attention_output = nnx.with_logical_constraint(attention_output, activation_axis_names) + + residual_after_attention = inputs_checkpoint + attention_output + + hidden_states = self.post_self_attention_layer_norm(residual_after_attention) + hidden_states = nnx.with_logical_constraint(hidden_states, activation_axis_names) + + return hidden_states, residual_after_attention # ----------------------------------------- # The Dense Decoder Layer for Qwen3 # ----------------------------------------- -class Qwen3DecoderLayer(nn.Module): + + +class Qwen3DecoderLayer(nnx.Module): """Qwen3 Transformer decoder layer (dense).""" - config: Config - mesh: Mesh - model_mode: str - quant: Optional[Quant] = None + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: Optional[Quant] = None, + ): + self.config = config + self.mesh = mesh + self.quant = quant + cfg = self.config + + self.attention = Qwen3SelfAttentionWithNorm( + config=cfg, + mesh=self.mesh, + quant=self.quant, + model_mode=model_mode, + rngs=rngs, + ) + + self.mlp = linears.MlpBlock( + in_features=cfg.emb_dim, + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="mlp", + config=cfg, + quant=self.quant, + model_mode=model_mode, + rngs=rngs, + ) - @nn.compact def __call__( self, inputs: jnp.ndarray, @@ -141,55 +195,74 @@ def __call__( ): cfg = self.config - hidden_states, residual_after_attention = self_attention_with_norm( + if model_mode == MODEL_MODE_PREFILL: + activation_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + else: + activation_axis_names = ("activation_batch", "activation_length", "activation_embed") + + hidden_states, residual_after_attention = self.attention( inputs, - cfg, - self.mesh, - self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode, + activation_axis_names, ) - # Dense MLP block - mlp_output = linears.mlp_block( - in_features=hidden_states.shape[-1], - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="mlp", - config=cfg, - quant=self.quant, - )(hidden_states, deterministic=deterministic) + mlp_output = self.mlp(hidden_states, deterministic=deterministic) - # Final residual connection layer_output = residual_after_attention + mlp_output - layer_output = nn.with_logical_constraint( - layer_output, - ("activation_batch", "activation_length", "activation_embed"), - ) + layer_output = nnx.with_logical_constraint(layer_output, activation_axis_names) if cfg.scan_layers: return layer_output, None - else: - return layer_output + return layer_output # ----------------------------------------- # The MoE Decoder Layer for Qwen3 # ----------------------------------------- -class Qwen3MoeDecoderLayer(nn.Module): + + +class Qwen3MoeDecoderLayer(nnx.Module): """Qwen3 Transformer decoder layer (MoE).""" - config: Config - mesh: Mesh - model_mode: str - quant: Optional[Quant] = None + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + rngs: nnx.Rngs, + quant: Optional[Quant] = None, + ): + self.config = config + self.mesh = mesh + self.quant = quant + cfg = self.config + + self.attention = Qwen3SelfAttentionWithNorm( + config=cfg, + mesh=self.mesh, + quant=self.quant, + model_mode=model_mode, + rngs=rngs, + ) + + self.moe_block = moe.MoeBlock( + config=cfg, + num_experts=cfg.num_experts, + num_experts_per_tok=cfg.num_experts_per_tok, + mesh=self.mesh, + kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), + kernel_axes=("embed", None), + intermediate_dim=cfg.moe_mlp_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="moe_block", + quant=self.quant, + rngs=rngs, + ) - @nn.compact def __call__( self, inputs: jnp.ndarray, @@ -203,45 +276,42 @@ def __call__( ): cfg = self.config - hidden_states, residual_after_attention = self_attention_with_norm( + if model_mode == MODEL_MODE_PREFILL: + activation_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + else: + activation_axis_names = ("activation_batch", "activation_length", "activation_embed") + + hidden_states, residual_after_attention = self.attention( inputs, - cfg, - self.mesh, - self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode, + activation_axis_names, ) - # Mixture of Experts block - mlp_output, load_balance_loss = moe.get_routed_moe( - config=cfg, - num_experts=cfg.num_experts, - num_experts_per_tok=cfg.num_experts_per_tok, - mesh=self.mesh, - kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"), - kernel_axes=("embed", None), - intermediate_dim=cfg.moe_mlp_dim, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="moe_block", - quant=self.quant, - )(hidden_states) + mlp_output, load_balance_loss = self.moe_block(hidden_states) if load_balance_loss is not None: self.sow("intermediates", "moe_lb_loss", load_balance_loss) - mlp_output = nn.with_logical_constraint(mlp_output, ("activation_batch", "activation_length", "activation_embed")) + mlp_output = nnx.with_logical_constraint(mlp_output, activation_axis_names) # Final residual connection layer_output = residual_after_attention + mlp_output - layer_output = nn.with_logical_constraint( - layer_output, - ("activation_batch", "activation_length", "activation_embed"), - ) + layer_output = nnx.with_logical_constraint(layer_output, activation_axis_names) if cfg.scan_layers: return layer_output, None - else: - return layer_output + return layer_output + + +# Linen wrappers for backward compatibility +Qwen3DecoderLayerToLinen = nnx_wrappers.to_linen_class( + Qwen3DecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) +Qwen3MoeDecoderLayerToLinen = nnx_wrappers.to_linen_class( + Qwen3MoeDecoderLayer, + base_metadata_fn=initializers.variable_to_logically_partitioned, +) diff --git a/MaxText/maxtext_utils.py b/MaxText/maxtext_utils.py index be2091cf8..29f016116 100644 --- a/MaxText/maxtext_utils.py +++ b/MaxText/maxtext_utils.py @@ -1068,10 +1068,7 @@ def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: Optio """Get a shaped abstraction of the state (including optimizer)""" def init_kv_cache(model, config): - input_shape = ( - config.global_batch_size_to_load, - config.max_prefill_predict_length, - ) + input_shape = (config.micro_batch_size_to_train_on, config.max_prefill_predict_length) image_shape = get_dummy_image_shape_for_init(config) model_vars = model.init( @@ -1098,10 +1095,7 @@ def get_kv_cache_annotations(model, config, rng, mesh, page_state: Optional[Page """Get a shaped abstraction of the state (including optimizer)""" def init_kv_cache(model, config): - input_shape = ( - config.global_batch_size_to_load, - 1, - ) + input_shape = (config.micro_batch_size_to_train_on, 1) image_shape = get_dummy_image_shape_for_init(config) model_vars = model.init(