diff --git a/MaxText/layers/decoders.py b/MaxText/layers/decoders.py index 3f12461eb..dd8fd5785 100644 --- a/MaxText/layers/decoders.py +++ b/MaxText/layers/decoders.py @@ -16,7 +16,7 @@ # pylint: disable=arguments-differ # pylint: disable=no-name-in-module -from typing import Any +from typing import Any, Callable, Union, Type import functools import jax @@ -25,20 +25,23 @@ from jax.sharding import Mesh from flax import linen as nn -from flax.linen.partitioning import ScanIn +from flax.core.spmd import logical_axis_rules as nn_logical_axis_rules +from flax import nnx +import numpy as np -from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from MaxText.common_types import Array, DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE from MaxText import max_logging from MaxText import max_utils from MaxText.inference import page_manager +from MaxText.layers import initializers from MaxText.layers import linears from MaxText.layers import quantizations from MaxText.layers import pipeline from MaxText import maxtext_utils from MaxText import multimodal_utils -from MaxText.layers.attentions import attention_as_linen -from MaxText.layers.normalizations import rms_norm -from MaxText.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen +from MaxText.layers.attentions import Attention +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.embeddings import attend_on_embedding, Embed, PositionalEmbedding,_MAX_WAVELENGTH from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.layers import ( deepseek, @@ -59,84 +62,102 @@ # ------------------------------------------------------------------------------ -class DecoderLayer(nn.Module): +class DecoderLayer(nnx.Module): """ Transformer decoder layer that attends to the encoder. This is the core, reusable building block for both the main model's decoder stack and the auxiliary MTP layers. """ - config: Config - mesh: Mesh - model_mode: str - quant: None | Quant = None + def __init__(self, config: Config, mesh: Mesh, quant: Quant | None = None, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs | None = None): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs if rngs is not None else nnx.Rngs(0) + inputs_shape = ( + int(self.config.per_device_batch_size), + int(self.config.max_target_length), + int(self.config.emb_dim), + ) - @nn.compact - def __call__( - self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - ): - cfg = self.config - mesh = self.mesh - if model_mode == MODEL_MODE_PREFILL: - logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") - else: - logical_axis_names = ("activation_batch", "activation_length", "activation_embed") + self.mlp = linears.MlpBlock( + in_features=inputs_shape[-1], + intermediate_dim=self.config.mlp_dim, + activations=self.config.mlp_activations, + intermediate_dropout_rate=self.config.dropout_rate, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + config=self.config, + quant=quant, + model_mode=model_mode, + rngs=self.rngs + ) - if model_mode == MODEL_MODE_PREFILL: - inputs = nn.with_logical_constraint(inputs, logical_axis_names) - else: - inputs = nn.with_logical_constraint(inputs, logical_axis_names) + self.drop_out = nnx.Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,),rngs=self.rngs) - inputs = checkpoint_name(inputs, "decoder_layer_input") - # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - lnx = rms_norm( - num_features=inputs.shape[-1], - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="pre_self_attention_norm", - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - )(inputs) - if model_mode == MODEL_MODE_PREFILL: - lnx = nn.with_logical_constraint(lnx, logical_axis_names) - else: - lnx = nn.with_logical_constraint(lnx, logical_axis_names) + self.pre_self_attention_norm = RMSNorm( + num_features=inputs_shape[-1], + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + kernel_axes=("norm", ), + epsilon=self.config.normalization_layer_epsilon, + rngs=self.rngs + ) - attention_layer = attention_as_linen( + self.self_attention = Attention( config=self.config, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, + num_query_heads=self.config.num_query_heads, + num_kv_heads=self.config.num_kv_heads, + head_dim=self.config.head_dim, + max_target_length=self.config.max_target_length, + max_prefill_predict_length=self.config.max_prefill_predict_length, + attention_kernel=self.config.attention, + inputs_q_shape=inputs_shape, + inputs_kv_shape=inputs_shape, mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, + dtype=self.config.dtype, + weight_dtype=self.config.weight_dtype, + dropout_rate=self.config.dropout_rate, name="self_attention", - float32_qk_product=cfg.float32_qk_product, - float32_logits=cfg.float32_logits, + float32_qk_product=self.config.float32_qk_product, + float32_logits=self.config.float32_logits, quant=self.quant, - kv_quant=quantizations.configure_kv_quant(cfg), - prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), - ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), - compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), - reshape_q=cfg.reshape_q, + kv_quant=quantizations.configure_kv_quant(self.config), + prefill_cache_axis_order=tuple(map(int, self.config.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, self.config.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, self.config.compute_axis_order.split(","))), + reshape_q=self.config.reshape_q, model_mode=model_mode, + rngs=self.rngs ) - attention_lnx = attention_layer( + def __call__( + self, + inputs : Array, + decoder_segment_ids : Array, + decoder_positions : Array, + deterministic : bool, + model_mode : str , + previous_chunk: Array | None = None, + slot: int | None = None, + page_state: page_manager.PageState | None = None, + ): + cfg = self.config + logical_axis_names = ( + ("activation_batch", "prefill_activation_length", "activation_embed") + if model_mode == MODEL_MODE_PREFILL + else ("activation_batch", "activation_length", "activation_embed") + ) + + inputs = nn.with_logical_constraint(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] + lnx = self.pre_self_attention_norm(inputs) + lnx = nn.with_logical_constraint(lnx, logical_axis_names) + + attention_lnx = self.self_attention( lnx, lnx, decoder_positions, @@ -144,53 +165,29 @@ def __call__( deterministic=deterministic, model_mode=model_mode, ) - - if model_mode == MODEL_MODE_PREFILL: - attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names) - else: - attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names) + attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names) # MLP block. - mlp_lnx = linears.mlp_block( - in_features=lnx.shape[-1], - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="mlp", - model_mode=model_mode, - config=cfg, - quant=self.quant, - )(lnx, deterministic=deterministic) - if model_mode == MODEL_MODE_PREFILL: - mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names) - else: - mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names) + mlp_lnx = self.mlp(lnx, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names) next_layer_addition = mlp_lnx + attention_lnx - next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + next_layer_addition_dropped_out = self.drop_out( next_layer_addition, deterministic=deterministic ) layer_output = next_layer_addition_dropped_out + inputs - if model_mode == MODEL_MODE_PREFILL: - layer_output = nn.with_logical_constraint( - layer_output, - logical_axis_names, - ) - else: - layer_output = nn.with_logical_constraint( - layer_output, - logical_axis_names, - ) + layer_output = nn.with_logical_constraint( + layer_output, + logical_axis_names, + ) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output)) + self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output)) self.sow( - "intermediates", + nnx.Intermediate, "activation_fraction_zero", jnp.sum(layer_output == 0) / jnp.size(layer_output), ) @@ -198,27 +195,28 @@ def __call__( return layer_output, None if cfg.scan_layers else layer_output -class SequentialBlockDecoderLayers(nn.Module): +class SequentialBlockDecoderLayers(nnx.Module): """Sequential unscanned series of decoder layers.""" - decoder_layer: Any - num_decoder_layers: int - config: Config - mesh: Mesh - quant: Quant - model_mode: str + def __init__(self,decoder_layer:Any,num_decoder_layers:int, config: Config, mesh: Mesh, quant: Quant|None=None, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rngs | None = None): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs if rngs is not None else nnx.Rngs(0) + self.num_decoder_layers = num_decoder_layers + self.decoder_layer = decoder_layer - @nn.compact def __call__( self, - inputs: jnp.ndarray, - decoder_segment_ids, - decoder_positions, + inputs: Array, + decoder_segment_ids: Array, + decoder_positions : Array, deterministic: bool, - model_mode, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - ) -> jnp.ndarray: + model_mode : str, + slot: int | None = None, + page_state: page_manager.PageState | None = None, + ) -> Union[Array, tuple[Array, None]]: for lyr in range(self.num_decoder_layers): inputs = self.decoder_layer( config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode @@ -235,162 +233,208 @@ def __call__( inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). if self.config.scan_layers: return inputs, None # pytype: disable=bad-return-type - else: - return inputs + return inputs +class Decoder(nnx.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture. + """ -class Decoder(nn.Module): - """A stack of decoder layers as a part of an encoder-decoder architecture.""" + BROADCAST_ARGS_LENGTH : int = 4 - config: Config - shared_embedding: nn.Module - mesh: Mesh - quant: None | Quant = None - model_mode: str = MODEL_MODE_TRAIN + def __init__( + self, + config: Config, + shared_embedding: nn.Module, + mesh: Mesh, + quant: Quant | None=None, + model_mode: str = MODEL_MODE_TRAIN, + rngs : nnx.Rngs | None = None, + ): + self.config = config + self.shared_embedding = shared_embedding + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs if rngs is not None else nnx.Rngs(0) + + layer_classes = self.get_decoder_layers() + policy = self.get_remat_policy() + self.rematted_layer_classes = self.set_remat_policy(layer_classes, policy) + + self.embedding_dropout = nnx.Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,)) + self.output_head_dropout = nnx.Dropout(rate=self.config.dropout_rate, broadcast_dims=(-2,)) + + self.decoder_norm = self.get_norm_layer(num_features=self.config.emb_dim)( + dtype=config.dtype, + weight_dtype=config.weight_dtype, + epsilon=config.normalization_layer_epsilon, + kernel_axes=("norm",), + parameter_memory_host_offload=config.parameter_memory_host_offload, + rngs=self.rngs, + ) - def setup(self): - """Initialize decoder layer.""" - self.decoder_layer = self.get_decoder_layers() - self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim) - if self.config.using_pipeline_parallelism: - pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) + self._pipeline_module: pipeline.Pipeline | None = None + if config.using_pipeline_parallelism: + pipeline_stage_module = self.get_pipeline_stage_module(layer_classes) remat_policy = self.get_remat_policy() - self.pipeline_module = pipeline.Pipeline( - config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy + self._pipeline_module = pipeline.Pipeline( + config=config, + mesh=self.mesh, + layers=pipeline_stage_module, + remat_policy=remat_policy, ) + if config.decoder_block == DecoderBlockType.DEEPSEEK: + self._build_exec_deepseek_pipeline() + else: + self._build_exec_standard_pipeline() + else: + if config.scan_layers: + self._build_exec_scanned() + else: + self._build_exec_unscanned() - def get_remat_policy(self): - """Get remat policy""" - policy = None - cfg = self.config - if cfg.remat_policy != "none": - if cfg.remat_policy == "minimal": - policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - elif cfg.remat_policy == "save_dot_with_context_except_mlp": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "context", - "out_proj", - ) - elif cfg.remat_policy == "save_dot_except_mlpwi": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwo", - ) - elif cfg.remat_policy == "save_dot_except_mlp": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", + sequence_length = int(config.max_target_length) + if model_mode == MODEL_MODE_PREFILL: + sequence_length = int(config.max_prefill_predict_length) + elif model_mode == MODEL_MODE_AUTOREGRESSIVE: + sequence_length = 1 + + inputs_shape = ( + int(config.micro_batch_size_to_train_on), + sequence_length, + int(self.config.emb_dim), + ) + + self.logits_dense = linears.dense_general( + inputs_shape=inputs_shape, + out_features_shape=self.config.vocab_size, + weight_dtype=self.config.weight_dtype, + dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, + kernel_axes=("embed", "vocab"), + name="logits_dense", + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + ) + + # Untrainable (static) positional embedding + self._static_pos_embedding = None + if self.config.use_untrainable_positional_embedding: + self._static_pos_embedding = PositionalEmbedding( + embedding_dims=self.config.base_emb_dim, + max_wavelength=_MAX_WAVELENGTH, ) - elif cfg.remat_policy == "save_qkv_proj": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", + + # Trainable position embedding + self.position_embedder = None + if self.config.trainable_position_size > 0: + self.position_embedder = Embed( + num_embeddings=self.config.trainable_position_size, + num_features=self.config.emb_dim, + dtype=self.config.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=self.config, + rngs=self.rngs, ) - elif cfg.remat_policy == "qkv_proj_offloaded": - policy = jax.checkpoint_policies.save_and_offload_only_these_names( + + def get_remat_policy(self)-> Callable[..., bool]|None: + cfg = self.config + policy_name = cfg.remat_policy + + if policy_name == "none" or policy_name == "full": + return None + + static_policies : dict[str,Any] = { + "minimal": jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + "save_dot_with_context_except_mlp": jax.checkpoint_policies.save_only_these_names( + "query_proj", "value_proj", "key_proj", "qkv_proj", "context", "out_proj" + ), + "save_dot_except_mlpwi": jax.checkpoint_policies.save_only_these_names( + "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj", "mlpwo" + ), + "save_dot_except_mlp": jax.checkpoint_policies.save_only_these_names( + "query_proj", "value_proj", "key_proj", "qkv_proj", "out_proj" + ), + "save_qkv_proj": jax.checkpoint_policies.save_only_these_names( + "query_proj", "value_proj", "key_proj", "qkv_proj" + ), + "save_out_proj": jax.checkpoint_policies.save_only_these_names("out_proj"), + "minimal_flash": jax.checkpoint_policies.save_from_both_policies( + jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + jax.checkpoint_policies.save_only_these_names("context") + ), + } + + dynamic_policies : dict[str,Callable] = { + "qkv_proj_offloaded": lambda: jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], offload_src="device", - offload_dst="pinned_host", - ) - elif cfg.remat_policy == "minimal_offloaded": - policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host") - elif cfg.remat_policy == "custom": - policy = jax.checkpoint_policies.save_and_offload_only_these_names( + offload_dst="pinned_host" + ), + "minimal_offloaded": lambda: jax.checkpoint_policies.offload_dot_with_no_batch_dims( + offload_src="device", + offload_dst="pinned_host" + ), + "custom": lambda: jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=cfg.tensors_on_device, names_which_can_be_offloaded=cfg.tensors_to_offload, offload_src="device", - offload_dst="pinned_host", - ) - elif cfg.remat_policy == "minimal_flash": - policy = jax.checkpoint_policies.save_from_both_policies( - jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, - jax.checkpoint_policies.save_only_these_names( - "context", - ), - ) - elif cfg.remat_policy == "save_out_proj": - policy = jax.checkpoint_policies.save_only_these_names( - "out_proj", - ) - else: - assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" - policy = None - return policy - - def get_decoder_layers(self): - """Retrieves a list of decoder layer classes based on the `decoder_block` config. - - Returns: - A list containing one or more `nn.Module` classes for the decoder. - """ - match self.config.decoder_block: - case DecoderBlockType.DEFAULT: - return [DecoderLayer] - case DecoderBlockType.LLAMA2: - return [llama2.LlamaDecoderLayer] - case DecoderBlockType.MISTRAL: - # TODO(ranran): update to Mistral with sliding window attention - return [mistral.MistralDecoderLayer] - case DecoderBlockType.MIXTRAL: - return [mixtral.MixtralDecoderLayer] - case DecoderBlockType.DEEPSEEK: - return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] - case DecoderBlockType.GEMMA: - return [gemma.GemmaDecoderLayer] - case DecoderBlockType.GEMMA2: - return [gemma2.Gemma2DecoderLayer] - case DecoderBlockType.GEMMA3: - return [gemma3.Gemma3DecoderLayer] - case DecoderBlockType.GPT3: - return [gpt3.Gpt3DecoderLayer] - case DecoderBlockType.QWEN3: - return [qwen3.Qwen3DecoderLayer] - case DecoderBlockType.QWEN3_MOE: - return [qwen3.Qwen3MoeDecoderLayer] - case DecoderBlockType.SIMPLE: - return [simple_layer.SimpleDecoderLayer] - case DecoderBlockType.SIMPLE_MLP: - return [simple_layer.SimpleMlpDecoderLayer] - case DecoderBlockType.LLAMA4: - return [llama4.Llama4ScannableBlock] if self.config.scan_layers else [llama4.Llama4DecoderLayer] - case _: - # Default case to handle any unknown decoder block types. - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") - - def set_remat_policy(self, block_layers, policy): - """Set remat policy""" + offload_dst="pinned_host" + ), + } + + if policy_name in static_policies: + return static_policies[policy_name] + elif policy_name in dynamic_policies: + return dynamic_policies[policy_name]() + raise ValueError(f"Remat policy needs to be on list of remat policies, get : '{policy_name}'") + + def get_decoder_layers(self)->list[Type[nnx.Module]]: + # TODO(ranran): update to Mistral with sliding window attention + decoder_layer_map = { + DecoderBlockType.DEFAULT: [DecoderLayer], + DecoderBlockType.LLAMA2: [llama2.LlamaDecoderLayer], + DecoderBlockType.MISTRAL: [mistral.MistralDecoderLayer], + DecoderBlockType.MIXTRAL: [mixtral.MixtralDecoderLayer], + DecoderBlockType.DEEPSEEK: [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer], + DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer], + DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer], + DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer], + DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer], + DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer], + DecoderBlockType.SIMPLE: [simple_layer.SimpleDecoderLayer], + DecoderBlockType.SIMPLE_MLP: [simple_layer.SimpleMlpDecoderLayer], + DecoderBlockType.LLAMA4: ( + [llama4.Llama4ScannableBlock] + if self.config.scan_layers + else [llama4.Llama4DecoderLayer] + ), + } + + decoder_type = self.config.decoder_block + if decoder_type in decoder_layer_map: + return decoder_layer_map[decoder_type] + raise ValueError(f"Incorrect decoder_block name: {decoder_type.value}") + + def set_remat_policy(self, block_layers, policy : Callable[..., bool]|None = None)->list[Type[nnx.Module]]: RemattedBlockLayers = [] for block_layer in block_layers: if self.config.parameter_memory_host_offload: - # Define parameter movement with mesh-based sharding def move_to_device(variables): """Move parameters to device with proper sharding.""" - def map_fn(path, value): max_logging.log(f"models.py: Moving parameter {path} to device") return jax.device_put(value, max_utils.device_space()) return jax.tree_util.tree_map_with_path(map_fn, variables) - + # Transform layer class before remat - block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True) + graphdef, params = nnx.split(block_layer, nnx.Param) + params = move_to_device(params) + block_layer = nnx.merge(graphdef, params) # Apply remat policy to layer - layer = nn.remat( + layer = nnx.remat( block_layer, prevent_cse=not self.config.scan_layers, policy=policy, @@ -399,9 +443,8 @@ def map_fn(path, value): RemattedBlockLayers.append(layer) return RemattedBlockLayers - def get_norm_layer(self, num_features: int): - """get normalization layer (return type inherits from nn.Module)""" - if self.config.decoder_block in ( + def get_norm_layer(self, num_features: int)-> Callable[...,Any]: + if self.config.decoder_block in { DecoderBlockType.DEFAULT, DecoderBlockType.LLAMA2, DecoderBlockType.MISTRAL, @@ -415,19 +458,271 @@ def get_norm_layer(self, num_features: int): DecoderBlockType.SIMPLE, DecoderBlockType.SIMPLE_MLP, DecoderBlockType.LLAMA4, - ): - return functools.partial(rms_norm, num_features=num_features) + }: + return functools.partial(RMSNorm, num_features=num_features,rngs=self.rngs) elif self.config.decoder_block == DecoderBlockType.GPT3: - return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True) + return functools.partial( + gpt3.Gpt3LayerNorm, + num_features=num_features, + reductions_in_fp32=False, + use_bias=True, + rngs=self.rngs, + ) + raise ValueError(f"Incorrect config decoder_block name : {self.config.decoder_block.value}") + + def _make_scan_runner(self, layer_ctor_or_fn, length:int, name:str, **layer_kwargs): + """Return a callable: run(y, *broadcast_args) -> y that executes a prebuilt scan.""" + cfg = self.config + mesh = self.mesh + + def build_scan_fn(): + # we bake in_axes based on the 4-tuple broadcast args we use everywhere: + in_axes_tuple = (nn.broadcast,) * self.BROADCAST_ARGS_LENGTH + return self.scan_decoder_layers( + cfg, layer_ctor_or_fn, length, name, mesh, in_axes_tuple, + model_mode=self.model_mode, **layer_kwargs + ) + + scan_fn = build_scan_fn() + + def run(y, *broadcast_args, **kwargs): + y, _ = scan_fn(y, *broadcast_args, **kwargs) + return y + + return run + + def _calculate_partition_spec(self, y, decoder_segment_ids, decoder_positions, deterministic, model_mode): + return ( + None if not self.config.pipeline_fsdp_ag_once + else + self.pipeline_module.get_weight_sharding( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) + ) + + def _build_exec_deepseek_pipeline(self): + cfg = self.config + mesh = self.mesh + if len(self.rematted_layer_classes) != 2: + raise ValueError( + f"Scanned layers must have a length of 2 when using DeepSeek, " + f"but got {len(self.rematted_layer_classes)}." + ) + dense_cls, moe_cls = self.rematted_layer_classes + + num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers + num_moe_layers_outside_pp = num_moe_layers - cfg.pipeline_parallel_layers + + # Prebuild scan runners (outside pipeline) + self.dense_layers = self._make_scan_runner( + dense_cls, cfg.first_num_dense_layers, "dense_layers" + ) + self.moe_layers = None + if num_moe_layers_outside_pp > 0: + self.moe_layers = self._make_scan_runner( + moe_cls, num_moe_layers_outside_pp, "moe_layers" + ) + + logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + + def exec_run(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=None, slot=None, page_state=None, bidirectional_mask=None): + # Execute scans outside pipeline under adjusted logical axis rules + with mesh, nn_logical_axis_rules(logical_axis_rules_pp_as_dp): + y = self.dense_layers(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + if self.moe_layers is not None: + y = self.moe_layers(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + + # Optionally compute weight sharding once (shape-dependent) + partition_spec = self._calculate_partition_spec( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + + # Pipeline proper (stage module was built in __init__) + y = self.pipeline_module(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + partition_spec=partition_spec) + return y + + self._exec = exec_run + + def _build_exec_standard_pipeline(self): + """Pipeline (non-DeepSeek). After pipeline, possibly run remaining scanned layers.""" + cfg = self.config + mesh = self.mesh + # Remaining layers after pipeline, if any, are scanned with the single base layer + remaining_layers = cfg.num_decoder_layers - cfg.pipeline_parallel_layers + self.layers_outside_pipeline = None + if remaining_layers > 0: + self.layers_outside_pipeline = self._make_scan_runner( + self.rematted_layer_classes[0], remaining_layers, "layers_outside_pipeline" + ) + logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + + def exec_run(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=None, slot=None, page_state=None, bidirectional_mask=None): + + # Optionally compute weight sharding once (shape-dependent) + partition_spec = self._calculate_partition_spec( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + + y = self.pipeline_module(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + partition_spec=partition_spec) + + if self.layers_outside_pipeline is not None: + with mesh, nn_logical_axis_rules(logical_axis_rules_pp_as_dp): + y = self.layers_outside_pipeline(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + return y + + self._exec = exec_run + + def _build_exec_scanned(self): + """No pipeline, scanned execution. Handle DeepSeek / Gemma3 / Llama4 / others.""" + cfg = self.config + mesh = self.mesh + + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + if len(self.rematted_layer_classes) != 2: + raise ValueError( + f"Scanned layers must have a length of 2 when using DeepSeek, " + f"but got {len(self.rematted_layer_classes)}." + ) + dense_cls, moe_cls = self.rematted_layer_classes + self.dense_layers = self._make_scan_runner(dense_cls, cfg.first_num_dense_layers, "dense_layers") + num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers + self.moe_layers = self._make_scan_runner(moe_cls, num_moe_layers, "moe_layers") + + def exec_run(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=None, slot=None, page_state=None, bidirectional_mask=None): + y = self.dense_layers( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=previous_chunk,page_state=page_state,slot=slot + ) + if self.moe_layers is not None: + y = self.moe_layers(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=previous_chunk,page_state=page_state,slot=slot + ) + return y + + self._exec = exec_run + return + + if cfg.decoder_block == DecoderBlockType.GEMMA3: + attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) + scan_length = cfg.num_decoder_layers // attention_pattern_length + RemattedGemma3Block = self.set_remat_policy([gemma3.Gemma3ScannableBlock], self.get_remat_policy())[0] + + # main scan (full patterns) + self.layers = None + if scan_length > 0: + self.layers = self._make_scan_runner( + RemattedGemma3Block, scan_length, "layers", num_of_layers=attention_pattern_length + ) + + # remainder block (module instance) + self.layers_remainder = None + num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length + if num_remaining_layers > 0: + self.layers_remainder = RemattedGemma3Block( + config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, + name="layers_remainder", num_of_layers=num_remaining_layers + ) + + def exec_run(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=None, slot=None, page_state=None, bidirectional_mask=None): + if self.layers is not None: + y = self.layers(y, decoder_segment_ids, decoder_positions, deterministic, model_mode) + if self.layers_remainder is not None: + # Call remainder with extra kwargs + y, _ = self.layers_remainder( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, + bidirectional_mask=bidirectional_mask, + ) + return y + + self._exec = exec_run + return + + # All other scanned (including LLAMA4 scanned scannable block) + RemattedBlockLayer = self.rematted_layer_classes[0] + scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) + # For LLAMA4 scanned, the layer itself needs kwargs — bake them into the scan builder + llama4_kwargs = {} + if cfg.decoder_block == DecoderBlockType.LLAMA4: + llama4_kwargs = { + "nope_layer_interval": cfg.nope_layer_interval, + "interleave_moe_layer_step": cfg.interleave_moe_layer_step, + } + self.layers = self._make_scan_runner(RemattedBlockLayer, scan_length, "layers", **llama4_kwargs) + + def exec_run(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=None, slot=None, page_state=None, bidirectional_mask=None): + if self.layers is not None: + y = self.layers(y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode + ) + return y + + self._exec = exec_run + + def _build_exec_unscanned(self): + """No pipeline, unscanned (instantiate all per-layer modules now).""" + cfg = self.config + mesh = self.mesh + self._unscanned_layers: list[nnx.Module] = [] + + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + if len(self.rematted_layer_classes) != 2: + raise ValueError( + f"Unscanned layers must have a length of 2 when using DeepSeek, but got {len(self.rematted_layer_classes)}." + ) + dense_cls, moe_cls = self.rematted_layer_classes + # Instantiate all layers now with unique names. + for idx in range(cfg.first_num_dense_layers): + self._unscanned_layers.append( + dense_cls(config=cfg, mesh=mesh, name=f"dense_layers_{idx}", quant=self.quant, model_mode=self.model_mode) + ) + for idx in range(cfg.num_decoder_layers - cfg.first_num_dense_layers): + self._unscanned_layers.append( + moe_cls(config=cfg, mesh=mesh, name=f"moe_layers_{idx}", quant=self.quant, model_mode=self.model_mode) + ) + else: - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + base_cls = self.rematted_layer_classes[0] + for idx in range(cfg.num_decoder_layers): + layer_kwargs = {} + if cfg.decoder_block == DecoderBlockType.GEMMA3: + layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=idx)} + elif cfg.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "is_nope_layer": llama4.determine_is_nope_layer(idx, cfg.nope_layer_interval), + "is_moe_layer": llama4.determine_is_moe_layer(idx, cfg.interleave_moe_layer_step), + } + self._unscanned_layers.append( + base_cls(config=cfg, mesh=mesh, name=f"layers_{idx}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs) + ) - def scan_decoder_layers( - self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, model_mode, **kwargs - ): - """scan decoder layers, calls `flax.linen.transforms.scan`""" - initializing = self.is_mutable_collection("params") - params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) + def exec_run(y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=None, slot=None, page_state=None, bidirectional_mask=None): + for layer in self._unscanned_layers: + y = layer( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode, + previous_chunk=previous_chunk, page_state=page_state, slot=slot, bidirectional_mask=bidirectional_mask + ) + return y + + self._exec = exec_run + + def scan_decoder_layers(self, cfg:Config, decoder_layer: Callable, length:int, metadata_axis_name:str, mesh:Mesh, in_axes_tuple:Any, model_mode:str, **kwargs): + params_spec = cfg.param_scan_axis cache_spec = 0 scan_fn = nn.scan( decoder_layer, @@ -446,22 +741,31 @@ def scan_decoder_layers( length=length, metadata_params={nn.PARTITION_NAME: metadata_axis_name}, ) - return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, model_mode=model_mode, **kwargs) - def get_pipeline_stage_module(self, decoder_blocks): + return scan_fn( + config=cfg, + mesh=mesh, + name=metadata_axis_name, + quant=self.quant, + model_mode=model_mode, + **kwargs + ) + + + def get_pipeline_stage_module(self, decoder_blocks:list[Type[nnx.Module]]) -> nnx.Module: """get pipeline stage module""" - def get_layer_to_pipeline(blocks, cfg): + def get_layer_to_pipeline(blocks: list[Type[nnx.Module]], cfg:Config)->Callable[..., nnx.Module]: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: return blocks[1] # return the sparse block - else: - return blocks[0] + return blocks[0] cfg = self.config base_stage = get_layer_to_pipeline(decoder_blocks, cfg) if cfg.set_remat_policy_on_layers_per_stage: policy = self.get_remat_policy() base_stage = self.set_remat_policy([base_stage], policy)[0] + if cfg.num_layers_per_pipeline_stage == 1: stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode) elif cfg.scan_layers_per_stage: @@ -471,7 +775,7 @@ def get_layer_to_pipeline(blocks, cfg): cfg.num_layers_per_pipeline_stage, "layers_per_stage", self.mesh, - in_axes_tuple=(nn.broadcast,) * 4, + in_axes_tuple=(nn.broadcast,) * self.BROADCAST_ARGS_LENGTH, model_mode=self.model_mode, ) else: @@ -483,383 +787,132 @@ def get_layer_to_pipeline(blocks, cfg): quant=self.quant, model_mode=self.model_mode, ) - return stage_module - - @nn.compact + return stage_module + def _apply_embedding( self, - decoder_input_tokens, - decoder_positions, - deterministic, - model_mode, - image_embeddings=None, + decoder_input_tokens:Array, + decoder_positions:Array, + deterministic:bool, + model_mode:str, + image_embeddings: np.ndarray | Array | None=None, bidirectional_mask=None, - ): - """Applies token and positional embeddings to the input tokens.""" + )->Array: cfg = self.config y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) - + # Merge the image embeddings with the text embeddings for multimodal models if image_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in ["gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e"]: - y = multimodal_utils.merge_mm_embeddings( - text_embeddings=y, - vision_embeddings=image_embeddings, - mask=bidirectional_mask, - ) # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed - else: + if cfg.model_name not in {"gemma3-4b", "gemma3-12b", "gemma3-27b", "llama4-17b-16e", "llama4-17b-128e"}: raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") - - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + + y = multimodal_utils.merge_mm_embeddings( + text_embeddings=y, + vision_embeddings=image_embeddings, + mask=bidirectional_mask, + ) + + y = self.embedding_dropout(y, deterministic=deterministic) y = y.astype(cfg.dtype) if cfg.use_untrainable_positional_embedding: - y = positional_embedding_as_linen(embedding_dims=cfg.base_emb_dim)(y, decoder_positions) - - if cfg.trainable_position_size > 0: - y += embed_as_linen( - num_embeddings=cfg.trainable_position_size, - num_features=cfg.emb_dim, - dtype=cfg.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - name="position_embedder", - config=cfg, - )(decoder_positions, model_mode=model_mode) - return y + y = self.static_pos_embedding(y, decoder_positions) - @nn.compact - def _apply_output_head(self, y, deterministic, model_mode): - """Applies final normalization and projects hidden states to logits.""" + if cfg.trainable_position_size > 0 and self.position_embedder: + y += self.position_embedder(decoder_positions, model_mode=model_mode) + return y + def _apply_output_head(self, y:Array, deterministic:bool, model_mode:str)->Array: cfg = self.config - y = self.get_norm_layer(num_features=y.shape[-1])( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="decoder_norm", - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - parameter_memory_host_offload=cfg.parameter_memory_host_offload, - )(y) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) - # [batch, length, emb_dim] -> [batch, length, vocab_size] + # Use the pre-instantiated norm layer + y = self.decoder_norm(y) + y = self.output_head_dropout(y, deterministic=deterministic) + if cfg.logits_via_embedding: - # Use the transpose of embedding matrix for logit transform. embedding_table = self.shared_embedding.variables["params"]["embedding"] if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): embedding_table = embedding_table.unbox() attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config) - if self.config.normalize_embedding_logits: - # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) if cfg.final_logits_soft_cap: logits = logits / cfg.final_logits_soft_cap logits = jnp.tanh(logits) * cfg.final_logits_soft_cap else: - logits = linears.dense_general( - inputs_shape=y.shape, - out_features_shape=cfg.vocab_size, - weight_dtype=cfg.weight_dtype, - dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=("embed", "vocab"), - name="logits_dense", - matmul_precision=self.config.matmul_precision, - parameter_memory_host_offload=cfg.parameter_memory_host_offload, - )( - y - ) # We do not quantize the logits matmul. - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab")) - else: - logits = nn.with_logical_constraint( - logits, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") - ) + logits = self.logits_dense(y) + + logical_axis_resource = ( + (None, None, "activation_vocab") + if model_mode in {MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE} + else ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") + ) + logits = nn.with_logical_constraint(logits, logical_axis_resource) if self.config.cast_logits_to_fp32: logits = logits.astype(jnp.float32) return logits - @nn.compact def __call__( self, - decoder_input_tokens, - decoder_positions, - decoder_segment_ids=None, - deterministic=False, - model_mode=MODEL_MODE_TRAIN, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - bidirectional_mask: None | Any = None, - image_embeddings: None | jnp.ndarray = None, - ): - cfg = self.config - mesh = self.mesh - assert decoder_input_tokens.ndim == 2 # [batch, len] - - # [batch, length] -> [batch, length, emb_dim] - y = self._apply_embedding( - decoder_input_tokens, decoder_positions, deterministic, model_mode, image_embeddings, bidirectional_mask - ) - - policy = self.get_remat_policy() - RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) - # scan does not support kwargs in layer call, passing broadcast_args as positional arg - broadcast_args = ( - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) - if cfg.using_pipeline_parallelism: - if cfg.pipeline_fsdp_ag_once: - partition_spec = self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode + decoder_input_tokens: Array, + decoder_positions: Array, + decoder_segment_ids: Array|None=None, + deterministic:bool=False, + model_mode:str=MODEL_MODE_TRAIN, + previous_chunk: Array | None =None, + slot: int | None = None, + page_state: page_manager.PageState | None = None, + bidirectional_mask: Array | None = None, + image_embeddings: Array | None = None, + )->tuple[Array,Array]: + + if decoder_input_tokens.ndim != 2: + raise ValueError( + f"`decoder_input_tokens` must have shape [batch, length], " + f"but got array with shape {decoder_input_tokens.shape}." ) - else: - partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers - logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) - # We chose not to pipeline the dense layers, only sparse for SPMD. - with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - y, _ = self.scan_decoder_layers( - cfg, - dense_layer, - cfg.first_num_dense_layers, - "dense_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - if num_moe_layers_outside_pp > 0: - y, _ = self.scan_decoder_layers( - cfg, - moe_layer, - num_moe_layers_outside_pp, - "moe_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - else: # Not DeepSeek - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers - if remaining_layers > 0: - logical_axis_rules_pp_as_dp = maxtext_utils.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) - with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - y, _ = self.scan_decoder_layers( - cfg, - RemattedBlockLayers[0], - remaining_layers, - "layers_outside_pipeline", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - else: - if cfg.scan_layers: - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - layer_call_kwargs = { - "page_state": page_state, - "previous_chunk": previous_chunk, - "slot": slot, - } - dense_layer = RemattedBlockLayers[0] - dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) - y, _ = self.scan_decoder_layers( - cfg, - dense_layer, - cfg.first_num_dense_layers, - "dense_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - moe_layer = RemattedBlockLayers[1] - moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - y, _ = self.scan_decoder_layers( - cfg, - moe_layer, - num_moe_layers, - "moe_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - elif cfg.decoder_block == DecoderBlockType.GEMMA3: - y = self._apply_gemma3_scanned_blocks( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - bidirectional_mask, - previous_chunk, - page_state, - slot, - ) - else: - RemattedBlockLayer = RemattedBlockLayers[0] - scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) - layer_kwargs = {} - if cfg.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "nope_layer_interval": self.config.nope_layer_interval, - "interleave_moe_layer_step": self.config.interleave_moe_layer_step, - } - broadcast_args += (bidirectional_mask,) - y, _ = self.scan_decoder_layers( - cfg, - RemattedBlockLayer, - scan_length, - "layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - **layer_kwargs, - )(y, *broadcast_args) - else: - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." - dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] - - layers = [dense_layer, moe_layer] - layer_prefixes = ["dense_layers", "moe_layers"] - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_layers_list = [cfg.first_num_dense_layers, num_moe_layers] - # Iterate over the two layer groups (dense and MoE) and apply layer transformation - for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes): - for index in range(num_layers): - y = layer( - config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode - )( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - ) - else: - for lyr in range(cfg.num_decoder_layers): - RemattedBlockLayer = RemattedBlockLayers[0] - layer_kwargs = {} - layer_call_kwargs = {} - if cfg.decoder_block == DecoderBlockType.GEMMA3: - # Gemma3 uses both global and sliding window attention depending on the layer index. - layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} - layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} - if cfg.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), - "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), - } - layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} - layer = RemattedBlockLayer( - config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs - ) - y = layer( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - **layer_call_kwargs, - ) - - assert isinstance(y, jax.Array) - - # After the final transformer layer, `y` holds the raw, un-normalized hidden state. - hidden_state = y - - logits = self._apply_output_head(hidden_state, deterministic, model_mode) - - # The API of the Decoder is now a tuple, providing both the main output - # and the raw hidden state needed for auxiliary tasks. - return logits, hidden_state - - def _apply_gemma3_scanned_blocks( - self, - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - bidirectional_mask, - previous_chunk, - page_state, - slot, - ): - """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" - - cfg = self.config - mesh = self.mesh - - # Define the repeating pattern length and calculate how many full blocks to scan - attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) - scan_length = cfg.num_decoder_layers // attention_pattern_length - - policy = self.get_remat_policy() - RemattedGemma3Block = self.set_remat_policy([gemma3.Gemma3ScannableBlock], policy)[0] - - layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} - layer_kwargs = {"num_of_layers": attention_pattern_length} - # Apply the main scan over the full blocks - if scan_length > 0: - broadcast_args = ( - decoder_segment_ids, + y = self._apply_embedding( + decoder_input_tokens, decoder_positions, deterministic, model_mode, + image_embeddings, + bidirectional_mask, ) - y, _ = self.scan_decoder_layers( - cfg, - RemattedGemma3Block, - scan_length, - "layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - **layer_kwargs, - )(y, *broadcast_args, **layer_call_kwargs) - - # Apply any remaining layers that did not fit into a full scanned block - num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length - if num_remaining_layers > 0: - # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions - rem_layer_kwargs = {"num_of_layers": num_remaining_layers} - layer = RemattedGemma3Block( - config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs - ) - y, _ = layer( + + y = self._exec( y, decoder_segment_ids, decoder_positions, deterministic, model_mode, previous_chunk=previous_chunk, - page_state=page_state, slot=slot, - **layer_call_kwargs, + page_state=page_state, + bidirectional_mask=bidirectional_mask, ) - return y + + if not isinstance(y, jax.Array): + raise TypeError(f"Expected `y` to be a jax.Array, but got {type(y).__name__}.") + + hidden_state = y + logits = self._apply_output_head(hidden_state, deterministic, model_mode) + return logits, hidden_state + + @property + def pipeline_module(self) -> pipeline.Pipeline: + if self._pipeline_module is None: + raise RuntimeError("Pipeline module is not initialized. Set 'ici_pipeline_parallelism' or `dcn_pipeline_parallelism` value larger than 1 in config to enable pipeline parallelism.") + return self._pipeline_module + + @property + def static_pos_embedding(self)->nn.Module: + if self._static_pos_embedding is None: + raise RuntimeError("Set 'use_untrainable_positional_embedding' in config to enable positional embedding") + return self._static_pos_embedding diff --git a/MaxText/layers/embeddings.py b/MaxText/layers/embeddings.py index e73dd38e6..2e2a545bc 100644 --- a/MaxText/layers/embeddings.py +++ b/MaxText/layers/embeddings.py @@ -686,22 +686,6 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array: output = output.astype(self.fprop_dtype) return output - -def positional_embedding_as_linen(*, embedding_dims: int, max_wavelength: int = _MAX_WAVELENGTH): - """Initializes the PositionalEmbedding module and returns it as a Linen module. - - Args: - embedding_dims: The dimension of the embeddings. - max_wavelength: The maximum wavelength for the sinusoidal positional embeddings. - """ - return nnx_wrappers.to_linen( - PositionalEmbedding, - embedding_dims=embedding_dims, - max_wavelength=max_wavelength, - metadata_fn=variable_to_logically_partitioned, - ) - - @dataclasses.dataclass(repr=False) class PositionalEmbedding(nnx.Module): """A layer that adds sinusoidal positional embeddings to the input. diff --git a/MaxText/tests/decoders_test.py b/MaxText/tests/decoders_test.py new file mode 100644 index 000000000..84ca46bc0 --- /dev/null +++ b/MaxText/tests/decoders_test.py @@ -0,0 +1,269 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from flax import linen as nn +from flax import nnx +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +import numpy as np +import os +import pytest + +from MaxText import maxtext_utils +from MaxText import pyconfig +from MaxText.globals import PKG_DIR +from MaxText.common_types import Config, MODEL_MODE_TRAIN +from MaxText.layers import embeddings +from MaxText.layers.decoders import Decoder + +class ModelHarness: + """A wrapper providing a unified interface for testing Linen and NNX models.""" + def __init__(self, model, variables_or_state, framework): + self._model = model + self._variables_or_state = variables_or_state + self._framework = framework + + def apply(self, *args, **kwargs): + """Executes the forward pass for either Linen or NNX.""" + if self._framework == 'linen': + rngs = {'dropout': jax.random.PRNGKey(1)} if 'deterministic' in kwargs and not kwargs['deterministic'] else None + return self._model.apply(self._variables_or_state, *args, rngs=rngs, **kwargs) + elif self._framework == 'nnx': + return self._model(*args, **kwargs) + raise TypeError(f"Unsupported model type: {type(self._model)}") + +def get_mesh(config: Config) -> Mesh: + """Provides a JAX device mesh for sharding.""" + devices_array = maxtext_utils.create_device_mesh(config) + return Mesh(devices_array, config.mesh_axes) + +def get_config_with_overrides(**overrides): + argv = [None, os.path.join(PKG_DIR, "configs", "base.yml")] + init_kwargs = { + 'run_name': 'test', + 'skip_jax_distributed_system': True, + } | overrides + + return pyconfig.initialize( + argv, + **init_kwargs + ) + +@pytest.fixture(scope="module") +def detected_framework(): + """ + Inspects the imported Decoder class and returns its framework type. + This runs only once per test session. + """ + if issubclass(Decoder, nnx.Module): + return 'nnx' + # Check for Linen last as NNX modules might have nn.Module in their MRO + if issubclass(Decoder, nn.Module): + return 'linen' + raise TypeError( + "Imported 'Decoder' is not a recognized subclass of flax.linen.Module or flax.nnx.Module" + ) + +@pytest.fixture +def harness_factory(detected_framework): + """ + Returns a factory function that can create a model harness for the + automatically detected framework. + """ + def _create_harness(config, mesh): + framework = detected_framework # The factory "remembers" the framework + key = jax.random.PRNGKey(0) + batch_size, seq_len = int(config.per_device_batch_size), 16 + decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + decoder_positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, -1) + + if framework == 'linen': + shared_embedding = embeddings.embed_as_linen( + num_embeddings=config.vocab_size, + num_features=config.emb_dim, + dtype=config.dtype, + attend_dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + name="token_embedder", + config=config, + ) + model = Decoder(config=config, shared_embedding=shared_embedding, mesh=mesh) + variables = model.init( + {'params': key, 'dropout': key}, + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + return ModelHarness(model=model, variables_or_state=variables, framework='linen') + + elif framework == 'nnx': + rngs = nnx.Rngs(params=0, dropout=1) + shared_embedding = embeddings.Embed( + num_embeddings=config.vocab_size, + num_features=config.emb_dim, + dtype=config.dtype, + attend_dtype=jnp.float32 if config.logits_dot_in_fp32 else config.dtype, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + config=config, + rngs=rngs, + ) + model = Decoder(config=config, shared_embedding=shared_embedding, mesh=mesh, rngs=rngs) + return ModelHarness(model=model, variables_or_state=model, framework='nnx') + + return _create_harness + + +@pytest.fixture +def base_config(): + """Provides a default, immutable config for tests.""" + return get_config_with_overrides( + dropout_rate=0.5, + enable_dropout=True, + per_device_batch_size=4, + ici_tensor_parallelism=1, + scan_layers=False + ) + +class TestUnifiedDecoder: + """A single test class for both Linen and NNX Decoder implementations.""" + + def test_forward_pass_shape(self, harness_factory, base_config): + """Tests the forward pass shape and dtype.""" + mesh = get_mesh(config=base_config) + harness = harness_factory(base_config, mesh) + batch_size, seq_len = int(base_config.per_device_batch_size), 16 + decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + decoder_positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, -1) + + logits, hidden_state = harness.apply( + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + + assert logits.shape == (batch_size, seq_len, base_config.vocab_size) + assert hidden_state.shape == (batch_size, seq_len, base_config.emb_dim) + assert hidden_state.dtype == base_config.dtype + + @pytest.mark.parametrize('use_untrainable', [True, False]) + @pytest.mark.parametrize('trainable_size', [0, 32]) + def test_embedding_logic(self, harness_factory, use_untrainable, trainable_size): + """Tests that enabling positional embeddings changes the output.""" + config_base = get_config_with_overrides( + use_untrainable_positional_embedding=False, trainable_position_size=0 + ) + harness_base = harness_factory(config_base, get_mesh(config_base)) + + config_custom = get_config_with_overrides( + use_untrainable_positional_embedding=use_untrainable, trainable_position_size=trainable_size + ) + harness_custom = harness_factory(config_custom, get_mesh(config_custom)) + + batch_size, seq_len = int(config_base.per_device_batch_size), 16 + tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, -1) + apply_kwargs = dict( + decoder_input_tokens=tokens, decoder_positions=positions, + deterministic=True, model_mode=MODEL_MODE_TRAIN + ) + + _, hidden_state_base = harness_base.apply(**apply_kwargs) + _, hidden_state_custom = harness_custom.apply(**apply_kwargs) + + if not use_untrainable and trainable_size == 0: + np.testing.assert_allclose(hidden_state_base, hidden_state_custom, atol=1e-6) + else: + assert not np.allclose(hidden_state_base, hidden_state_custom) + + @pytest.mark.parametrize('logits_via_embedding', [True, False]) + def test_output_head_logic(self, harness_factory, logits_via_embedding): + """Tests switching between tied and separate output logits layer.""" + config = get_config_with_overrides(logits_via_embedding=logits_via_embedding) + harness = harness_factory(config, get_mesh(config)) + + batch_size, seq_len = int(config.per_device_batch_size), 16 + tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, -1) + + logits, _ = harness.apply( + decoder_input_tokens=tokens, decoder_positions=positions, + deterministic=True, model_mode=MODEL_MODE_TRAIN + ) + assert logits.shape == (batch_size, seq_len, config.vocab_size) + + def test_deterministic_mode_for_dropout(self, harness_factory, base_config): + """Ensures dropout is active only when deterministic is False.""" + config = base_config + mesh = get_mesh(config=config) + + batch_size, seq_len = int(config.per_device_batch_size), 16 + tokens = jax.random.randint(jax.random.PRNGKey(10), (batch_size, seq_len), 0, config.vocab_size) + positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, -1) + + harness1 = harness_factory(config, mesh) + logits_det1, _ = harness1.apply( + decoder_input_tokens=tokens, decoder_positions=positions, + deterministic=True, model_mode=MODEL_MODE_TRAIN + ) + + harness2 = harness_factory(config, mesh) + logits_det2, _ = harness2.apply( + decoder_input_tokens=tokens, decoder_positions=positions, + deterministic=True, model_mode=MODEL_MODE_TRAIN + ) + np.testing.assert_allclose(logits_det1, logits_det2, atol=1e-7, rtol=1e-7) + + harness3 = harness_factory(config, mesh) + logits_nondet, _ = harness3.apply( + decoder_input_tokens=tokens, decoder_positions=positions, + deterministic=False, model_mode=MODEL_MODE_TRAIN + ) + assert not np.allclose(logits_det1, logits_nondet) + + @pytest.mark.parametrize('scan_layers', [True, False]) + def test_scan_layers_flag(self, harness_factory, scan_layers): + """Tests that the model works with and without layer scanning.""" + config = get_config_with_overrides(scan_layers=scan_layers) + harness = harness_factory(config, get_mesh(config)) + + batch_size, seq_len = int(config.per_device_batch_size), 16 + tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, -1) + + logits, hidden_state = harness.apply( + decoder_input_tokens=tokens, decoder_positions=positions, + deterministic=True, model_mode=MODEL_MODE_TRAIN + ) + assert logits.shape == (batch_size, seq_len, config.vocab_size) + assert hidden_state.shape == (batch_size, seq_len, config.emb_dim) + + def test_input_shape_validation(self, harness_factory, base_config): + """Tests that the model raises AssertionError for incorrect input dimensions.""" + mesh = get_mesh(config=base_config) + harness = harness_factory(base_config, mesh) + + batch_size, seq_len = int(base_config.per_device_batch_size), 16 + bad_tokens = jnp.ones((batch_size, seq_len, 1), dtype=jnp.int32) + positions = jnp.arange(seq_len, dtype=jnp.int32).reshape(1, -1) + + with pytest.raises((AssertionError,ValueError)): + harness.apply( + decoder_input_tokens=bad_tokens, + decoder_positions=positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN + ) \ No newline at end of file