-
Notifications
You must be signed in to change notification settings - Fork 198
avoid save_to_memory running too fast then causing save_to_storage to fail #1649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -85,6 +85,11 @@ class CheckpointEvent: | |
| global_shard_num: int = 0 | ||
|
|
||
|
|
||
| @dataclass | ||
| class CheckpointNotifyEvent: | ||
| step: int = 0 | ||
|
|
||
|
|
||
| @dataclass | ||
| class TensorMeta: | ||
| shape: Tuple[int] = None # type: ignore | ||
|
|
@@ -451,7 +456,11 @@ def __init__( | |
| self._latest_step = 0 | ||
| qname = CheckpointSharedObjPrefix.SAVE_STEP_QNAME + str(0) | ||
| self._event_queue = SharedQueue(name=qname, create=True) | ||
| self._notify_queues = [] | ||
| for i in range(self.local_shard_num): | ||
| self._notify_queues.append( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use a function to generate queue name |
||
| SharedQueue(name=CheckpointSharedObjPrefix.SAVE_STEP_QNAME + "_notify_" + str(i), | ||
| create=True)) | ||
| self._shm_handlers.append(SharedMemoryHandler(i)) | ||
| lock_name = CheckpointSharedObjPrefix.SHM_LOCK_NAME + str(i) | ||
| self._shm_locks.append(SharedLock(name=lock_name, create=True)) | ||
|
|
@@ -707,6 +716,14 @@ def _save_shard( | |
| return False | ||
| finally: | ||
| shm_lock.release() | ||
| try: | ||
| self._notify_queues[local_shard_id].put( | ||
| CheckpointNotifyEvent(step=step), block=False | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| f"Skip notify for shard {local_shard_id}, step {step} due to: {e}" | ||
| ) | ||
|
|
||
| def _dist_make_dir(self, path, timeout=30): | ||
| if self._node_rank == 0: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,6 +89,10 @@ def save_to_memory(self, step, state_dict, paths): | |
| ["model_states", "optim_states"] of the state dict and | ||
| the value is the path of storage to save. | ||
| """ | ||
| if self._checkpoint_event_step > 0: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be basic impl from parent class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also i don't get the point why not using 'last_step' but use '_checkpoint_event_step' instead? |
||
| notify_event = self._notify_queue.get() | ||
| assert notify_event.step == self._checkpoint_event_step | ||
| self._checkpoint_event_step = -1 | ||
| conf = CheckpointConfig(step=step, paths=paths) | ||
| success = self.save_state_dict_to_memory(state_dict, conf) | ||
| return success | ||
|
|
@@ -120,7 +124,9 @@ def save_to_storage(self, step, state_dict, paths): | |
| if self._local_rank == 0 and success: | ||
| event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step) | ||
| self._event_queue.put(event) | ||
| # All ranks should expect a notify to drain their local shard queue | ||
| if success: | ||
| self._checkpoint_event_step = step | ||
| self.latest_step = step | ||
| return success | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -469,7 +469,7 @@ def get_saving_ranks(self): | |
| return None | ||
|
|
||
| @timer | ||
| def save_to_memory(self, step, state_dict, paths: Dict[str, str]): | ||
| def _save_to_memory(self, step, state_dict, paths: Dict[str, str]): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be modified, as it is inherently a public abstract method. |
||
| """ | ||
| Synchronously Saves the state dict into the shared memory with the main | ||
| process. If the agent in the main process is saving the shared memory | ||
|
|
@@ -487,6 +487,11 @@ def save_to_memory(self, step, state_dict, paths: Dict[str, str]): | |
| if self._local_rank != self.local_shard_id: | ||
| return False | ||
|
|
||
| if self._checkpoint_event_step > 0: | ||
| notify_event = self._notify_queue.get() | ||
| assert notify_event.step == self._checkpoint_event_step | ||
| self._checkpoint_event_step = -1 | ||
|
|
||
| acquired = self._shm_lock.acquire(blocking=False) | ||
| all_rank_ready = check_all_rank_ready(self._saver_group, acquired) | ||
| if not all_rank_ready: | ||
|
|
@@ -549,6 +554,7 @@ def save_to_storage(self, step, state_dict, paths: Dict[str, str]): | |
| event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step) | ||
| self._event_queue.put(event) | ||
| if success: | ||
| self._checkpoint_event_step = step | ||
| self.latest_step = step | ||
|
|
||
| def get_saver_class(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used?