@@ -264,8 +264,29 @@ def get_remat_policy(self):
264
264
policy = None
265
265
cfg = self .config
266
266
if cfg .remat_policy != "none" :
267
- if cfg .remat_policy == "minimal" :
268
- policy = jax .checkpoint_policies .checkpoint_dots_with_no_batch_dims
267
+ if cfg .remat_policy == "minimal_with_context" :
268
+ policy = jax .checkpoint_policies .save_only_these_names (
269
+ "query_proj" ,
270
+ "value_proj" ,
271
+ "key_proj" ,
272
+ "qkv_proj" ,
273
+ "context" ,
274
+ "out_proj" ,
275
+ "mlpwi_0" ,
276
+ "mlpwi_1" ,
277
+ "mlpwo" ,
278
+ )
279
+ elif cfg .remat_policy == "minimal" :
280
+ policy = jax .checkpoint_policies .save_only_these_names (
281
+ "query_proj" ,
282
+ "value_proj" ,
283
+ "key_proj" ,
284
+ "qkv_proj" ,
285
+ "out_proj" ,
286
+ "mlpwi_0" ,
287
+ "mlpwi_1" ,
288
+ "mlpwo" ,
289
+ )
269
290
elif cfg .remat_policy == "save_dot_with_context_except_mlp" :
270
291
policy = jax .checkpoint_policies .save_only_these_names (
271
292
"query_proj" ,
@@ -307,21 +328,28 @@ def get_remat_policy(self):
307
328
offload_dst = "pinned_host" ,
308
329
)
309
330
elif cfg .remat_policy == "minimal_offloaded" :
310
- policy = jax .checkpoint_policies .offload_dot_with_no_batch_dims (offload_src = "device" , offload_dst = "pinned_host" )
331
+ policy = jax .checkpoint_policies .save_and_offload_only_these_names (
332
+ names_which_can_be_saved = [],
333
+ names_which_can_be_offloaded = [
334
+ "query_proj" ,
335
+ "value_proj" ,
336
+ "key_proj" ,
337
+ "qkv_proj" ,
338
+ "out_proj" ,
339
+ "mlpwi_0" ,
340
+ "mlpwi_1" ,
341
+ "mlpwo" ,
342
+ ],
343
+ offload_src = "device" ,
344
+ offload_dst = "pinned_host" ,
345
+ )
311
346
elif cfg .remat_policy == "custom" :
312
347
policy = jax .checkpoint_policies .save_and_offload_only_these_names (
313
348
names_which_can_be_saved = cfg .tensors_on_device ,
314
349
names_which_can_be_offloaded = cfg .tensors_to_offload ,
315
350
offload_src = "device" ,
316
351
offload_dst = "pinned_host" ,
317
352
)
318
- elif cfg .remat_policy == "minimal_flash" :
319
- policy = jax .checkpoint_policies .save_from_both_policies (
320
- jax .checkpoint_policies .checkpoint_dots_with_no_batch_dims ,
321
- jax .checkpoint_policies .save_only_these_names (
322
- "context" ,
323
- ),
324
- )
325
353
elif cfg .remat_policy == "save_out_proj" :
326
354
policy = jax .checkpoint_policies .save_only_these_names (
327
355
"out_proj" ,
@@ -422,9 +450,7 @@ def get_norm_layer(self, num_features: int):
422
450
else :
423
451
raise ValueError (f"Incorrect decoder_block name { self .config .decoder_block .value = } " )
424
452
425
- def scan_decoder_layers (
426
- self , cfg , decoder_layer , length , metadata_axis_name , mesh , in_axes_tuple , model_mode , ** kwargs
427
- ):
453
+ def scan_decoder_layers (self , cfg , decoder_layer , length , metadata_axis_name , mesh , in_axes_tuple , model_mode , ** kwargs ):
428
454
"""scan decoder layers, calls `flax.linen.transforms.scan`"""
429
455
initializing = self .is_mutable_collection ("params" )
430
456
params_spec = cfg .param_scan_axis if initializing else ScanIn (cfg .param_scan_axis )
@@ -744,9 +770,7 @@ def __call__(
744
770
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
745
771
for layer , num_layers , layer_prefix in zip (layers , num_layers_list , layer_prefixes ):
746
772
for index in range (num_layers ):
747
- y = layer (
748
- config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode
749
- )(
773
+ y = layer (config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode )(
750
774
y ,
751
775
decoder_segment_ids ,
752
776
decoder_positions ,
0 commit comments