Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ model_name: "default" # override config settings to match a specific model. othe
override_model_config: False # When set to true allows overriding model parameters via CLI for the purpose of debugging/testing.
normalization_layer_epsilon: 1.e-05

# Temporary flag to determine whether to use NNX implementation for certain modules.
# This flag mainly targets modules at the decoder layer of the pytree and higher.
# Modules below that will be NNX regardless.
# TODO: Remove this flag when NNX inference memory is optimized
enable_nnx: False

################################## CHECKPOINTING ##################################
# Checkpointing makes the following choices in the following order, starting with (1):
# (1) If there is already a checkpoint for this run_name, we load the latest entire checkpoint.
Expand Down
5 changes: 3 additions & 2 deletions MaxText/inference/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,9 @@ def __init__(
self.model_mode = model_mode
self.use_chunked_prefill = use_chunked_prefill

self._initialize_prefill_caches(model_mode)
self._initialize_ar_cache_vars(model_mode)
if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
self._initialize_prefill_caches(model_mode)
self._initialize_ar_cache_vars(model_mode)

@property
def prefill_key_vars(self):
Expand Down
82 changes: 45 additions & 37 deletions MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,12 @@ def __call__(
):
cfg = self.config
mesh = self.mesh
if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
else:
logical_axis_names = ("activation_batch", "activation_length", "activation_embed")

if model_mode == MODEL_MODE_PREFILL:
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
else:
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
inputs = nn.with_logical_constraint(inputs, logical_axis_names)

inputs = checkpoint_name(inputs, "decoder_layer_input")
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
Expand All @@ -105,10 +102,7 @@ def __call__(
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
)(inputs)
if model_mode == MODEL_MODE_PREFILL:
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
else:
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
lnx = nn.with_logical_constraint(lnx, logical_axis_names)

attention_layer = attention_as_linen(
config=self.config,
Expand All @@ -133,7 +127,7 @@ def __call__(
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
reshape_q=cfg.reshape_q,
model_mode=model_mode,
model_mode=self.model_mode,
)

attention_lnx = attention_layer(
Expand All @@ -142,10 +136,10 @@ def __call__(
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
model_mode=self.model_mode,
)

if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names)
else:
attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names)
Expand All @@ -159,11 +153,11 @@ def __call__(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="mlp",
model_mode=model_mode,
model_mode=self.model_mode,
config=cfg,
quant=self.quant,
)(lnx, deterministic=deterministic)
if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
else:
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
Expand All @@ -175,7 +169,7 @@ def __call__(
)

layer_output = next_layer_addition_dropped_out + inputs
if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
layer_output = nn.with_logical_constraint(
layer_output,
logical_axis_names,
Expand Down Expand Up @@ -221,13 +215,13 @@ def __call__(
) -> jnp.ndarray:
for lyr in range(self.num_decoder_layers):
inputs = self.decoder_layer(
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode
)(
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
slot=slot,
page_state=page_state,
)
Expand Down Expand Up @@ -337,6 +331,13 @@ def get_decoder_layers(self):
Returns:
A list containing one or more `nn.Module` classes for the decoder.
"""
if self.config.enable_nnx:
match self.config.decoder_block:
case DecoderBlockType.MISTRAL:
return [mistral.MistralDecoderLayerToLinen]
case _:
raise ValueError(f"decoder_block name {self.config.decoder_block.value=} not yet supported with enable_nnx=True")

match self.config.decoder_block:
case DecoderBlockType.DEFAULT:
return [DecoderLayer]
Expand Down Expand Up @@ -446,7 +447,14 @@ def scan_decoder_layers(
length=length,
metadata_params={nn.PARTITION_NAME: metadata_axis_name},
)
return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, model_mode=model_mode, **kwargs)
return scan_fn(
config=cfg,
mesh=mesh,
name=metadata_axis_name,
quant=self.quant,
model_mode=self.model_mode,
**kwargs
)

def get_pipeline_stage_module(self, decoder_blocks):
"""get pipeline stage module"""
Expand Down Expand Up @@ -498,7 +506,7 @@ def _apply_embedding(
"""Applies token and positional embeddings to the input tokens."""
cfg = self.config

y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)
y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=self.model_mode)

# Merge the image embeddings with the text embeddings for multimodal models
if image_embeddings is not None and cfg.use_multimodal:
Expand Down Expand Up @@ -526,7 +534,7 @@ def _apply_embedding(
embedding_init=nn.initializers.normal(stddev=1.0),
name="position_embedder",
config=cfg,
)(decoder_positions, model_mode=model_mode)
)(decoder_positions, model_mode=self.model_mode)
return y

@nn.compact
Expand Down Expand Up @@ -572,7 +580,7 @@ def _apply_output_head(self, y, deterministic, model_mode):
)(
y
) # We do not quantize the logits matmul.
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab"))
else:
logits = nn.with_logical_constraint(
Expand Down Expand Up @@ -604,7 +612,7 @@ def __call__(

# [batch, length] -> [batch, length, emb_dim]
y = self._apply_embedding(
decoder_input_tokens, decoder_positions, deterministic, model_mode, image_embeddings, bidirectional_mask
decoder_input_tokens, decoder_positions, deterministic, self.model_mode, image_embeddings, bidirectional_mask
)

policy = self.get_remat_policy()
Expand All @@ -614,12 +622,12 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
)
if cfg.using_pipeline_parallelism:
if cfg.pipeline_fsdp_ag_once:
partition_spec = self.pipeline_module.get_weight_sharding(
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
y, decoder_segment_ids, decoder_positions, deterministic, self.model_mode
)
else:
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
Expand All @@ -639,7 +647,7 @@ def __call__(
"dense_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
if num_moe_layers_outside_pp > 0:
y, _ = self.scan_decoder_layers(
Expand All @@ -649,7 +657,7 @@ def __call__(
"moe_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
else: # Not DeepSeek
Expand All @@ -665,7 +673,7 @@ def __call__(
"layers_outside_pipeline",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
else:
if cfg.scan_layers:
Expand All @@ -685,7 +693,7 @@ def __call__(
"dense_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
moe_layer = RemattedBlockLayers[1]
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
Expand All @@ -697,15 +705,15 @@ def __call__(
"moe_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
elif cfg.decoder_block == DecoderBlockType.GEMMA3:
y = self._apply_gemma3_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
bidirectional_mask,
previous_chunk,
page_state,
Expand All @@ -728,7 +736,7 @@ def __call__(
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
**layer_kwargs,
)(y, *broadcast_args)
else:
Expand All @@ -751,7 +759,7 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand Down Expand Up @@ -779,7 +787,7 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand All @@ -791,7 +799,7 @@ def __call__(
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
hidden_state = y

logits = self._apply_output_head(hidden_state, deterministic, model_mode)
logits = self._apply_output_head(hidden_state, deterministic, self.model_mode)

# The API of the Decoder is now a tuple, providing both the main output
# and the raw hidden state needed for auxiliary tasks.
Expand Down Expand Up @@ -830,7 +838,7 @@ def _apply_gemma3_scanned_blocks(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
)
y, _ = self.scan_decoder_layers(
cfg,
Expand All @@ -839,7 +847,7 @@ def _apply_gemma3_scanned_blocks(
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
**layer_kwargs,
)(y, *broadcast_args, **layer_call_kwargs)

Expand All @@ -856,7 +864,7 @@ def _apply_gemma3_scanned_blocks(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand Down
Loading
Loading