Skip to content

Commit a79921b

Browse files
committed
Migrate LlamaDecoderLayer to NNX
1 parent 5d600de commit a79921b

File tree

7 files changed

+301
-117
lines changed

7 files changed

+301
-117
lines changed

MaxText/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ model_name: "default" # override config settings to match a specific model. othe
2020
override_model_config: False # When set to true allows overriding model parameters via CLI for the purpose of debugging/testing.
2121
normalization_layer_epsilon: 1.e-05
2222

23+
# Temporary flag to determine whether to use NNX implementation for certain modules.
24+
# This flag mainly targets modules at the decoder layer of the pytree and higher.
25+
# Modules below that will be NNX regardless.
26+
# TODO: Remove this flag when NNX inference memory is optimized
27+
enable_nnx: False
28+
2329
################################## CHECKPOINTING ##################################
2430
# Checkpointing makes the following choices in the following order, starting with (1):
2531
# (1) If there is already a checkpoint for this run_name, we load the latest entire checkpoint.

MaxText/inference/kvcache.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ def __init__(
310310
self.model_mode = model_mode
311311
self.use_chunked_prefill = use_chunked_prefill
312312

313-
self._initialize_prefill_caches(model_mode)
314-
self._initialize_ar_cache_vars(model_mode)
313+
if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
314+
self._initialize_prefill_caches(model_mode)
315+
self._initialize_ar_cache_vars(model_mode)
315316

316317
@property
317318
def prefill_key_vars(self):

MaxText/layers/decoders.py

Lines changed: 45 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,12 @@ def __call__(
8585
):
8686
cfg = self.config
8787
mesh = self.mesh
88-
if model_mode == MODEL_MODE_PREFILL:
88+
if self.model_mode == MODEL_MODE_PREFILL:
8989
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
9090
else:
9191
logical_axis_names = ("activation_batch", "activation_length", "activation_embed")
9292

93-
if model_mode == MODEL_MODE_PREFILL:
94-
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
95-
else:
96-
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
93+
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
9794

9895
inputs = checkpoint_name(inputs, "decoder_layer_input")
9996
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
@@ -105,10 +102,7 @@ def __call__(
105102
epsilon=cfg.normalization_layer_epsilon,
106103
kernel_axes=("norm",),
107104
)(inputs)
108-
if model_mode == MODEL_MODE_PREFILL:
109-
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
110-
else:
111-
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
105+
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
112106

113107
attention_layer = attention_as_linen(
114108
config=self.config,
@@ -133,7 +127,7 @@ def __call__(
133127
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
134128
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
135129
reshape_q=cfg.reshape_q,
136-
model_mode=model_mode,
130+
model_mode=self.model_mode,
137131
)
138132

139133
attention_lnx = attention_layer(
@@ -142,10 +136,10 @@ def __call__(
142136
decoder_positions,
143137
decoder_segment_ids=decoder_segment_ids,
144138
deterministic=deterministic,
145-
model_mode=model_mode,
139+
model_mode=self.model_mode,
146140
)
147141

148-
if model_mode == MODEL_MODE_PREFILL:
142+
if self.model_mode == MODEL_MODE_PREFILL:
149143
attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names)
150144
else:
151145
attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names)
@@ -159,11 +153,11 @@ def __call__(
159153
dtype=cfg.dtype,
160154
weight_dtype=cfg.weight_dtype,
161155
name="mlp",
162-
model_mode=model_mode,
156+
model_mode=self.model_mode,
163157
config=cfg,
164158
quant=self.quant,
165159
)(lnx, deterministic=deterministic)
166-
if model_mode == MODEL_MODE_PREFILL:
160+
if self.model_mode == MODEL_MODE_PREFILL:
167161
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
168162
else:
169163
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
@@ -175,7 +169,7 @@ def __call__(
175169
)
176170

177171
layer_output = next_layer_addition_dropped_out + inputs
178-
if model_mode == MODEL_MODE_PREFILL:
172+
if self.model_mode == MODEL_MODE_PREFILL:
179173
layer_output = nn.with_logical_constraint(
180174
layer_output,
181175
logical_axis_names,
@@ -221,13 +215,13 @@ def __call__(
221215
) -> jnp.ndarray:
222216
for lyr in range(self.num_decoder_layers):
223217
inputs = self.decoder_layer(
224-
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode
218+
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode
225219
)(
226220
inputs,
227221
decoder_segment_ids,
228222
decoder_positions,
229223
deterministic,
230-
model_mode,
224+
self.model_mode,
231225
slot=slot,
232226
page_state=page_state,
233227
)
@@ -337,6 +331,13 @@ def get_decoder_layers(self):
337331
Returns:
338332
A list containing one or more `nn.Module` classes for the decoder.
339333
"""
334+
if self.config.enable_nnx:
335+
match self.config.decoder_block:
336+
case DecoderBlockType.LLAMA2:
337+
return [llama2.LlamaDecoderLayerNNXToLinen]
338+
case _:
339+
raise ValueError(f"decoder_block name {self.config.decoder_block.value=} not yet supported with enable_nnx=True")
340+
340341
match self.config.decoder_block:
341342
case DecoderBlockType.DEFAULT:
342343
return [DecoderLayer]
@@ -446,7 +447,14 @@ def scan_decoder_layers(
446447
length=length,
447448
metadata_params={nn.PARTITION_NAME: metadata_axis_name},
448449
)
449-
return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, model_mode=model_mode, **kwargs)
450+
return scan_fn(
451+
config=cfg,
452+
mesh=mesh,
453+
name=metadata_axis_name,
454+
quant=self.quant,
455+
model_mode=self.model_mode,
456+
**kwargs
457+
)
450458

451459
def get_pipeline_stage_module(self, decoder_blocks):
452460
"""get pipeline stage module"""
@@ -498,7 +506,7 @@ def _apply_embedding(
498506
"""Applies token and positional embeddings to the input tokens."""
499507
cfg = self.config
500508

501-
y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)
509+
y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=self.model_mode)
502510

503511
# Merge the image embeddings with the text embeddings for multimodal models
504512
if image_embeddings is not None and cfg.use_multimodal:
@@ -526,7 +534,7 @@ def _apply_embedding(
526534
embedding_init=nn.initializers.normal(stddev=1.0),
527535
name="position_embedder",
528536
config=cfg,
529-
)(decoder_positions, model_mode=model_mode)
537+
)(decoder_positions, model_mode=self.model_mode)
530538
return y
531539

532540
@nn.compact
@@ -572,7 +580,7 @@ def _apply_output_head(self, y, deterministic, model_mode):
572580
)(
573581
y
574582
) # We do not quantize the logits matmul.
575-
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
583+
if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
576584
logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab"))
577585
else:
578586
logits = nn.with_logical_constraint(
@@ -604,7 +612,7 @@ def __call__(
604612

605613
# [batch, length] -> [batch, length, emb_dim]
606614
y = self._apply_embedding(
607-
decoder_input_tokens, decoder_positions, deterministic, model_mode, image_embeddings, bidirectional_mask
615+
decoder_input_tokens, decoder_positions, deterministic, self.model_mode, image_embeddings, bidirectional_mask
608616
)
609617

610618
policy = self.get_remat_policy()
@@ -614,12 +622,12 @@ def __call__(
614622
decoder_segment_ids,
615623
decoder_positions,
616624
deterministic,
617-
model_mode,
625+
self.model_mode,
618626
)
619627
if cfg.using_pipeline_parallelism:
620628
if cfg.pipeline_fsdp_ag_once:
621629
partition_spec = self.pipeline_module.get_weight_sharding(
622-
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
630+
y, decoder_segment_ids, decoder_positions, deterministic, self.model_mode
623631
)
624632
else:
625633
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
@@ -639,7 +647,7 @@ def __call__(
639647
"dense_layers",
640648
mesh,
641649
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
642-
model_mode=model_mode,
650+
model_mode=self.model_mode,
643651
)(y, *broadcast_args)
644652
if num_moe_layers_outside_pp > 0:
645653
y, _ = self.scan_decoder_layers(
@@ -649,7 +657,7 @@ def __call__(
649657
"moe_layers",
650658
mesh,
651659
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
652-
model_mode=model_mode,
660+
model_mode=self.model_mode,
653661
)(y, *broadcast_args)
654662
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
655663
else: # Not DeepSeek
@@ -665,7 +673,7 @@ def __call__(
665673
"layers_outside_pipeline",
666674
mesh,
667675
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
668-
model_mode=model_mode,
676+
model_mode=self.model_mode,
669677
)(y, *broadcast_args)
670678
else:
671679
if cfg.scan_layers:
@@ -685,7 +693,7 @@ def __call__(
685693
"dense_layers",
686694
mesh,
687695
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
688-
model_mode=model_mode,
696+
model_mode=self.model_mode,
689697
)(y, *broadcast_args)
690698
moe_layer = RemattedBlockLayers[1]
691699
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
@@ -697,15 +705,15 @@ def __call__(
697705
"moe_layers",
698706
mesh,
699707
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
700-
model_mode=model_mode,
708+
model_mode=self.model_mode,
701709
)(y, *broadcast_args)
702710
elif cfg.decoder_block == DecoderBlockType.GEMMA3:
703711
y = self._apply_gemma3_scanned_blocks(
704712
y,
705713
decoder_segment_ids,
706714
decoder_positions,
707715
deterministic,
708-
model_mode,
716+
self.model_mode,
709717
bidirectional_mask,
710718
previous_chunk,
711719
page_state,
@@ -728,7 +736,7 @@ def __call__(
728736
"layers",
729737
mesh,
730738
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
731-
model_mode=model_mode,
739+
model_mode=self.model_mode,
732740
**layer_kwargs,
733741
)(y, *broadcast_args)
734742
else:
@@ -751,7 +759,7 @@ def __call__(
751759
decoder_segment_ids,
752760
decoder_positions,
753761
deterministic,
754-
model_mode,
762+
self.model_mode,
755763
previous_chunk=previous_chunk,
756764
page_state=page_state,
757765
slot=slot,
@@ -779,7 +787,7 @@ def __call__(
779787
decoder_segment_ids,
780788
decoder_positions,
781789
deterministic,
782-
model_mode,
790+
self.model_mode,
783791
previous_chunk=previous_chunk,
784792
page_state=page_state,
785793
slot=slot,
@@ -791,7 +799,7 @@ def __call__(
791799
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
792800
hidden_state = y
793801

794-
logits = self._apply_output_head(hidden_state, deterministic, model_mode)
802+
logits = self._apply_output_head(hidden_state, deterministic, self.model_mode)
795803

796804
# The API of the Decoder is now a tuple, providing both the main output
797805
# and the raw hidden state needed for auxiliary tasks.
@@ -830,7 +838,7 @@ def _apply_gemma3_scanned_blocks(
830838
decoder_segment_ids,
831839
decoder_positions,
832840
deterministic,
833-
model_mode,
841+
self.model_mode,
834842
)
835843
y, _ = self.scan_decoder_layers(
836844
cfg,
@@ -839,7 +847,7 @@ def _apply_gemma3_scanned_blocks(
839847
"layers",
840848
mesh,
841849
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
842-
model_mode=model_mode,
850+
model_mode=self.model_mode,
843851
**layer_kwargs,
844852
)(y, *broadcast_args, **layer_call_kwargs)
845853

@@ -856,7 +864,7 @@ def _apply_gemma3_scanned_blocks(
856864
decoder_segment_ids,
857865
decoder_positions,
858866
deterministic,
859-
model_mode,
867+
self.model_mode,
860868
previous_chunk=previous_chunk,
861869
page_state=page_state,
862870
slot=slot,

0 commit comments

Comments
 (0)