Skip to content

Commit 662a297

Browse files
committed
remove ignore_missing_optim config from zero ds_config
Signed-off-by: Schwidola0607 <[email protected]>
1 parent 2930f2a commit 662a297

File tree

6 files changed

+35
-49
lines changed

6 files changed

+35
-49
lines changed

deepspeed/runtime/base_optimizer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ class DeepSpeedOptimizer(object):
1717

1818
class ZeROOptimizer(DeepSpeedOptimizer):
1919

20-
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str, ignore_missing_optim_state: bool = False) -> None:
20+
def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
2121
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
22-
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
23-
if not ignore_missing_optim_state:
24-
assert os.path.isfile(
25-
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
26-
22+
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
23+
if os.path.isfile(optim_state_path):
24+
ignore_missing_optim_state = False
2725
optim_sd = torch.load(optim_state_path, weights_only=False)
2826
self._load_global_state(optim_sd)
2927
else:
28+
logger.warning(f'{optim_state_path} containing optimizer global state is missing!')
29+
ignore_missing_optim_state = True
3030
optim_sd = {}
3131

3232
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)

deepspeed/runtime/engine.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2952,24 +2952,16 @@ def load_checkpoint(self,
29522952
if self._optimizer_has_ckpt_event_prologue():
29532953
# Prepare for checkpoint load by ensuring all parameters are partitioned
29542954
self.optimizer.checkpoint_event_prologue()
2955-
2956-
if not self.zero_ignore_missing_optim_state():
2957-
# Temporary skip this path for HF-based UCP
2958-
load_path, client_states = self._load_checkpoint(load_dir,
2959-
tag,
2960-
load_module_strict=load_module_strict,
2961-
load_optimizer_states=load_optimizer_states,
2962-
load_lr_scheduler_states=load_lr_scheduler_states,
2963-
load_module_only=load_module_only,
2964-
custom_load_fn=custom_load_fn)
2965-
2966-
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
2967-
2968-
else:
2969-
# What should load_path and client_states be?
2970-
load_path, client_states = None, {}
2971-
load_zero_checkpoint = (self.zero_optimization() or self.bfloat16_enabled())
2972-
2955+
2956+
load_path, client_states = self._load_checkpoint(load_dir,
2957+
tag,
2958+
load_module_strict=load_module_strict,
2959+
load_optimizer_states=load_optimizer_states,
2960+
load_lr_scheduler_states=load_lr_scheduler_states,
2961+
load_module_only=load_module_only,
2962+
custom_load_fn=custom_load_fn)
2963+
2964+
load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled())
29732965
if load_zero_checkpoint:
29742966
if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint():
29752967
success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states)
@@ -3009,7 +3001,7 @@ def _load_checkpoint(self,
30093001
custom_load_fn=None):
30103002

30113003
from deepspeed.runtime.state_dict_factory import SDLoaderFactory
3012-
logger.info(f"Loading checkpoint from {load_dir} with tag {tag}")
3004+
30133005
ckpt_list = self._get_all_ckpt_names(load_dir, tag)
30143006
sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine)
30153007

@@ -3167,8 +3159,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
31673159
load_from_fp32_weights=self.zero_load_from_fp32_weights(),
31683160
checkpoint_folder=checkpoint_folder,
31693161
load_serial=load_serial,
3170-
param_shapes=param_shapes,
3171-
ignore_missing_optim_state=self.zero_ignore_missing_optim_state())
3162+
param_shapes=param_shapes)
31723163

31733164
if self.load_universal_checkpoint():
31743165
logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}')

deepspeed/runtime/state_dict_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def set_module(self, sd, module):
164164
return sd
165165

166166
def check_ckpt_list(self):
167+
#logger.info(f'checkpoint file list: {self.ckpt_list}')
167168
assert len(self.ckpt_list) > 0
168169

169170
sd = self.checkpoint_engine.load(self.ckpt_list[0], map_location=lambda storage, loc: storage)

deepspeed/runtime/zero/config.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
272272
ignore_unused_parameters: bool = True
273273
"""
274274
Unused parameters in modules may be unexpected in static networks, but
275-
could be normal in dynamic networks. This controls whether or not training
275+
could be normal in dynamic networks. This controls whether or not training
276276
should terminate with an error message when unused parameters are detected.
277277
This is set to ``True`` by default, which means unused parameters are
278278
ignored and training continues. Now is just used in stage 2.
@@ -345,11 +345,6 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
345345
"""
346346
Whether to log warnings from trace cache, such as invalidation events.
347347
"""
348-
349-
ignore_missing_optim_state: bool = False
350-
"""
351-
Ignore missing optimizer states when loading checkpoint
352-
"""
353348

354349
# Validators
355350
@model_validator(mode="after")

deepspeed/runtime/zero/stage3.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2692,8 +2692,7 @@ def load_state_dict(self,
26922692
load_from_fp32_weights=False,
26932693
checkpoint_folder=None,
26942694
load_serial=None,
2695-
param_shapes=None,
2696-
ignore_missing_optim_state: bool = False):
2695+
param_shapes=None):
26972696
r"""Loading a ZeRO checkpoint
26982697
Arguments:
26992698
state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition.
@@ -2724,7 +2723,7 @@ def load_state_dict(self,
27242723

27252724
if checkpoint_folder:
27262725
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
2727-
param_shapes, ignore_missing_optim_state=ignore_missing_optim_state)
2726+
param_shapes)
27282727
else:
27292728
self._rigid_load_state_dict(state_dict_list[dist.get_rank(group=self.dp_process_group)],
27302729
load_optimizer_states=load_optimizer_states)
@@ -2746,19 +2745,20 @@ def load_state_dict(self,
27462745
# self.persistent_parameters[0].all_gather(self.persistent_parameters) # this will be done in checkpoint_event_epilogue() so remove it to prevent double all_gather
27472746

27482747
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights,
2749-
param_shapes, ignore_missing_optim_state):
2750-
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes, ignore_missing_optim_state)
2748+
param_shapes):
2749+
self.load_hp_checkpoint_state_from_checkpoint_dir_stage3(checkpoint_folder, param_shapes)
27512750

2752-
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes, ignore_missing_optim_state):
2751+
def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, param_shapes):
27532752
""" Load optimizer and model states from the checkpoint directory. """
27542753
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
27552754
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
2756-
if not ignore_missing_optim_state:
2757-
assert os.path.isfile(
2758-
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
2759-
2755+
if os.path.isfile(optim_state_path):
2756+
ignore_missing_optim_state = False
27602757
optim_sd = torch.load(optim_state_path, weights_only=False)
27612758
self._load_global_state_stage3(optim_sd)
2759+
else:
2760+
logger.warning(f'{optim_state_path} containing optimizer global state is missing!')
2761+
ignore_missing_optim_state = True
27622762

27632763
key_list = ["fp32", "exp_avg", "exp_avg_sq"]
27642764

@@ -2777,6 +2777,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir_stage3(self, checkpoint_dir, pa
27772777
# Purge the swapped optimizer state, it was initialized to the freshly created model and not the checkpoint
27782778
self.optimizer_swapper.purge_state()
27792779

2780+
if self.swap_optimizer:
27802781
# Touch all parameters to synchronize all buffers
27812782
timer_names = set()
27822783
self._partition_all_parameters()
@@ -2813,7 +2814,6 @@ def load_hp_checkpoint_state(self, folder, key):
28132814
local_rank = dist.get_local_rank()
28142815

28152816
# Load tensors from files and reshape them to flat vectors
2816-
28172817
loaded_state = torch.load(os.path.join(folder, f"{key}.pt"), weights_only=False)
28182818
if isinstance(loaded_state, dict):
28192819
loaded_checkpoint_state = loaded_state['param'].view(-1)

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2309,15 +2309,14 @@ def load_state_dict(self,
23092309
load_from_fp32_weights=False,
23102310
checkpoint_folder=None,
23112311
load_serial=None,
2312-
param_shapes=None,
2313-
ignore_missing_optim_state: bool = False):
2312+
param_shapes=None):
23142313
if checkpoint_folder:
2315-
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights, ignore_missing_optim_state=ignore_missing_optim_state)
2314+
self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights)
23162315
else:
23172316
self._load_legacy_checkpoint(state_dict_list, load_optimizer_states, load_from_fp32_weights)
23182317

2319-
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights, ignore_missing_optim_state: bool = False):
2320-
self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder, ignore_missing_optim_state=ignore_missing_optim_state)
2318+
def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
2319+
self.load_hp_checkpoint_state_from_checkpoint_dir("bit16_groups", checkpoint_folder)
23212320

23222321
def _load_global_state(self, sd):
23232322
self.loss_scaler = sd.get(LOSS_SCALER, self.loss_scaler)

0 commit comments

Comments
 (0)