Skip to content

Migrate QwenDecoderLayer to NNX #2196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions MaxText/inference/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,9 @@ def __init__(
self.model_mode = model_mode
self.use_chunked_prefill = use_chunked_prefill

self._initialize_prefill_caches(model_mode)
self._initialize_ar_cache_vars(model_mode)
if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
self._initialize_prefill_caches(model_mode)
self._initialize_ar_cache_vars(model_mode)

@property
def prefill_key_vars(self):
Expand Down
79 changes: 42 additions & 37 deletions MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ def __call__(
):
cfg = self.config
mesh = self.mesh
if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed")
else:
logical_axis_names = ("activation_batch", "activation_length", "activation_embed")

if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
else:
inputs = nn.with_logical_constraint(inputs, logical_axis_names)
Expand All @@ -105,7 +105,7 @@ def __call__(
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
)(inputs)
if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
else:
lnx = nn.with_logical_constraint(lnx, logical_axis_names)
Expand Down Expand Up @@ -133,7 +133,7 @@ def __call__(
ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))),
compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))),
reshape_q=cfg.reshape_q,
model_mode=model_mode,
model_mode=self.model_mode,
)

attention_lnx = attention_layer(
Expand All @@ -142,10 +142,10 @@ def __call__(
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
model_mode=self.model_mode,
)

if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names)
else:
attention_lnx = nn.with_logical_constraint(attention_lnx, logical_axis_names)
Expand All @@ -159,11 +159,11 @@ def __call__(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="mlp",
model_mode=model_mode,
model_mode=self.model_mode,
config=cfg,
quant=self.quant,
)(lnx, deterministic=deterministic)
if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
else:
mlp_lnx = nn.with_logical_constraint(mlp_lnx, logical_axis_names)
Expand All @@ -175,7 +175,7 @@ def __call__(
)

layer_output = next_layer_addition_dropped_out + inputs
if model_mode == MODEL_MODE_PREFILL:
if self.model_mode == MODEL_MODE_PREFILL:
layer_output = nn.with_logical_constraint(
layer_output,
logical_axis_names,
Expand Down Expand Up @@ -221,13 +221,13 @@ def __call__(
) -> jnp.ndarray:
for lyr in range(self.num_decoder_layers):
inputs = self.decoder_layer(
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode
config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode
)(
inputs,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
slot=slot,
page_state=page_state,
)
Expand Down Expand Up @@ -341,7 +341,7 @@ def get_decoder_layers(self):
case DecoderBlockType.DEFAULT:
return [DecoderLayer]
case DecoderBlockType.LLAMA2:
return [llama2.LlamaDecoderLayer]
return [llama2.LlamaDecoderLayerToLinen]
case DecoderBlockType.MISTRAL:
# TODO(ranran): update to Mistral with sliding window attention
return [mistral.MistralDecoderLayer]
Expand All @@ -358,9 +358,9 @@ def get_decoder_layers(self):
case DecoderBlockType.GPT3:
return [gpt3.Gpt3DecoderLayer]
case DecoderBlockType.QWEN3:
return [qwen3.Qwen3DecoderLayer]
return [qwen3.Qwen3DecoderLayerToLinen]
case DecoderBlockType.QWEN3_MOE:
return [qwen3.Qwen3MoeDecoderLayer]
return [qwen3.Qwen3MoeDecoderLayerToLinen]
case DecoderBlockType.SIMPLE:
return [simple_layer.SimpleDecoderLayer]
case DecoderBlockType.SIMPLE_MLP:
Expand All @@ -382,9 +382,7 @@ def move_to_device(variables):

def map_fn(path, value):
max_logging.log(f"models.py: Moving parameter {path} to device")
return jax.device_put(
value, max_utils.device_space()
)
return jax.device_put(value, max_utils.device_space())

return jax.tree_util.tree_map_with_path(map_fn, variables)

Expand Down Expand Up @@ -446,7 +444,14 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me
length=length,
metadata_params={nn.PARTITION_NAME: metadata_axis_name},
)
return scan_fn(config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, model_mode=model_mode, **kwargs)
return scan_fn(
config=cfg,
mesh=mesh,
name=metadata_axis_name,
quant=self.quant,
model_mode=self.model_mode,
**kwargs
)

def get_pipeline_stage_module(self, decoder_blocks):
"""get pipeline stage module"""
Expand Down Expand Up @@ -498,7 +503,7 @@ def _apply_embedding(
"""Applies token and positional embeddings to the input tokens."""
cfg = self.config

y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode)
y = self.shared_embedding(decoder_input_tokens.astype("int32"), model_mode=self.model_mode)

# Merge the image embeddings with the text embeddings for multimodal models
if image_embeddings is not None and cfg.use_multimodal:
Expand Down Expand Up @@ -526,7 +531,7 @@ def _apply_embedding(
embedding_init=nn.initializers.normal(stddev=1.0),
name="position_embedder",
config=cfg,
)(decoder_positions, model_mode=model_mode)
)(decoder_positions, model_mode=self.model_mode)
return y

@nn.compact
Expand Down Expand Up @@ -572,7 +577,7 @@ def _apply_output_head(self, y, deterministic, model_mode):
)(
y
) # We do not quantize the logits matmul.
if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
if self.model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE):
logits = nn.with_logical_constraint(logits, (None, None, "activation_vocab"))
else:
logits = nn.with_logical_constraint(
Expand Down Expand Up @@ -604,7 +609,7 @@ def __call__(

# [batch, length] -> [batch, length, emb_dim]
y = self._apply_embedding(
decoder_input_tokens, decoder_positions, deterministic, model_mode, image_embeddings, bidirectional_mask
decoder_input_tokens, decoder_positions, deterministic, self.model_mode, image_embeddings, bidirectional_mask
)

policy = self.get_remat_policy()
Expand All @@ -614,12 +619,12 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
)
if cfg.using_pipeline_parallelism:
if cfg.pipeline_fsdp_ag_once:
partition_spec = self.pipeline_module.get_weight_sharding(
y, decoder_segment_ids, decoder_positions, deterministic, model_mode
y, decoder_segment_ids, decoder_positions, deterministic, self.model_mode
)
else:
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
Expand All @@ -639,7 +644,7 @@ def __call__(
"dense_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
if num_moe_layers_outside_pp > 0:
y, _ = self.scan_decoder_layers(
Expand All @@ -649,7 +654,7 @@ def __call__(
"moe_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec)
else: # Not DeepSeek
Expand All @@ -665,7 +670,7 @@ def __call__(
"layers_outside_pipeline",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
else:
if cfg.scan_layers:
Expand All @@ -685,7 +690,7 @@ def __call__(
"dense_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
moe_layer = RemattedBlockLayers[1]
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
Expand All @@ -697,15 +702,15 @@ def __call__(
"moe_layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
)(y, *broadcast_args)
elif cfg.decoder_block == DecoderBlockType.GEMMA3:
y = self._apply_gemma3_scanned_blocks(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
bidirectional_mask,
previous_chunk,
page_state,
Expand All @@ -728,7 +733,7 @@ def __call__(
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
**layer_kwargs,
)(y, *broadcast_args)
else:
Expand All @@ -749,7 +754,7 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand Down Expand Up @@ -777,7 +782,7 @@ def __call__(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand All @@ -789,7 +794,7 @@ def __call__(
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
hidden_state = y

logits = self._apply_output_head(hidden_state, deterministic, model_mode)
logits = self._apply_output_head(hidden_state, deterministic, self.model_mode)

# The API of the Decoder is now a tuple, providing both the main output
# and the raw hidden state needed for auxiliary tasks.
Expand Down Expand Up @@ -828,7 +833,7 @@ def _apply_gemma3_scanned_blocks(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
)
y, _ = self.scan_decoder_layers(
cfg,
Expand All @@ -837,7 +842,7 @@ def _apply_gemma3_scanned_blocks(
"layers",
mesh,
in_axes_tuple=(nn.broadcast,) * len(broadcast_args),
model_mode=model_mode,
model_mode=self.model_mode,
**layer_kwargs,
)(y, *broadcast_args, **layer_call_kwargs)

Expand All @@ -854,7 +859,7 @@ def _apply_gemma3_scanned_blocks(
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
self.model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
Expand Down
Loading
Loading