Skip to content

Commit 7c653cb

Browse files
authored
docs(user-guide): document use_optimized_save and use_cached_ckpt_structure (#72)
1 parent dcb28da commit 7c653cb

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

docs/user-guide.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ auto_resume = wrap_trainer_and_auto_resume_with_mlflashpoint(
9292
# always_save_context=False, # Optional, defaults to False
9393
# write_thread_count=1, # Optional, defaults to 1
9494
# initial_write_buffer_size_bytes=DESIRED_NUM_BYTES, # Optional, defaults to 16 GB
95+
# use_optimized_save=True, # Optional, defaults to True. Uses the optimized save method to reduce write time.
96+
# use_cached_ckpt_structure=True, # Optional, defaults to False. Caches the checkpoint structure after identifying 2 consecutive save plan structures that are equal.
9597
)
9698
```
9799

@@ -148,6 +150,7 @@ memory_storage_writer = MemoryStorageWriter(...)
148150
# Use it to instantiate the Save Strategy
149151
megatron_save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
150152
storage_writer=memory_storage_writer,
153+
# use_cached_ckpt_structure=True, # Optional, defaults to False. Caches the checkpoint structure after identifying 2 consecutive save plan structures that are equal.
151154
)
152155
```
153156

@@ -167,7 +170,7 @@ async_request = save_local_aware_megatron_checkpoint(
167170

168171
!!! note
169172

170-
Make sure to specify the checkpoint ID/path when saving based on the current step using:
173+
Make sure to specify the checkpoint ID/path when saving based on the current step using:
171174
`CheckpointContainerId.create_child(base_container, CheckpointContainerId.format_version_container(current_step))`
172175
where `base_container` is the base path CheckpointContainerId used for all checkpoints for the current job, e.g. `"/tmp/mlf-checkpoints/job123"`.
173176

@@ -229,7 +232,7 @@ Code: See the [`ml_flashpoint.adapter.pytorch`](https://github.com/google/ml-fla
229232
To use directly with PyTorch DCP, use the provided `StorageWriter` and `StorageReader` implementations.
230233
You can use whatever `Planner` implementations work for your use case, or resort to the defaults.
231234

232-
If your per-rank checkpoint data exceeds the default buffer size (16 GB as of this writing), you can increase it using the optional `initial_buffer_size_bytes` parameter.
235+
If your per-rank checkpoint data exceeds the default buffer size (16 GB as of this writing), you can increase it using the optional `initial_buffer_size_bytes` parameter.
233236

234237
#### Imports
235238
```python
@@ -262,6 +265,7 @@ memory_storage_writer = MemoryStorageWriter(
262265
ckpt_obj_manager=checkpoint_object_manager,
263266
replication_manager=replication_manager,
264267
# initial_buffer_size_bytes=initial_write_buffer_size_bytes, # Optional - increase for larger checkpoint sizes per rank
268+
# use_optimized_save=True, # Optional, defaults to True. Uses the optimized save method to reduce write time.
265269
),
266270
mp_manager=torch_mp.Manager(),
267271
)

0 commit comments

Comments
 (0)