Skip to content

Commit e9ac6f4

Browse files
saumishrfacebook-github-bot
authored andcommitted
[DCP][OSS] Rank local checkpointing in DCP without collectives (pytorch#147758)
Summary: X-link: pytorch/tnt#991 Context: DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon. Test Plan: E2E UTs Save and load test with internal DCP components: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-lv5d7qcfmnqzkd Save and load test with OSS DCP components: https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-z1vz46vkkgtcld https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-textray-pretrain_mlm-njvvbn07rv5ckd Reviewed By: meetv18 Differential Revision: D70112642
1 parent 8d1cf52 commit e9ac6f4

File tree

9 files changed

+207
-32
lines changed

9 files changed

+207
-32
lines changed

test/distributed/checkpoint/test_checkpoint.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
import sys
5-
from typing import cast, Optional, Union
5+
from typing import Any, cast, Optional, Union
66

77
import torch
88
import torch.distributed as dist
@@ -170,7 +170,9 @@ def __init__(self, fail_conf):
170170
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
171171
return
172172

173-
def set_up_storage_writer(self, is_coordinator: bool) -> None:
173+
def set_up_storage_writer(
174+
self, is_coordinator: bool, *args: Any, **kwargs: Any
175+
) -> None:
174176
self._fail_rank("fail_set_up_storage_writer")
175177

176178
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:

torch/distributed/checkpoint/_async_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def execute_save(
2121
storage_writer: Optional[StorageWriter] = None,
2222
planner: Optional[SavePlanner] = None,
2323
process_group: Optional[dist.ProcessGroup] = None,
24+
no_dist: bool = False,
25+
use_collectives: bool = True,
2426
) -> Future:
2527
"""
2628
Execute the checkpoint save request asynchronously.

torch/distributed/checkpoint/_async_process_executor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class _AsyncCheckpointRequest:
4444
checkpoint_request_id: _CheckpointRequestIdentifier
4545
storage_writer: Optional[StorageWriter] = None
4646
planner: Optional[SavePlanner] = None
47+
no_dist: bool = False
48+
use_collectives: bool = True
4749

4850

4951
@dataclass(init=False)
@@ -150,6 +152,8 @@ def save(
150152
checkpoint_id: Union[str, os.PathLike, None] = None,
151153
storage_writer: Optional[StorageWriter] = None,
152154
planner: Optional[SavePlanner] = None,
155+
no_dist: bool = False,
156+
use_collectives: bool = True,
153157
) -> Metadata:
154158
# Create a unique identifier to locate requests/responses
155159
# from the checkpoint daemon process.
@@ -159,6 +163,8 @@ def save(
159163
checkpoint_request_id=checkpoint_request_id,
160164
storage_writer=storage_writer,
161165
planner=planner,
166+
no_dist=no_dist,
167+
use_collectives=use_collectives,
162168
)
163169
self._send(async_cp_request)
164170
result = self._wait_for_response()
@@ -172,6 +178,8 @@ def _execute_save(
172178
checkpoint_request_id: _CheckpointRequestIdentifier,
173179
storage_writer: Optional[StorageWriter] = None,
174180
planner: Optional[SavePlanner] = None,
181+
no_dist: bool = False,
182+
use_collectives: bool = True,
175183
) -> Metadata:
176184
from torch.distributed.checkpoint.state_dict_saver import save
177185

@@ -180,6 +188,8 @@ def _execute_save(
180188
checkpoint_id=checkpoint_request_id.checkpoint_id,
181189
storage_writer=storage_writer,
182190
planner=planner,
191+
no_dist=no_dist,
192+
use_collectives=use_collectives,
183193
)
184194
return metadata
185195

@@ -239,6 +249,8 @@ def _checkpointing_subprocess(
239249
checkpoint_request_id=obj.checkpoint_request_id,
240250
storage_writer=obj.storage_writer,
241251
planner=obj.planner,
252+
no_dist=obj.no_dist,
253+
use_collectives=obj.use_collectives,
242254
)
243255
parent_conn.send(response)
244256
logger.info(
@@ -272,6 +284,8 @@ def _execute_save_impl(
272284
storage_writer: Optional[StorageWriter] = None,
273285
planner: Optional[SavePlanner] = None,
274286
process_group: Optional[dist.ProcessGroup] = None,
287+
no_dist: bool = False,
288+
use_collectives: bool = True,
275289
) -> Metadata:
276290
global _CHECKPOINT_PROCESS
277291
if _CHECKPOINT_PROCESS is None:
@@ -299,6 +313,8 @@ def create_checkpoint_daemon_process() -> None:
299313
checkpoint_id=checkpoint_id,
300314
storage_writer=storage_writer,
301315
planner=planner,
316+
no_dist=no_dist,
317+
use_collectives=use_collectives,
302318
)
303319

304320
def execute_save(
@@ -309,6 +325,8 @@ def execute_save(
309325
storage_writer: Optional[StorageWriter] = None,
310326
planner: Optional[SavePlanner] = None,
311327
process_group: Optional[dist.ProcessGroup] = None,
328+
no_dist: bool = False,
329+
use_collectives: bool = True,
312330
) -> Future:
313331
"""
314332
NOTE:
@@ -339,6 +357,8 @@ def execute_save(
339357
checkpoint_id=checkpoint_id,
340358
storage_writer=storage_writer,
341359
planner=planner,
360+
no_dist=no_dist,
361+
use_collectives=use_collectives,
342362
)
343363
f.add_done_callback(lambda f: self._executor.shutdown(wait=False))
344364

torch/distributed/checkpoint/_async_thread_executor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def save_wrapper(
1818
storage_writer: Optional[StorageWriter] = None,
1919
planner: Optional[SavePlanner] = None,
2020
process_group: Optional[dist.ProcessGroup] = None,
21+
no_dist: bool = False,
22+
use_collectives: bool = True,
2123
) -> Future:
2224
from torch.distributed.checkpoint.state_dict_saver import save
2325

@@ -32,6 +34,8 @@ def save_wrapper(
3234
storage_writer=storage_writer,
3335
planner=planner,
3436
process_group=process_group,
37+
no_dist=no_dist,
38+
use_collectives=use_collectives,
3539
)
3640

3741

@@ -49,6 +53,8 @@ def execute_save(
4953
storage_writer: Optional[StorageWriter] = None,
5054
planner: Optional[SavePlanner] = None,
5155
process_group: Optional[dist.ProcessGroup] = None,
56+
no_dist: bool = False,
57+
use_collectives: bool = True,
5258
) -> Future:
5359
f: Future = self._executor.submit(
5460
save_wrapper,
@@ -57,6 +63,8 @@ def execute_save(
5763
storage_writer=storage_writer,
5864
planner=planner,
5965
process_group=process_group,
66+
no_dist=no_dist,
67+
use_collectives=use_collectives,
6068
)
6169
f.add_done_callback(lambda f: self._executor.shutdown(wait=False))
6270

torch/distributed/checkpoint/default_planner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bo
654654

655655
# Check whether combined chunk cover the whole tensor
656656
tensor_volume = reduce(operator.mul, value.size, 1)
657-
if chunks_volume != tensor_volume:
657+
if len(global_plan) > 1 and chunks_volume != tensor_volume:
658658
logger.warning(
659659
"""
660660
key:%s invalid fill tensor-volume:

torch/distributed/checkpoint/filesystem.py

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -620,32 +620,55 @@ def __init__(
620620
self.overwrite = overwrite
621621
self.transforms = _StorageWriterTransforms(_extensions)
622622
self.serialization_format = serialization_format
623+
self.rank: Optional[int] = None
624+
self.use_collectives: bool = True
623625

624626
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
625627
if checkpoint_id:
626628
self.path = self.fs.init_path(checkpoint_id)
627629
self.save_id = _generate_uuid()
628630

629-
def set_up_storage_writer(self, is_coordinator: bool) -> None:
630-
pass
631+
def set_up_storage_writer(
632+
self, is_coordinator: bool, *args: Any, **kwargs: Any
633+
) -> None:
634+
self.rank = kwargs.get("rank", None)
635+
self.use_collectives = kwargs.get("use_collectives", True)
636+
637+
def _metadata_exists(self) -> bool:
638+
if self.use_collectives:
639+
# A global checkpoint metadata file
640+
metadata_path = self._get_metadata_path(rank=None)
641+
else:
642+
# A rank 0 specific metadata file if every rank has written its own metadata
643+
# Just looking for lowest rank metadata file is sufficient
644+
metadata_path = self._get_metadata_path(rank=0)
645+
646+
return self.fs.exists(metadata_path)
631647

632648
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
633649
self.fs.mkdir(self.path)
634-
if self.fs.exists(self.metadata_path):
650+
if self._metadata_exists():
635651
if self.overwrite:
636652
warnings.warn(
637-
f"Detected an existing checkpoint in {self.metadata_path}, overwriting since {self.overwrite=}."
653+
f"Detected an existing checkpoint in {self.path}, overwriting since {self.overwrite=}."
638654
" Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to"
639655
" maintain this functionality or False to raise when an existing checkpoint is found."
640656
)
641657
else:
642658
raise RuntimeError(f"Checkpoint already exists and {self.overwrite=}.")
643659

660+
if self.rank is not None and not self.use_collectives:
661+
plan = dataclasses.replace(
662+
plan, storage_data=_StoragePrefix(f"__{self.rank}_")
663+
)
664+
644665
return plan
645666

646667
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
647668
new_plans = [
648669
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
670+
if plan.storage_data is None
671+
else plan
649672
for i, plan in enumerate(plans)
650673
]
651674
return new_plans
@@ -737,8 +760,12 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
737760
metadata.storage_data = storage_md
738761

739762
metadata.storage_meta = self.storage_meta()
740-
741-
tmp_path = cast(Path, self.fs.concat_path(self.path, f"{_metadata_fn}.tmp"))
763+
tmp_filename = (
764+
f"__{self.rank}{_metadata_fn}.tmp"
765+
if not self.use_collectives and self.rank is not None
766+
else f"{_metadata_fn}.tmp"
767+
)
768+
tmp_path = cast(Path, self.fs.concat_path(self.path, tmp_filename))
742769
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
743770
pickle.dump(metadata, metadata_file)
744771
if self.sync_files:
@@ -748,17 +775,22 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
748775
os.sync()
749776

750777
# delete in-case other checkpoints were present.
751-
if self.fs.exists(self.metadata_path):
752-
self.fs.rm_file(self.metadata_path)
778+
if not self.use_collectives and self.rank is not None:
779+
metadata_path = self._get_metadata_path(self.rank)
780+
else:
781+
metadata_path = self._get_metadata_path()
753782

754-
self.fs.rename(tmp_path, self.metadata_path)
783+
if self.fs.exists(metadata_path):
784+
self.fs.rm_file(metadata_path)
785+
786+
self.fs.rename(tmp_path, metadata_path)
755787

756788
def storage_meta(self) -> Optional[StorageMeta]:
757789
return StorageMeta(checkpoint_id=self.checkpoint_id, save_id=self.save_id)
758790

759-
@property
760-
def metadata_path(self) -> Union[str, os.PathLike]:
761-
return cast(Path, self.fs.concat_path(self.path, _metadata_fn))
791+
def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike:
792+
filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}"
793+
return cast(Path, self.fs.concat_path(self.path, filename))
762794

763795
@property
764796
def checkpoint_id(self) -> Union[str, os.PathLike]:
@@ -810,6 +842,8 @@ def __init__(
810842
self.storage_data: dict[Any, Any] = {}
811843
self.load_id = _generate_uuid()
812844
self.transforms = _StorageReaderTransforms(_extension_registry)
845+
self.rank = None
846+
self.use_collectives = True
813847

814848
def _slice_file(self, file, sinfo: _StorageInfo) -> IO[bytes]:
815849
return cast(IO[bytes], _create_file_view(file, sinfo.offset, sinfo.length))
@@ -879,9 +913,14 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
879913
fut.set_result(None)
880914
return fut
881915

916+
def _get_metadata_path(self, rank: Optional[int] = None) -> os.PathLike:
917+
filename = f"{_metadata_fn}" if rank is None else f"__{rank}{_metadata_fn}"
918+
return cast(Path, self.fs.concat_path(self.path, filename))
919+
882920
# Implementing the abstract function in StorageReader
883-
def read_metadata(self) -> Metadata:
884-
path = self.fs.concat_path(self.path, ".metadata")
921+
def read_metadata(self, *args: Any, **kwargs: Any) -> Metadata:
922+
rank = kwargs.get("rank", None)
923+
path = self._get_metadata_path(rank)
885924
with self.fs.create_stream(path, "rb") as metadata_file:
886925
metadata = pickle.load(metadata_file)
887926

@@ -891,8 +930,12 @@ def read_metadata(self) -> Metadata:
891930

892931
return metadata
893932

894-
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
933+
def set_up_storage_reader(
934+
self, metadata: Metadata, is_coordinator: bool, *args: Any, **kwargs: Any
935+
) -> None:
895936
self.storage_data = metadata.storage_data
937+
self.rank = kwargs.get("rank", None)
938+
self.use_collectives = kwargs.get("use_collectives", True)
896939
assert self.storage_data is not None
897940

898941
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
@@ -923,7 +966,8 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
923966
* File creation is atomic
924967
925968
The checkpoint consist of one file per write request plus
926-
a `.metadata` file with the serialized metadata.
969+
a global `.metadata` file with the serialized metadata if rank coordination is enabled.
970+
a rank local `__{rank}.metadata` file with the serialized metadata if rank coordination is NOT enabled.
927971
928972
"""
929973

0 commit comments

Comments
 (0)