-
Notifications
You must be signed in to change notification settings - Fork 416
Migrate LlamaDecoderLayer to NNX #2178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
6756c3e
to
49d6cb0
Compare
49d6cb0
to
cc04158
Compare
cc04158
to
e2d20ff
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, a few minor comments.
MaxText/layers/decoders.py
Outdated
if self.model_mode == MODEL_MODE_PREFILL: | ||
inputs = nn.with_logical_constraint(inputs, logical_axis_names) | ||
else: | ||
inputs = nn.with_logical_constraint(inputs, logical_axis_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since you are touching this part, could you also fix it.
if/else blocks are doing the same thing, we can just call (without if/else):
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated, thanks @gagika
MaxText/layers/decoders.py
Outdated
if self.model_mode == MODEL_MODE_PREFILL: | ||
lnx = nn.with_logical_constraint(lnx, logical_axis_names) | ||
else: | ||
lnx = nn.with_logical_constraint(lnx, logical_axis_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above, no need for if/else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated, thanks @gagika
MaxText/layers/models.py
Outdated
def init(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): | ||
"""Initializes the model.""" | ||
module = self.clone(model_mode=model_mode) | ||
return nn.Module.init(module, *args, **kwargs) | ||
|
||
def apply(self, *args, model_mode: str = MODEL_MODE_TRAIN, **kwargs): | ||
"""Initializes the model.""" | ||
module = self.clone(model_mode=model_mode) | ||
return nn.Module.apply(module, *args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you comment why those functions with clones are needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a code comment for these. They are needed to ensure the same model_mode
is passed to __init__
and __call__
MaxText/layers/llama2.py
Outdated
@@ -61,7 +143,6 @@ def __call__( | |||
previous_chunk=None, | |||
): | |||
cfg = self.config | |||
mesh = self.mesh | |||
|
|||
if model_mode == MODEL_MODE_PREFILL: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since your PR made model_mode
a static argument, passing it in both __init__
and __call__
can be confusing. You could move activation_axis_names
initialization in init function (or depend model_mode from config).
if model_mode == MODEL_MODE_PREFILL:
activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")
else:
activation_axis_names = ("activation_batch", "activation_norm_length", "activation_embed")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to keep it as if if we expect model_mode to change during runtime, e.g. going from prefill to autoregressive decode.
cc @cgarciae
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @gagika. Updated so that activation_axis_names
is set in __init__
now. As a P1, we will want to remove the model_mode
that is passed to __call__
later
else: | ||
seq_len = config.max_target_length | ||
|
||
dummy_inputs_shape = (batch_size, seq_len, config.emb_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you want to use a shape that is sharded along the batch dim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dubstack can you clarify what you mean about "sharded shape"?
642930d
to
d3c5290
Compare
d3c5290
to
a79921b
Compare
Description
Migrate LlamaDecoderLayer to NNX instead of Linen
__init__
instead of__call__
self.model_mode
to module constructors instead of the one passed to callto_linen_class
to work for pipelining (thanks @cgarciae for this)model_mode
value is passed aroundenable_nnx
to enable/disable NNX for this and future models (off by default)New PR due to some git issues. Addressed comments from #2155
Note: Continuing to investigate increased KVCache memory during inference. Considering if we should still merge this PR to unblock others waiting on this migration. Then continue investigating this in parallel
Tests
load_full_state_path
) and an existing checkpoint withload_parameters_path
:jit_train_step
in memory viewer before/after this change:Checklist
Before submitting this PR, please make sure (put X in square brackets):