@@ -85,15 +85,12 @@ def __call__(
85
85
):
86
86
cfg = self .config
87
87
mesh = self .mesh
88
- if model_mode == MODEL_MODE_PREFILL :
88
+ if self . model_mode == MODEL_MODE_PREFILL :
89
89
logical_axis_names = ("activation_batch" , "prefill_activation_length" , "activation_embed" )
90
90
else :
91
91
logical_axis_names = ("activation_batch" , "activation_length" , "activation_embed" )
92
92
93
- if model_mode == MODEL_MODE_PREFILL :
94
- inputs = nn .with_logical_constraint (inputs , logical_axis_names )
95
- else :
96
- inputs = nn .with_logical_constraint (inputs , logical_axis_names )
93
+ inputs = nn .with_logical_constraint (inputs , logical_axis_names )
97
94
98
95
inputs = checkpoint_name (inputs , "decoder_layer_input" )
99
96
# inputs: embedded inputs to the decoder with shape [batch, length, emb_dim]
@@ -105,10 +102,7 @@ def __call__(
105
102
epsilon = cfg .normalization_layer_epsilon ,
106
103
kernel_axes = ("norm" ,),
107
104
)(inputs )
108
- if model_mode == MODEL_MODE_PREFILL :
109
- lnx = nn .with_logical_constraint (lnx , logical_axis_names )
110
- else :
111
- lnx = nn .with_logical_constraint (lnx , logical_axis_names )
105
+ lnx = nn .with_logical_constraint (lnx , logical_axis_names )
112
106
113
107
attention_layer = attention_as_linen (
114
108
config = self .config ,
@@ -133,7 +127,7 @@ def __call__(
133
127
ar_cache_axis_order = tuple (map (int , cfg .ar_cache_axis_order .split ("," ))),
134
128
compute_axis_order = tuple (map (int , cfg .compute_axis_order .split ("," ))),
135
129
reshape_q = cfg .reshape_q ,
136
- model_mode = model_mode ,
130
+ model_mode = self . model_mode ,
137
131
)
138
132
139
133
attention_lnx = attention_layer (
@@ -142,10 +136,10 @@ def __call__(
142
136
decoder_positions ,
143
137
decoder_segment_ids = decoder_segment_ids ,
144
138
deterministic = deterministic ,
145
- model_mode = model_mode ,
139
+ model_mode = self . model_mode ,
146
140
)
147
141
148
- if model_mode == MODEL_MODE_PREFILL :
142
+ if self . model_mode == MODEL_MODE_PREFILL :
149
143
attention_lnx = nn .with_logical_constraint (attention_lnx , logical_axis_names )
150
144
else :
151
145
attention_lnx = nn .with_logical_constraint (attention_lnx , logical_axis_names )
@@ -159,11 +153,11 @@ def __call__(
159
153
dtype = cfg .dtype ,
160
154
weight_dtype = cfg .weight_dtype ,
161
155
name = "mlp" ,
162
- model_mode = model_mode ,
156
+ model_mode = self . model_mode ,
163
157
config = cfg ,
164
158
quant = self .quant ,
165
159
)(lnx , deterministic = deterministic )
166
- if model_mode == MODEL_MODE_PREFILL :
160
+ if self . model_mode == MODEL_MODE_PREFILL :
167
161
mlp_lnx = nn .with_logical_constraint (mlp_lnx , logical_axis_names )
168
162
else :
169
163
mlp_lnx = nn .with_logical_constraint (mlp_lnx , logical_axis_names )
@@ -175,7 +169,7 @@ def __call__(
175
169
)
176
170
177
171
layer_output = next_layer_addition_dropped_out + inputs
178
- if model_mode == MODEL_MODE_PREFILL :
172
+ if self . model_mode == MODEL_MODE_PREFILL :
179
173
layer_output = nn .with_logical_constraint (
180
174
layer_output ,
181
175
logical_axis_names ,
@@ -221,13 +215,13 @@ def __call__(
221
215
) -> jnp .ndarray :
222
216
for lyr in range (self .num_decoder_layers ):
223
217
inputs = self .decoder_layer (
224
- config = self .config , mesh = self .mesh , name = f"layers_{ lyr } " , quant = self .quant , model_mode = model_mode
218
+ config = self .config , mesh = self .mesh , name = f"layers_{ lyr } " , quant = self .quant , model_mode = self . model_mode
225
219
)(
226
220
inputs ,
227
221
decoder_segment_ids ,
228
222
decoder_positions ,
229
223
deterministic ,
230
- model_mode ,
224
+ self . model_mode ,
231
225
slot = slot ,
232
226
page_state = page_state ,
233
227
)
@@ -337,6 +331,13 @@ def get_decoder_layers(self):
337
331
Returns:
338
332
A list containing one or more `nn.Module` classes for the decoder.
339
333
"""
334
+ if self .config .enable_nnx :
335
+ match self .config .decoder_block :
336
+ case DecoderBlockType .LLAMA2 :
337
+ return [llama2 .LlamaDecoderLayerNNXToLinen ]
338
+ case _:
339
+ raise ValueError (f"decoder_block name { self .config .decoder_block .value = } not yet supported with enable_nnx=True" )
340
+
340
341
match self .config .decoder_block :
341
342
case DecoderBlockType .DEFAULT :
342
343
return [DecoderLayer ]
@@ -446,7 +447,14 @@ def scan_decoder_layers(
446
447
length = length ,
447
448
metadata_params = {nn .PARTITION_NAME : metadata_axis_name },
448
449
)
449
- return scan_fn (config = cfg , mesh = mesh , name = metadata_axis_name , quant = self .quant , model_mode = model_mode , ** kwargs )
450
+ return scan_fn (
451
+ config = cfg ,
452
+ mesh = mesh ,
453
+ name = metadata_axis_name ,
454
+ quant = self .quant ,
455
+ model_mode = self .model_mode ,
456
+ ** kwargs
457
+ )
450
458
451
459
def get_pipeline_stage_module (self , decoder_blocks ):
452
460
"""get pipeline stage module"""
@@ -498,7 +506,7 @@ def _apply_embedding(
498
506
"""Applies token and positional embeddings to the input tokens."""
499
507
cfg = self .config
500
508
501
- y = self .shared_embedding (decoder_input_tokens .astype ("int32" ), model_mode = model_mode )
509
+ y = self .shared_embedding (decoder_input_tokens .astype ("int32" ), model_mode = self . model_mode )
502
510
503
511
# Merge the image embeddings with the text embeddings for multimodal models
504
512
if image_embeddings is not None and cfg .use_multimodal :
@@ -526,7 +534,7 @@ def _apply_embedding(
526
534
embedding_init = nn .initializers .normal (stddev = 1.0 ),
527
535
name = "position_embedder" ,
528
536
config = cfg ,
529
- )(decoder_positions , model_mode = model_mode )
537
+ )(decoder_positions , model_mode = self . model_mode )
530
538
return y
531
539
532
540
@nn .compact
@@ -572,7 +580,7 @@ def _apply_output_head(self, y, deterministic, model_mode):
572
580
)(
573
581
y
574
582
) # We do not quantize the logits matmul.
575
- if model_mode in (MODEL_MODE_PREFILL , MODEL_MODE_AUTOREGRESSIVE ):
583
+ if self . model_mode in (MODEL_MODE_PREFILL , MODEL_MODE_AUTOREGRESSIVE ):
576
584
logits = nn .with_logical_constraint (logits , (None , None , "activation_vocab" ))
577
585
else :
578
586
logits = nn .with_logical_constraint (
@@ -604,7 +612,7 @@ def __call__(
604
612
605
613
# [batch, length] -> [batch, length, emb_dim]
606
614
y = self ._apply_embedding (
607
- decoder_input_tokens , decoder_positions , deterministic , model_mode , image_embeddings , bidirectional_mask
615
+ decoder_input_tokens , decoder_positions , deterministic , self . model_mode , image_embeddings , bidirectional_mask
608
616
)
609
617
610
618
policy = self .get_remat_policy ()
@@ -614,12 +622,12 @@ def __call__(
614
622
decoder_segment_ids ,
615
623
decoder_positions ,
616
624
deterministic ,
617
- model_mode ,
625
+ self . model_mode ,
618
626
)
619
627
if cfg .using_pipeline_parallelism :
620
628
if cfg .pipeline_fsdp_ag_once :
621
629
partition_spec = self .pipeline_module .get_weight_sharding (
622
- y , decoder_segment_ids , decoder_positions , deterministic , model_mode
630
+ y , decoder_segment_ids , decoder_positions , deterministic , self . model_mode
623
631
)
624
632
else :
625
633
partition_spec = None # This partition spec is only used for the fsdp_ag_once feature.
@@ -639,7 +647,7 @@ def __call__(
639
647
"dense_layers" ,
640
648
mesh ,
641
649
in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
642
- model_mode = model_mode ,
650
+ model_mode = self . model_mode ,
643
651
)(y , * broadcast_args )
644
652
if num_moe_layers_outside_pp > 0 :
645
653
y , _ = self .scan_decoder_layers (
@@ -649,7 +657,7 @@ def __call__(
649
657
"moe_layers" ,
650
658
mesh ,
651
659
in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
652
- model_mode = model_mode ,
660
+ model_mode = self . model_mode ,
653
661
)(y , * broadcast_args )
654
662
y = self .pipeline_module (y , * broadcast_args , partition_spec = partition_spec )
655
663
else : # Not DeepSeek
@@ -665,7 +673,7 @@ def __call__(
665
673
"layers_outside_pipeline" ,
666
674
mesh ,
667
675
in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
668
- model_mode = model_mode ,
676
+ model_mode = self . model_mode ,
669
677
)(y , * broadcast_args )
670
678
else :
671
679
if cfg .scan_layers :
@@ -685,7 +693,7 @@ def __call__(
685
693
"dense_layers" ,
686
694
mesh ,
687
695
in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
688
- model_mode = model_mode ,
696
+ model_mode = self . model_mode ,
689
697
)(y , * broadcast_args )
690
698
moe_layer = RemattedBlockLayers [1 ]
691
699
moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
@@ -697,15 +705,15 @@ def __call__(
697
705
"moe_layers" ,
698
706
mesh ,
699
707
in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
700
- model_mode = model_mode ,
708
+ model_mode = self . model_mode ,
701
709
)(y , * broadcast_args )
702
710
elif cfg .decoder_block == DecoderBlockType .GEMMA3 :
703
711
y = self ._apply_gemma3_scanned_blocks (
704
712
y ,
705
713
decoder_segment_ids ,
706
714
decoder_positions ,
707
715
deterministic ,
708
- model_mode ,
716
+ self . model_mode ,
709
717
bidirectional_mask ,
710
718
previous_chunk ,
711
719
page_state ,
@@ -728,7 +736,7 @@ def __call__(
728
736
"layers" ,
729
737
mesh ,
730
738
in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
731
- model_mode = model_mode ,
739
+ model_mode = self . model_mode ,
732
740
** layer_kwargs ,
733
741
)(y , * broadcast_args )
734
742
else :
@@ -751,7 +759,7 @@ def __call__(
751
759
decoder_segment_ids ,
752
760
decoder_positions ,
753
761
deterministic ,
754
- model_mode ,
762
+ self . model_mode ,
755
763
previous_chunk = previous_chunk ,
756
764
page_state = page_state ,
757
765
slot = slot ,
@@ -779,7 +787,7 @@ def __call__(
779
787
decoder_segment_ids ,
780
788
decoder_positions ,
781
789
deterministic ,
782
- model_mode ,
790
+ self . model_mode ,
783
791
previous_chunk = previous_chunk ,
784
792
page_state = page_state ,
785
793
slot = slot ,
@@ -791,7 +799,7 @@ def __call__(
791
799
# After the final transformer layer, `y` holds the raw, un-normalized hidden state.
792
800
hidden_state = y
793
801
794
- logits = self ._apply_output_head (hidden_state , deterministic , model_mode )
802
+ logits = self ._apply_output_head (hidden_state , deterministic , self . model_mode )
795
803
796
804
# The API of the Decoder is now a tuple, providing both the main output
797
805
# and the raw hidden state needed for auxiliary tasks.
@@ -830,7 +838,7 @@ def _apply_gemma3_scanned_blocks(
830
838
decoder_segment_ids ,
831
839
decoder_positions ,
832
840
deterministic ,
833
- model_mode ,
841
+ self . model_mode ,
834
842
)
835
843
y , _ = self .scan_decoder_layers (
836
844
cfg ,
@@ -839,7 +847,7 @@ def _apply_gemma3_scanned_blocks(
839
847
"layers" ,
840
848
mesh ,
841
849
in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
842
- model_mode = model_mode ,
850
+ model_mode = self . model_mode ,
843
851
** layer_kwargs ,
844
852
)(y , * broadcast_args , ** layer_call_kwargs )
845
853
@@ -856,7 +864,7 @@ def _apply_gemma3_scanned_blocks(
856
864
decoder_segment_ids ,
857
865
decoder_positions ,
858
866
deterministic ,
859
- model_mode ,
867
+ self . model_mode ,
860
868
previous_chunk = previous_chunk ,
861
869
page_state = page_state ,
862
870
slot = slot ,
0 commit comments