Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,26 @@ You may want to enable checkpointing in `torchtitan` for better fault tolerance
## A general guide to use checkpoints during training

1. ENABLE CHECKPOINTING
In your `torchtitan` training config, ensure that `enable_checkpoint` is set to True.
In your `torchtitan` training config, ensure that under `[checkpoint]`, `enable` is set to True.
```
[checkpoint]
enable_checkpoint = true
enable = true
folder = "checkpoint"
interval = 500
```
2. SAVE MODEL ONLY
By setting `last_save_model_only` to `True`, the checkpoint will only contain the model and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
```
[checkpoint]
enable_checkpoint = true
enable = true
last_save_model_only = true
```

3. CHOOSE DESIRED EXPORT PRECISION
The default model states are in `float32`. You can choose to export the checkpoint in a lower precision format such as `bfloat16`.
```
[checkpoint]
enable_checkpoint = true
enable = true
last_save_model_only = true
export_dtype = "bfloat16"
```
Expand All @@ -34,15 +34,15 @@ In some cases, you may want to partially load from a previous-trained checkpoint
This parameter takes a list of string that should be excluded from loading.
```
[checkpoint]
enable_checkpoint = true
enable = true
exclude_from_loading = ["data_loader", "lr_scheduler"]
```
When used in command line, the parameter should be a comma-separated list of strings. For example: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`.

5. EXAMPLE CHECKPOINT CONFIGURATION
```
[checkpoint]
enable_checkpoint = true
enable = true
folder = "checkpoint"
interval = 10
load_step = 5
Expand All @@ -60,7 +60,7 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
To create a seed checkpoint, use the same model config as you use for training.
e.g.
```bash
NGPU=1 CONFIG_FILE=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
NGPU=1 CONFIG_FILE=<path_to_model_config> ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
```

## Conversion support
Expand All @@ -86,7 +86,7 @@ This guide will walk you through the steps required to convert a checkpoint from
1. CHECKPOINT CONFIGURATION
```
[checkpoint]
enable_checkpoint = true
enable = true
folder = "checkpoint"
interval = 10
last_save_model_only = true
Expand Down
2 changes: 1 addition & 1 deletion docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ For multiple experimental runs with different parallelism configs, we need to us


```bash
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
```

**Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc.
Expand Down
2 changes: 1 addition & 1 deletion docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Below is an example validation config:

```toml
[validation]
enabled = true
enable = true
dataset = "c4_validation"
freq = 500
steps = -1 # consumes the entire validation set
Expand Down
26 changes: 13 additions & 13 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--training.steps 20",
],
],
Expand All @@ -121,13 +121,13 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--checkpoint.folder hf_checkpoint",
"--checkpoint.last_save_model_only",
"--checkpoint.last_save_in_hf",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--checkpoint.initial_load_path artifacts-to-be-uploaded/model_only_hf_checkpoint/hf_checkpoint/step-10/",
"--checkpoint.initial_load_model_only",
"--checkpoint.initial_load_in_hf",
Expand All @@ -139,7 +139,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--checkpoint.last_save_model_only",
],
],
Expand All @@ -149,7 +149,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--checkpoint.last_save_model_only",
"--checkpoint.export_dtype bfloat16",
],
Expand Down Expand Up @@ -244,14 +244,14 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--parallelism.pipeline_parallel_degree 2",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
],
[
"--training.steps 20",
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--parallelism.pipeline_parallel_degree 2",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
Expand Down Expand Up @@ -443,7 +443,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--parallelism.tensor_parallel_degree=2",
"--parallelism.context_parallel_degree=2",
"--training.enable_cpu_offload",
Expand Down Expand Up @@ -474,7 +474,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
],
[
# placeholder for the generation script's generate step
Expand All @@ -497,13 +497,13 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--training.steps 10",
],
# Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be
# excluded during loading to avoid errors caused by mismatched dp_degree.
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
"--parallelism.tensor_parallel_degree 2",
"--training.steps 20",
Expand Down Expand Up @@ -542,7 +542,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--validation.enabled",
"--validation.enable",
"--validation.dataset c4_test",
"--parallelism.tensor_parallel_degree=2",
"--parallelism.context_parallel_degree=2",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests_ft.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def build_test_list():
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
["--training.steps 10", "--checkpoint.enable_checkpoint"],
["--training.steps 10", "--checkpoint.enable"],
],
"Default TorchFT integration test",
"default_torchft",
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DummyJobConfig:
def __init__(self, job):
self.job = job
self.checkpoint = CheckpointConfig(
enable_checkpoint=True,
enable=True,
async_mode="disabled",
folder="",
interval=1,
Expand Down Expand Up @@ -114,7 +114,7 @@ def setUp(self):
self.ft_manager = DummyFTManager()

ckpt_cfg = CheckpointConfig(
enable_checkpoint=True,
enable=True,
async_mode="DISABLED",
folder="",
interval=1,
Expand Down
12 changes: 6 additions & 6 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(
base_folder: str = "",
ft_manager: FTManager | None = None,
) -> None:
self.enable_checkpoint = checkpoint_config.enable_checkpoint
self.enable = checkpoint_config.enable

self.ft_manager = (
ft_manager.manager if ft_manager and ft_manager.enabled else None
Expand Down Expand Up @@ -216,10 +216,10 @@ def load_state_dict(state_dict):

async_mode = checkpoint_config.async_mode.lower()
self.enable_staging = (
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
) or self.ft_manager

if not self.enable_checkpoint and self.ft_manager is None:
if not self.enable and self.ft_manager is None:
return

self.states = states
Expand Down Expand Up @@ -305,7 +305,7 @@ def __del__(self):
self.close()

def close(self):
if hasattr(self, "enable_checkpoint") and self.enable_checkpoint:
if hasattr(self, "enable") and self.enable:
if hasattr(self, "mp") and self.mp and self.mp.is_alive():
self.mp_queue_send.put(Terminate())
self.mp.join()
Expand Down Expand Up @@ -517,7 +517,7 @@ def load(self, step: int = -1) -> bool:
if self.ft_manager:
self._ft_load()

if not self.enable_checkpoint:
if not self.enable:
return False

model_only = False
Expand Down Expand Up @@ -739,7 +739,7 @@ def _save_last_step(self, curr_step: int) -> None:
)

def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
if not self.enable_checkpoint:
if not self.enable:
return False

if curr_step == 1 and self.enable_first_step_checkpoint:
Expand Down
16 changes: 5 additions & 11 deletions torchtitan/components/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import torch
import torch.nn as nn
from torch.distributed.fsdp import FSDPModule
from torch.distributed.pipelining.schedules import _PipelineSchedule
from torchtitan.components.dataloader import BaseDataLoader
from torchtitan.components.loss import LossFunction
Expand Down Expand Up @@ -82,8 +81,8 @@ def validate(
step: int,
) -> None:
# Set model to eval mode
model = model_parts[0]
model.eval()
for model in model_parts:
model.eval()

parallel_dims = self.parallel_dims

Expand Down Expand Up @@ -148,7 +147,7 @@ def validate(
with self.validation_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
predictions = model(inputs)
predictions = model_parts[0](inputs)
loss = self.loss_fn(predictions, labels)

accumulated_losses.append(loss.detach())
Expand All @@ -167,14 +166,9 @@ def validate(

self.metrics_processor.log_validation(loss=global_avg_loss, step=step)

# Reshard after run forward pass
# This is to ensure the model weights are sharded the same way for checkpoint saving.
for module in model.modules():
if isinstance(module, FSDPModule):
module.reshard()

# Set model back to train mode
model.train()
for model in model_parts:
model.train()


def build_validator(
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,13 @@ class Parallelism:

@dataclass
class Checkpoint:
enable_checkpoint: bool = False
enable: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, I was thinking to do this as well.

"""Whether to enable checkpoint"""

folder: str = "checkpoint"
"""
The folder to store the checkpoints.
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
When enable is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
"""

interval: int = 500
Expand Down Expand Up @@ -710,7 +710,7 @@ class Experimental:

@dataclass
class Validation:
enabled: bool = False
enable: bool = False
"""Enable validation to default run validation after each training loop"""

dataset: str = "c4_validation"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ context_parallel_degree = 1
# expert_parallel_degree = 2 set in custom_args

[checkpoint]
enable_checkpoint = false
enable = false
folder = "checkpoint"
interval = 10
model_weights_only = false
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/inference/run_infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ PYTORCH_ALLOC_CONF="expandable_segments:True" \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m torchtitan.experiments.flux.inference.infer --job.config_file ${CONFIG_FILE} \
--checkpoint.enable_checkpoint \
--checkpoint.enable \
--checkpoint.exclude_from_loading=lr_scheduler,dataloader,optimizer "$@"
8 changes: 4 additions & 4 deletions torchtitan/experiments/flux/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
],
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--training.steps 20",
],
],
Expand All @@ -57,15 +57,15 @@ def build_test_list():
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.enable",
"--checkpoint.last_save_model_only",
],
],
"Checkpoint Integration Test - Save Model Only fp32",
"last_save_model_only_fp32",
),
OverrideDefinitions(
[["--validation.enabled"]], "Flux Validation Test", "validation"
[["--validation.enable"]], "Flux Validation Test", "validation"
),
# Parallelism tests.
OverrideDefinitions(
Expand Down
Loading
Loading