Skip to content

Commit 19bd2df

Browse files
committed
changed enable_checkpoint->enable
1 parent 39c8249 commit 19bd2df

File tree

28 files changed

+64
-64
lines changed

28 files changed

+64
-64
lines changed

docs/checkpoint.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,26 @@ You may want to enable checkpointing in `torchtitan` for better fault tolerance
55
## A general guide to use checkpoints during training
66

77
1. ENABLE CHECKPOINTING
8-
In your `torchtitan` training config, ensure that `enable_checkpoint` is set to True.
8+
In your `torchtitan` training config, ensure that under `[checkpoint]`, `enable` is set to True.
99
```
1010
[checkpoint]
11-
enable_checkpoint = true
11+
enable = true
1212
folder = "checkpoint"
1313
interval = 500
1414
```
1515
2. SAVE MODEL ONLY
1616
By setting `last_save_model_only` to `True`, the checkpoint will only contain the model and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
1717
```
1818
[checkpoint]
19-
enable_checkpoint = true
19+
enable = true
2020
last_save_model_only = true
2121
```
2222

2323
3. CHOOSE DESIRED EXPORT PRECISION
2424
The default model states are in `float32`. You can choose to export the checkpoint in a lower precision format such as `bfloat16`.
2525
```
2626
[checkpoint]
27-
enable_checkpoint = true
27+
enable = true
2828
last_save_model_only = true
2929
export_dtype = "bfloat16"
3030
```
@@ -34,15 +34,15 @@ In some cases, you may want to partially load from a previous-trained checkpoint
3434
This parameter takes a list of string that should be excluded from loading.
3535
```
3636
[checkpoint]
37-
enable_checkpoint = true
37+
enable = true
3838
exclude_from_loading = ["data_loader", "lr_scheduler"]
3939
```
4040
When used in command line, the parameter should be a comma-separated list of strings. For example: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`.
4141

4242
5. EXAMPLE CHECKPOINT CONFIGURATION
4343
```
4444
[checkpoint]
45-
enable_checkpoint = true
45+
enable = true
4646
folder = "checkpoint"
4747
interval = 10
4848
load_step = 5
@@ -60,7 +60,7 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
6060
To create a seed checkpoint, use the same model config as you use for training.
6161
e.g.
6262
```bash
63-
NGPU=1 CONFIG_FILE=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
63+
NGPU=1 CONFIG_FILE=<path_to_model_config> ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
6464
```
6565

6666
## Conversion support
@@ -86,7 +86,7 @@ This guide will walk you through the steps required to convert a checkpoint from
8686
1. CHECKPOINT CONFIGURATION
8787
```
8888
[checkpoint]
89-
enable_checkpoint = true
89+
enable = true
9090
folder = "checkpoint"
9191
interval = 10
9292
last_save_model_only = true

docs/debugging.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ For multiple experimental runs with different parallelism configs, we need to us
100100

101101

102102
```bash
103-
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
103+
NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --checkpoint.enable --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
104104
```
105105

106106
**Note**: Using a seed checkpoint will only make sure a model has same initial weights when configs change, but the training process may not be the same even after setting the seed and the `deterministic` mode, e.g. due to tensor shape change, data precision change, usage of randomness in model code, etc.

tests/integration_tests.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ def build_test_list():
108108
OverrideDefinitions(
109109
[
110110
[
111-
"--checkpoint.enable_checkpoint",
111+
"--checkpoint.enable",
112112
],
113113
[
114-
"--checkpoint.enable_checkpoint",
114+
"--checkpoint.enable",
115115
"--training.steps 20",
116116
],
117117
],
@@ -121,13 +121,13 @@ def build_test_list():
121121
OverrideDefinitions(
122122
[
123123
[
124-
"--checkpoint.enable_checkpoint",
124+
"--checkpoint.enable",
125125
"--checkpoint.folder hf_checkpoint",
126126
"--checkpoint.last_save_model_only",
127127
"--checkpoint.last_save_in_hf",
128128
],
129129
[
130-
"--checkpoint.enable_checkpoint",
130+
"--checkpoint.enable",
131131
"--checkpoint.initial_load_path artifacts-to-be-uploaded/model_only_hf_checkpoint/hf_checkpoint/step-10/",
132132
"--checkpoint.initial_load_model_only",
133133
"--checkpoint.initial_load_in_hf",
@@ -139,7 +139,7 @@ def build_test_list():
139139
OverrideDefinitions(
140140
[
141141
[
142-
"--checkpoint.enable_checkpoint",
142+
"--checkpoint.enable",
143143
"--checkpoint.last_save_model_only",
144144
],
145145
],
@@ -149,7 +149,7 @@ def build_test_list():
149149
OverrideDefinitions(
150150
[
151151
[
152-
"--checkpoint.enable_checkpoint",
152+
"--checkpoint.enable",
153153
"--checkpoint.last_save_model_only",
154154
"--checkpoint.export_dtype bfloat16",
155155
],
@@ -244,14 +244,14 @@ def build_test_list():
244244
OverrideDefinitions(
245245
[
246246
[
247-
"--checkpoint.enable_checkpoint",
247+
"--checkpoint.enable",
248248
"--parallelism.pipeline_parallel_degree 2",
249249
"--parallelism.data_parallel_shard_degree 2",
250250
"--parallelism.tensor_parallel_degree 2",
251251
],
252252
[
253253
"--training.steps 20",
254-
"--checkpoint.enable_checkpoint",
254+
"--checkpoint.enable",
255255
"--parallelism.pipeline_parallel_degree 2",
256256
"--parallelism.data_parallel_shard_degree 2",
257257
"--parallelism.tensor_parallel_degree 2",
@@ -443,7 +443,7 @@ def build_test_list():
443443
OverrideDefinitions(
444444
[
445445
[
446-
"--checkpoint.enable_checkpoint",
446+
"--checkpoint.enable",
447447
"--parallelism.tensor_parallel_degree=2",
448448
"--parallelism.context_parallel_degree=2",
449449
"--training.enable_cpu_offload",
@@ -474,7 +474,7 @@ def build_test_list():
474474
OverrideDefinitions(
475475
[
476476
[
477-
"--checkpoint.enable_checkpoint",
477+
"--checkpoint.enable",
478478
],
479479
[
480480
# placeholder for the generation script's generate step
@@ -497,13 +497,13 @@ def build_test_list():
497497
OverrideDefinitions(
498498
[
499499
[
500-
"--checkpoint.enable_checkpoint",
500+
"--checkpoint.enable",
501501
"--training.steps 10",
502502
],
503503
# Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be
504504
# excluded during loading to avoid errors caused by mismatched dp_degree.
505505
[
506-
"--checkpoint.enable_checkpoint",
506+
"--checkpoint.enable",
507507
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
508508
"--parallelism.tensor_parallel_degree 2",
509509
"--training.steps 20",

tests/integration_tests_ft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def build_test_list():
3232
integration_tests_flavors["debug_model.toml"] = [
3333
OverrideDefinitions(
3434
[
35-
["--training.steps 10", "--checkpoint.enable_checkpoint"],
35+
["--training.steps 10", "--checkpoint.enable"],
3636
],
3737
"Default TorchFT integration test",
3838
"default_torchft",

tests/unit_tests/test_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class DummyJobConfig:
8383
def __init__(self, job):
8484
self.job = job
8585
self.checkpoint = CheckpointConfig(
86-
enable_checkpoint=True,
86+
enable=True,
8787
async_mode="disabled",
8888
folder="",
8989
interval=1,
@@ -114,7 +114,7 @@ def setUp(self):
114114
self.ft_manager = DummyFTManager()
115115

116116
ckpt_cfg = CheckpointConfig(
117-
enable_checkpoint=True,
117+
enable=True,
118118
async_mode="DISABLED",
119119
folder="",
120120
interval=1,

torchtitan/components/checkpoint.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(
186186
base_folder: str = "",
187187
ft_manager: FTManager | None = None,
188188
) -> None:
189-
self.enable_checkpoint = checkpoint_config.enable_checkpoint
189+
self.enable = checkpoint_config.enable
190190

191191
self.ft_manager = (
192192
ft_manager.manager if ft_manager and ft_manager.enabled else None
@@ -216,10 +216,10 @@ def load_state_dict(state_dict):
216216

217217
async_mode = checkpoint_config.async_mode.lower()
218218
self.enable_staging = (
219-
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
219+
self.enable and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
220220
) or self.ft_manager
221221

222-
if not self.enable_checkpoint and self.ft_manager is None:
222+
if not self.enable and self.ft_manager is None:
223223
return
224224

225225
self.states = states
@@ -305,7 +305,7 @@ def __del__(self):
305305
self.close()
306306

307307
def close(self):
308-
if hasattr(self, "enable_checkpoint") and self.enable_checkpoint:
308+
if hasattr(self, "enable") and self.enable:
309309
if hasattr(self, "mp") and self.mp and self.mp.is_alive():
310310
self.mp_queue_send.put(Terminate())
311311
self.mp.join()
@@ -517,7 +517,7 @@ def load(self, step: int = -1) -> bool:
517517
if self.ft_manager:
518518
self._ft_load()
519519

520-
if not self.enable_checkpoint:
520+
if not self.enable:
521521
return False
522522

523523
model_only = False
@@ -739,7 +739,7 @@ def _save_last_step(self, curr_step: int) -> None:
739739
)
740740

741741
def _should_save(self, curr_step: int, last_step: bool = False) -> bool:
742-
if not self.enable_checkpoint:
742+
if not self.enable:
743743
return False
744744

745745
if curr_step == 1 and self.enable_first_step_checkpoint:

torchtitan/config/job_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,13 +398,13 @@ class Parallelism:
398398

399399
@dataclass
400400
class Checkpoint:
401-
enable_checkpoint: bool = False
401+
enable: bool = False
402402
"""Whether to enable checkpoint"""
403403

404404
folder: str = "checkpoint"
405405
"""
406406
The folder to store the checkpoints.
407-
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
407+
When enable is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
408408
"""
409409

410410
interval: int = 500

torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ context_parallel_degree = 1
5656
# expert_parallel_degree = 2 set in custom_args
5757

5858
[checkpoint]
59-
enable_checkpoint = false
59+
enable = false
6060
folder = "checkpoint"
6161
interval = 10
6262
model_weights_only = false

torchtitan/experiments/flux/inference/run_infer.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ PYTORCH_ALLOC_CONF="expandable_segments:True" \
1818
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
1919
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2020
-m torchtitan.experiments.flux.inference.infer --job.config_file ${CONFIG_FILE} \
21-
--checkpoint.enable_checkpoint \
21+
--checkpoint.enable \
2222
--checkpoint.exclude_from_loading=lr_scheduler,dataloader,optimizer "$@"

torchtitan/experiments/flux/tests/integration_tests.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def build_test_list():
4444
OverrideDefinitions(
4545
[
4646
[
47-
"--checkpoint.enable_checkpoint",
47+
"--checkpoint.enable",
4848
],
4949
[
50-
"--checkpoint.enable_checkpoint",
50+
"--checkpoint.enable",
5151
"--training.steps 20",
5252
],
5353
],
@@ -57,7 +57,7 @@ def build_test_list():
5757
OverrideDefinitions(
5858
[
5959
[
60-
"--checkpoint.enable_checkpoint",
60+
"--checkpoint.enable",
6161
"--checkpoint.last_save_model_only",
6262
],
6363
],

0 commit comments

Comments
 (0)