@@ -620,32 +620,55 @@ def __init__(
620
620
self .overwrite = overwrite
621
621
self .transforms = _StorageWriterTransforms (_extensions )
622
622
self .serialization_format = serialization_format
623
+ self .rank : Optional [int ] = None
624
+ self .use_collectives : bool = True
623
625
624
626
def reset (self , checkpoint_id : Union [str , os .PathLike , None ] = None ) -> None :
625
627
if checkpoint_id :
626
628
self .path = self .fs .init_path (checkpoint_id )
627
629
self .save_id = _generate_uuid ()
628
630
629
- def set_up_storage_writer (self , is_coordinator : bool ) -> None :
630
- pass
631
+ def set_up_storage_writer (
632
+ self , is_coordinator : bool , * args : Any , ** kwargs : Any
633
+ ) -> None :
634
+ self .rank = kwargs .get ("rank" , None )
635
+ self .use_collectives = kwargs .get ("use_collectives" , True )
636
+
637
+ def _metadata_exists (self ) -> bool :
638
+ if self .use_collectives :
639
+ # A global checkpoint metadata file
640
+ metadata_path = self ._get_metadata_path (rank = None )
641
+ else :
642
+ # A rank 0 specific metadata file if every rank has written its own metadata
643
+ # Just looking for lowest rank metadata file is sufficient
644
+ metadata_path = self ._get_metadata_path (rank = 0 )
645
+
646
+ return self .fs .exists (metadata_path )
631
647
632
648
def prepare_local_plan (self , plan : SavePlan ) -> SavePlan :
633
649
self .fs .mkdir (self .path )
634
- if self .fs . exists ( self . metadata_path ):
650
+ if self ._metadata_exists ( ):
635
651
if self .overwrite :
636
652
warnings .warn (
637
- f"Detected an existing checkpoint in { self .metadata_path } , overwriting since { self .overwrite = } ."
653
+ f"Detected an existing checkpoint in { self .path } , overwriting since { self .overwrite = } ."
638
654
" Past version 2.5 of PyTorch, `overwrite` will default to False. Set this variable to True to"
639
655
" maintain this functionality or False to raise when an existing checkpoint is found."
640
656
)
641
657
else :
642
658
raise RuntimeError (f"Checkpoint already exists and { self .overwrite = } ." )
643
659
660
+ if self .rank is not None and not self .use_collectives :
661
+ plan = dataclasses .replace (
662
+ plan , storage_data = _StoragePrefix (f"__{ self .rank } _" )
663
+ )
664
+
644
665
return plan
645
666
646
667
def prepare_global_plan (self , plans : list [SavePlan ]) -> list [SavePlan ]:
647
668
new_plans = [
648
669
dataclasses .replace (plan , storage_data = _StoragePrefix (f"__{ i } _" ))
670
+ if plan .storage_data is None
671
+ else plan
649
672
for i , plan in enumerate (plans )
650
673
]
651
674
return new_plans
@@ -737,8 +760,12 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
737
760
metadata .storage_data = storage_md
738
761
739
762
metadata .storage_meta = self .storage_meta ()
740
-
741
- tmp_path = cast (Path , self .fs .concat_path (self .path , f"{ _metadata_fn } .tmp" ))
763
+ tmp_filename = (
764
+ f"__{ self .rank } { _metadata_fn } .tmp"
765
+ if not self .use_collectives and self .rank is not None
766
+ else f"{ _metadata_fn } .tmp"
767
+ )
768
+ tmp_path = cast (Path , self .fs .concat_path (self .path , tmp_filename ))
742
769
with self .fs .create_stream (tmp_path , "wb" ) as metadata_file :
743
770
pickle .dump (metadata , metadata_file )
744
771
if self .sync_files :
@@ -748,17 +775,22 @@ def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
748
775
os .sync ()
749
776
750
777
# delete in-case other checkpoints were present.
751
- if self .fs .exists (self .metadata_path ):
752
- self .fs .rm_file (self .metadata_path )
778
+ if not self .use_collectives and self .rank is not None :
779
+ metadata_path = self ._get_metadata_path (self .rank )
780
+ else :
781
+ metadata_path = self ._get_metadata_path ()
753
782
754
- self .fs .rename (tmp_path , self .metadata_path )
783
+ if self .fs .exists (metadata_path ):
784
+ self .fs .rm_file (metadata_path )
785
+
786
+ self .fs .rename (tmp_path , metadata_path )
755
787
756
788
def storage_meta (self ) -> Optional [StorageMeta ]:
757
789
return StorageMeta (checkpoint_id = self .checkpoint_id , save_id = self .save_id )
758
790
759
- @ property
760
- def metadata_path ( self ) -> Union [ str , os . PathLike ]:
761
- return cast (Path , self .fs .concat_path (self .path , _metadata_fn ))
791
+ def _get_metadata_path ( self , rank : Optional [ int ] = None ) -> os . PathLike :
792
+ filename = f" { _metadata_fn } " if rank is None else f"__ { rank } { _metadata_fn } "
793
+ return cast (Path , self .fs .concat_path (self .path , filename ))
762
794
763
795
@property
764
796
def checkpoint_id (self ) -> Union [str , os .PathLike ]:
@@ -810,6 +842,8 @@ def __init__(
810
842
self .storage_data : dict [Any , Any ] = {}
811
843
self .load_id = _generate_uuid ()
812
844
self .transforms = _StorageReaderTransforms (_extension_registry )
845
+ self .rank = None
846
+ self .use_collectives = True
813
847
814
848
def _slice_file (self , file , sinfo : _StorageInfo ) -> IO [bytes ]:
815
849
return cast (IO [bytes ], _create_file_view (file , sinfo .offset , sinfo .length ))
@@ -879,9 +913,14 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
879
913
fut .set_result (None )
880
914
return fut
881
915
916
+ def _get_metadata_path (self , rank : Optional [int ] = None ) -> os .PathLike :
917
+ filename = f"{ _metadata_fn } " if rank is None else f"__{ rank } { _metadata_fn } "
918
+ return cast (Path , self .fs .concat_path (self .path , filename ))
919
+
882
920
# Implementing the abstract function in StorageReader
883
- def read_metadata (self ) -> Metadata :
884
- path = self .fs .concat_path (self .path , ".metadata" )
921
+ def read_metadata (self , * args : Any , ** kwargs : Any ) -> Metadata :
922
+ rank = kwargs .get ("rank" , None )
923
+ path = self ._get_metadata_path (rank )
885
924
with self .fs .create_stream (path , "rb" ) as metadata_file :
886
925
metadata = pickle .load (metadata_file )
887
926
@@ -891,8 +930,12 @@ def read_metadata(self) -> Metadata:
891
930
892
931
return metadata
893
932
894
- def set_up_storage_reader (self , metadata : Metadata , is_coordinator : bool ) -> None :
933
+ def set_up_storage_reader (
934
+ self , metadata : Metadata , is_coordinator : bool , * args : Any , ** kwargs : Any
935
+ ) -> None :
895
936
self .storage_data = metadata .storage_data
937
+ self .rank = kwargs .get ("rank" , None )
938
+ self .use_collectives = kwargs .get ("use_collectives" , True )
896
939
assert self .storage_data is not None
897
940
898
941
def prepare_local_plan (self , plan : LoadPlan ) -> LoadPlan :
@@ -923,7 +966,8 @@ class FileSystemWriter(_FileSystemWriter, BlockingAsyncStager):
923
966
* File creation is atomic
924
967
925
968
The checkpoint consist of one file per write request plus
926
- a `.metadata` file with the serialized metadata.
969
+ a global `.metadata` file with the serialized metadata if rank coordination is enabled.
970
+ a rank local `__{rank}.metadata` file with the serialized metadata if rank coordination is NOT enabled.
927
971
928
972
"""
929
973
0 commit comments