Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ set_remat_policy_on_pipeline_iterations: True
set_remat_policy_on_layers_per_stage: False


# Choose 'remat_policy' between 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
# Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp',
# 'save_qkv_proj', 'qkv_proj_offloaded', 'custom', 'minimal_offloaded', 'save_out_proj' and 'full'.
# These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest)
remat_policy: 'full'
Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/models/gpu/llama2_7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ max_target_length: 4096
model_name: "llama2-7b"
enable_checkpointing: False
attention: "cudnn_flash_te"
remat_policy: "minimal_flash"
remat_policy: "minimal_with_context"
use_iota_embed: True
scan_layers: False
dataset_type: "synthetic"
Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/models/gpu/llama3_8b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ steps: 30
per_device_batch_size: 12
max_target_length: 8192
attention: "cudnn_flash_te"
remat_policy: "minimal_flash"
remat_policy: "minimal_with_context"
use_iota_embed: True
dataset_type: "synthetic"
reuse_example_batch: 1
Expand Down
2 changes: 1 addition & 1 deletion MaxText/configs/models/gpu/mixtral_8x7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ steps: 30
per_device_batch_size: 12
max_target_length: 4096
attention: "cudnn_flash_te"
remat_policy: "minimal_flash"
remat_policy: "minimal_with_context"
use_iota_embed: True
dataset_type: "synthetic"
reuse_example_batch: 1
Expand Down
54 changes: 41 additions & 13 deletions MaxText/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,34 @@ def setup(self):
config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy
)

def minimal_policy(self, with_context=False):
"""Helper for creating minimal checkpoint policies."""
names = [
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwi_0",
"mlpwi_1",
"mlpwi",
"mlpwo",
]
if with_context:
names.append("context")
return jax.checkpoint_policies.save_only_these_names(*names)

def get_remat_policy(self):
"""Get remat policy"""
policy = None
cfg = self.config
if cfg.remat_policy != "none":
if cfg.remat_policy == "minimal":
policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
if cfg.remat_policy == "minimal_with_context":
# save all
policy = self.minimal_policy(with_context=True)
elif cfg.remat_policy == "minimal":
# save all except context
policy = self.minimal_policy()
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
policy = jax.checkpoint_policies.save_only_these_names(
"query_proj",
Expand Down Expand Up @@ -307,21 +328,30 @@ def get_remat_policy(self):
offload_dst="pinned_host",
)
elif cfg.remat_policy == "minimal_offloaded":
policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host")
# offload all except context
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=[],
names_which_can_be_offloaded=[
"query_proj",
"value_proj",
"key_proj",
"qkv_proj",
"out_proj",
"mlpwi_0",
"mlpwi_1",
"mlpwi",
"mlpwo",
],
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "custom":
policy = jax.checkpoint_policies.save_and_offload_only_these_names(
names_which_can_be_saved=cfg.tensors_on_device,
names_which_can_be_offloaded=cfg.tensors_to_offload,
offload_src="device",
offload_dst="pinned_host",
)
elif cfg.remat_policy == "minimal_flash":
policy = jax.checkpoint_policies.save_from_both_policies(
jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
jax.checkpoint_policies.save_only_these_names(
"context",
),
)
elif cfg.remat_policy == "save_out_proj":
policy = jax.checkpoint_policies.save_only_these_names(
"out_proj",
Expand Down Expand Up @@ -742,9 +772,7 @@ def __call__(
# Iterate over the two layer groups (dense and MoE) and apply layer transformation
for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes):
for index in range(num_layers):
y = layer(
config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode
)(
y = layer(config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode)(
y,
decoder_segment_ids,
decoder_positions,
Expand Down
Loading