Skip to content

Commit 6426085

Browse files
committed
Update remat rules
1 parent e1a5788 commit 6426085

File tree

5 files changed

+45
-17
lines changed

5 files changed

+45
-17
lines changed

MaxText/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ set_remat_policy_on_pipeline_iterations: True
241241
set_remat_policy_on_layers_per_stage: False
242242

243243

244-
# Choose 'remat_policy' between 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
244+
# Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
245245
# 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'.
246246
# These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest)
247247
remat_policy: 'full'

MaxText/configs/models/gpu/llama2_7b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ max_target_length: 4096
88
model_name: "llama2-7b"
99
enable_checkpointing: False
1010
attention: "cudnn_flash_te"
11-
remat_policy: "minimal_flash"
11+
remat_policy: "minimal_with_context"
1212
use_iota_embed: True
1313
scan_layers: False
1414
dataset_type: "synthetic"

MaxText/configs/models/gpu/llama3_8b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ steps: 30
2323
per_device_batch_size: 12
2424
max_target_length: 8192
2525
attention: "cudnn_flash_te"
26-
remat_policy: "minimal_flash"
26+
remat_policy: "minimal_with_context"
2727
use_iota_embed: True
2828
dataset_type: "synthetic"
2929
reuse_example_batch: 1

MaxText/configs/models/gpu/mixtral_8x7b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ steps: 30
2323
per_device_batch_size: 12
2424
max_target_length: 4096
2525
attention: "cudnn_flash_te"
26-
remat_policy: "minimal_flash"
26+
remat_policy: "minimal_with_context"
2727
use_iota_embed: True
2828
dataset_type: "synthetic"
2929
reuse_example_batch: 1

MaxText/layers/decoders.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,34 @@ def setup(self):
259259
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
260260
)
261261

262+
def minimal_policy(self, with_context=False):
263+
"""Helper for creating minimal checkpoint policies."""
264+
names = [
265+
"query_proj",
266+
"value_proj",
267+
"key_proj",
268+
"qkv_proj",
269+
"out_proj",
270+
"mlpwi_0",
271+
"mlpwi_1",
272+
"mlpwi",
273+
"mlpwo",
274+
]
275+
if with_context:
276+
names.append("context")
277+
return jax.checkpoint_policies.save_only_these_names(*names)
278+
262279
def get_remat_policy(self):
263280
"""Get remat policy"""
264281
policy = None
265282
cfg = self.config
266283
if cfg.remat_policy != "none":
267-
if cfg.remat_policy == "minimal":
268-
policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
284+
if cfg.remat_policy == "minimal_with_context":
285+
# save all
286+
policy = self.minimal_policy(with_context=True)
287+
elif cfg.remat_policy == "minimal":
288+
# save all except context
289+
policy = self.minimal_policy()
269290
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
270291
policy = jax.checkpoint_policies.save_only_these_names(
271292
"query_proj",
@@ -307,21 +328,30 @@ def get_remat_policy(self):
307328
offload_dst="pinned_host",
308329
)
309330
elif cfg.remat_policy == "minimal_offloaded":
310-
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host")
331+
# offload all except context
332+
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
333+
names_which_can_be_saved=[],
334+
names_which_can_be_offloaded=[
335+
"query_proj",
336+
"value_proj",
337+
"key_proj",
338+
"qkv_proj",
339+
"out_proj",
340+
"mlpwi_0",
341+
"mlpwi_1",
342+
"mlpwi",
343+
"mlpwo",
344+
],
345+
offload_src="device",
346+
offload_dst="pinned_host",
347+
)
311348
elif cfg.remat_policy == "custom":
312349
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
313350
names_which_can_be_saved=cfg.tensors_on_device,
314351
names_which_can_be_offloaded=cfg.tensors_to_offload,
315352
offload_src="device",
316353
offload_dst="pinned_host",
317354
)
318-
elif cfg.remat_policy == "minimal_flash":
319-
policy = jax.checkpoint_policies.save_from_both_policies(
320-
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
321-
jax.checkpoint_policies.save_only_these_names(
322-
"context",
323-
),
324-
)
325355
elif cfg.remat_policy == "save_out_proj":
326356
policy = jax.checkpoint_policies.save_only_these_names(
327357
"out_proj",
@@ -742,9 +772,7 @@ def __call__(
742772
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
743773
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
744774
for index in range(num_layers):
745-
y = layer(
746-
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
747-
)(
775+
y = layer(config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode)(
748776
y,
749777
decoder_segment_ids,
750778
decoder_positions,

0 commit comments

Comments
 (0)