@@ -2692,8 +2692,7 @@ def load_state_dict(self,
2692
2692
load_from_fp32_weights = False ,
2693
2693
checkpoint_folder = None ,
2694
2694
load_serial = None ,
2695
- param_shapes = None ,
2696
- ignore_missing_optim_state : bool = False ):
2695
+ param_shapes = None ):
2697
2696
r"""Loading a ZeRO checkpoint
2698
2697
Arguments:
2699
2698
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
@@ -2724,7 +2723,7 @@ def load_state_dict(self,
2724
2723
2725
2724
if checkpoint_folder :
2726
2725
self ._load_universal_checkpoint (checkpoint_folder , load_optimizer_states , load_from_fp32_weights ,
2727
- param_shapes , ignore_missing_optim_state = ignore_missing_optim_state )
2726
+ param_shapes )
2728
2727
else :
2729
2728
self ._rigid_load_state_dict (state_dict_list [dist .get_rank (group = self .dp_process_group )],
2730
2729
load_optimizer_states = load_optimizer_states )
@@ -2746,19 +2745,20 @@ def load_state_dict(self,
2746
2745
# self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather
2747
2746
2748
2747
def _load_universal_checkpoint (self , checkpoint_folder , load_optimizer_states , load_from_fp32_weights ,
2749
- param_shapes , ignore_missing_optim_state ):
2750
- self .load_hp_checkpoint_state_from_checkpoint_dir_stage3 (checkpoint_folder , param_shapes , ignore_missing_optim_state )
2748
+ param_shapes ):
2749
+ self .load_hp_checkpoint_state_from_checkpoint_dir_stage3 (checkpoint_folder , param_shapes )
2751
2750
2752
- def load_hp_checkpoint_state_from_checkpoint_dir_stage3 (self , checkpoint_dir , param_shapes , ignore_missing_optim_state ):
2751
+ def load_hp_checkpoint_state_from_checkpoint_dir_stage3 (self , checkpoint_dir , param_shapes ):
2753
2752
""" Load optimizer and model states from the checkpoint directory. """
2754
2753
checkpoint_dir = os .path .join (checkpoint_dir , "zero" )
2755
2754
optim_state_path = os .path .join (checkpoint_dir , "optimizer_state.pt" )
2756
- if not ignore_missing_optim_state :
2757
- assert os .path .isfile (
2758
- optim_state_path ), f'{ optim_state_path } containing optimizer global state is missing! Cannot proceed.'
2759
-
2755
+ if os .path .isfile (optim_state_path ):
2756
+ ignore_missing_optim_state = False
2760
2757
optim_sd = torch .load (optim_state_path , weights_only = False )
2761
2758
self ._load_global_state_stage3 (optim_sd )
2759
+ else :
2760
+ logger .warning (f'{ optim_state_path } containing optimizer global state is missing!' )
2761
+ ignore_missing_optim_state = True
2762
2762
2763
2763
key_list = ["fp32" , "exp_avg" , "exp_avg_sq" ]
2764
2764
@@ -2777,6 +2777,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
2777
2777
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
2778
2778
self .optimizer_swapper .purge_state ()
2779
2779
2780
+ if self .swap_optimizer :
2780
2781
# Touch all parameters to synchronize all buffers
2781
2782
timer_names = set ()
2782
2783
self ._partition_all_parameters ()
@@ -2813,7 +2814,6 @@ def load_hp_checkpoint_state(self, folder, key):
2813
2814
local_rank = dist .get_local_rank ()
2814
2815
2815
2816
# Load tensors from files and reshape them to flat vectors
2816
-
2817
2817
loaded_state = torch .load (os .path .join (folder , f"{ key } .pt" ), weights_only = False )
2818
2818
if isinstance (loaded_state , dict ):
2819
2819
loaded_checkpoint_state = loaded_state ['param' ].view (- 1 )
0 commit comments