@@ -259,13 +259,34 @@ def setup(self):
259
259
config = self .config , mesh = self .mesh , layers = pipeline_stage_module , remat_policy = remat_policy
260
260
)
261
261
262
+ def minimal_policy (self , with_context = False ):
263
+ """Helper for creating minimal checkpoint policies."""
264
+ names = [
265
+ "query_proj" ,
266
+ "value_proj" ,
267
+ "key_proj" ,
268
+ "qkv_proj" ,
269
+ "out_proj" ,
270
+ "mlpwi_0" ,
271
+ "mlpwi_1" ,
272
+ "mlpwi" ,
273
+ "mlpwo" ,
274
+ ]
275
+ if with_context :
276
+ names .append ("context" )
277
+ return jax .checkpoint_policies .save_only_these_names (* names )
278
+
262
279
def get_remat_policy (self ):
263
280
"""Get remat policy"""
264
281
policy = None
265
282
cfg = self .config
266
283
if cfg .remat_policy != "none" :
267
- if cfg .remat_policy == "minimal" :
268
- policy = jax .checkpoint_policies .checkpoint_dots_with_no_batch_dims
284
+ if cfg .remat_policy == "minimal_with_context" :
285
+ # save all
286
+ policy = self .minimal_policy (with_context = True )
287
+ elif cfg .remat_policy == "minimal" :
288
+ # save all except context
289
+ policy = self .minimal_policy ()
269
290
elif cfg .remat_policy == "save_dot_with_context_except_mlp" :
270
291
policy = jax .checkpoint_policies .save_only_these_names (
271
292
"query_proj" ,
@@ -307,21 +328,30 @@ def get_remat_policy(self):
307
328
offload_dst = "pinned_host" ,
308
329
)
309
330
elif cfg .remat_policy == "minimal_offloaded" :
310
- policy = jax .checkpoint_policies .offload_dot_with_no_batch_dims (offload_src = "device" , offload_dst = "pinned_host" )
331
+ # offload all except context
332
+ policy = jax .checkpoint_policies .save_and_offload_only_these_names (
333
+ names_which_can_be_saved = [],
334
+ names_which_can_be_offloaded = [
335
+ "query_proj" ,
336
+ "value_proj" ,
337
+ "key_proj" ,
338
+ "qkv_proj" ,
339
+ "out_proj" ,
340
+ "mlpwi_0" ,
341
+ "mlpwi_1" ,
342
+ "mlpwi" ,
343
+ "mlpwo" ,
344
+ ],
345
+ offload_src = "device" ,
346
+ offload_dst = "pinned_host" ,
347
+ )
311
348
elif cfg .remat_policy == "custom" :
312
349
policy = jax .checkpoint_policies .save_and_offload_only_these_names (
313
350
names_which_can_be_saved = cfg .tensors_on_device ,
314
351
names_which_can_be_offloaded = cfg .tensors_to_offload ,
315
352
offload_src = "device" ,
316
353
offload_dst = "pinned_host" ,
317
354
)
318
- elif cfg .remat_policy == "minimal_flash" :
319
- policy = jax .checkpoint_policies .save_from_both_policies (
320
- jax .checkpoint_policies .checkpoint_dots_with_no_batch_dims ,
321
- jax .checkpoint_policies .save_only_these_names (
322
- "context" ,
323
- ),
324
- )
325
355
elif cfg .remat_policy == "save_out_proj" :
326
356
policy = jax .checkpoint_policies .save_only_these_names (
327
357
"out_proj" ,
@@ -742,9 +772,7 @@ def __call__(
742
772
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
743
773
for layer , num_layers , layer_prefix in zip (layers , num_layers_list , layer_prefixes ):
744
774
for index in range (num_layers ):
745
- y = layer (
746
- config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode
747
- )(
775
+ y = layer (config = cfg , mesh = mesh , name = f"{ layer_prefix } _{ index } " , quant = self .quant , model_mode = self .model_mode )(
748
776
y ,
749
777
decoder_segment_ids ,
750
778
decoder_positions ,
0 commit comments