Skip to content

Conversation

bvandermoon
Copy link
Collaborator

@bvandermoon bvandermoon commented Aug 14, 2025

Description

Migrate LlamaDecoderLayer to NNX instead of Linen

  • Change module type to NNX
  • Initialize submodules in __init__ instead of __call__
  • Pass self.model_mode to module constructors instead of the one passed to call
  • Update to_linen_class to work for pipelining (thanks @cgarciae for this)
  • New init/apply functions in Transformer module to ensure the proper model_mode value is passed around
  • Add a new config flag enable_nnx to enable/disable NNX for this and future models (off by default)
    • We will remove this flag and have everything in NNX when the inference memory issue is resolved

New PR due to some git issues. Addressed comments from #2155

Note: Continuing to investigate increased KVCache memory during inference. Considering if we should still merge this PR to unblock others waiting on this migration. Then continue investigating this in parallel

Tests

  • Llama2-7B train gives same memory/perf before/after on TPU VM. Can also continue training using a locally-generated checkpoint (with load_full_state_path) and an existing checkpoint with load_parameters_path:
python3 -m MaxText.train MaxText/configs/base.yml \
    run_name=<run_name> \
    base_output_directory=gs://<gcs_bucket> \
    dataset_type=synthetic \
    steps=10 \
    model_name=llama2-7b
  • Exact same peak memory allocation for jit_train_step in memory viewer before/after this change:
  • Successfully ran Llama3.1-8B decode on TPU VM. Can also run this from a checkpoint:
python3 -m MaxText.decode MaxText/configs/base.yml \
    model_name=llama2-7b \
    tokenizer_path=assets/tokenizer_llama3.tiktoken \
    tokenizer_type=tiktoken \
    scan_layers=false \
    per_device_batch_size=1 \
    ici_fsdp_parallelism=1 \
    ici_autoregressive_parallelism=-1 \
    max_prefill_predict_length=128 \
    max_target_length=256 \
    prompt="I love to" \
    attention=dot_product

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@bvandermoon bvandermoon force-pushed the bvandermoon-llama branch 5 times, most recently from 6756c3e to 49d6cb0 Compare August 15, 2025 01:13
Copy link
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, a few minor comments.

Comment on lines 93 to 96
if self.model_mode == MODEL_MODE_PREFILL:
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
else:
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you are touching this part, could you also fix it.

if/else blocks are doing the same thing, we can just call (without if/else):
inputs = nn.with_logical_constraint(inputs, logical_axis_names)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks @gagika

Comment on lines 108 to 111
if self.model_mode == MODEL_MODE_PREFILL:
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
else:
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, no need for if/else.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, thanks @gagika

Comment on lines 57 to 65
def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs):
"""Initializes the model."""
module = self.clone(model_mode=model_mode)
return nn.Module.init(module, *args, **kwargs)

def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs):
"""Initializes the model."""
module = self.clone(model_mode=model_mode)
return nn.Module.apply(module, *args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment why those functions with clones are needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a code comment for these. They are needed to ensure the same model_mode is passed to __init__ and __call__

@@ -61,7 +143,6 @@ def __call__(
previous_chunk=None,
):
cfg = self.config
mesh = self.mesh

if model_mode == MODEL_MODE_PREFILL:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since your PR made model_mode a static argument, passing it in both __init__ and __call__ can be confusing. You could move activation_axis_names initialization in init function (or depend model_mode from config).

if model_mode == MODEL_MODE_PREFILL:
activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
else:
activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to keep it as if if we expect model_mode to change during runtime, e.g. going from prefill to autoregressive decode.

cc @cgarciae

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @gagika. Updated so that activation_axis_names is set in __init__ now. As a P1, we will want to remove the model_mode that is passed to __call__ later

@bvandermoon bvandermoon mentioned this pull request Aug 19, 2025
4 tasks
else:
seq_len = config.max_target_length

dummy_inputs_shape = (batch_size, seq_len, config.emb_dim)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you want to use a shape that is sharded along the batch dim?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dubstack can you clarify what you mean about "sharded shape"?

@bvandermoon bvandermoon force-pushed the bvandermoon-llama branch 2 times, most recently from 642930d to d3c5290 Compare August 20, 2025 22:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants