Skip to content

Commit 7561844

Browse files
committed
Update remat rules
1 parent 5d600de commit 7561844

File tree

5 files changed

+44
-20
lines changed

5 files changed

+44
-20
lines changed

MaxText/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ set_remat_policy_on_pipeline_iterations: True
240240
set_remat_policy_on_layers_per_stage: False
241241

242242

243-
# Choose 'remat_policy' between 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
243+
# Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
244244
# 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'.
245245
# These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest)
246246
remat_policy: 'full'

MaxText/configs/models/gpu/llama2_7b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ max_target_length: 4096
88
model_name: "llama2-7b"
99
enable_checkpointing: False
1010
attention: "cudnn_flash_te"
11-
remat_policy: "minimal_flash"
11+
remat_policy: "minimal_with_context"
1212
use_iota_embed: True
1313
scan_layers: False
1414
dataset_type: "synthetic"

MaxText/configs/models/gpu/llama3_8b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ steps: 30
2323
per_device_batch_size: 12
2424
max_target_length: 8192
2525
attention: "cudnn_flash_te"
26-
remat_policy: "minimal_flash"
26+
remat_policy: "minimal_with_context"
2727
use_iota_embed: True
2828
dataset_type: "synthetic"
2929
reuse_example_batch: 1

MaxText/configs/models/gpu/mixtral_8x7b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ steps: 30
2323
per_device_batch_size: 12
2424
max_target_length: 4096
2525
attention: "cudnn_flash_te"
26-
remat_policy: "minimal_flash"
26+
remat_policy: "minimal_with_context"
2727
use_iota_embed: True
2828
dataset_type: "synthetic"
2929
reuse_example_batch: 1

MaxText/layers/decoders.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,29 @@ def get_remat_policy(self):
264264
policy = None
265265
cfg = self.config
266266
if cfg.remat_policy != "none":
267-
if cfg.remat_policy == "minimal":
268-
policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
267+
if cfg.remat_policy == "minimal_with_context":
268+
policy = jax.checkpoint_policies.save_only_these_names(
269+
"query_proj",
270+
"value_proj",
271+
"key_proj",
272+
"qkv_proj",
273+
"context",
274+
"out_proj",
275+
"mlpwi_0",
276+
"mlpwi_1",
277+
"mlpwo",
278+
)
279+
elif cfg.remat_policy == "minimal":
280+
policy = jax.checkpoint_policies.save_only_these_names(
281+
"query_proj",
282+
"value_proj",
283+
"key_proj",
284+
"qkv_proj",
285+
"out_proj",
286+
"mlpwi_0",
287+
"mlpwi_1",
288+
"mlpwo",
289+
)
269290
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
270291
policy = jax.checkpoint_policies.save_only_these_names(
271292
"query_proj",
@@ -307,21 +328,28 @@ def get_remat_policy(self):
307328
offload_dst="pinned_host",
308329
)
309330
elif cfg.remat_policy == "minimal_offloaded":
310-
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host")
331+
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
332+
names_which_can_be_saved=[],
333+
names_which_can_be_offloaded=[
334+
"query_proj",
335+
"value_proj",
336+
"key_proj",
337+
"qkv_proj",
338+
"out_proj",
339+
"mlpwi_0",
340+
"mlpwi_1",
341+
"mlpwo",
342+
],
343+
offload_src="device",
344+
offload_dst="pinned_host",
345+
)
311346
elif cfg.remat_policy == "custom":
312347
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
313348
names_which_can_be_saved=cfg.tensors_on_device,
314349
names_which_can_be_offloaded=cfg.tensors_to_offload,
315350
offload_src="device",
316351
offload_dst="pinned_host",
317352
)
318-
elif cfg.remat_policy == "minimal_flash":
319-
policy = jax.checkpoint_policies.save_from_both_policies(
320-
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
321-
jax.checkpoint_policies.save_only_these_names(
322-
"context",
323-
),
324-
)
325353
elif cfg.remat_policy == "save_out_proj":
326354
policy = jax.checkpoint_policies.save_only_these_names(
327355
"out_proj",
@@ -422,9 +450,7 @@ def get_norm_layer(self, num_features: int):
422450
else:
423451
raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}")
424452

425-
def scan_decoder_layers(
426-
self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, model_mode, **kwargs
427-
):
453+
def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, model_mode, **kwargs):
428454
"""scan decoder layers, calls `flax.linen.transforms.scan`"""
429455
initializing = self.is_mutable_collection("params")
430456
params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis)
@@ -744,9 +770,7 @@ def __call__(
744770
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
745771
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
746772
for index in range(num_layers):
747-
y = layer(
748-
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
749-
)(
773+
y = layer(config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode)(
750774
y,
751775
decoder_segment_ids,
752776
decoder_positions,

0 commit comments

Comments
 (0)