diff --git a/dlrover/python/elastic_agent/torch/ckpt_saver.py b/dlrover/python/elastic_agent/torch/ckpt_saver.py index f5dc5065f..e30ba6cd5 100644 --- a/dlrover/python/elastic_agent/torch/ckpt_saver.py +++ b/dlrover/python/elastic_agent/torch/ckpt_saver.py @@ -85,6 +85,11 @@ class CheckpointEvent: global_shard_num: int = 0 +@dataclass +class CheckpointNotifyEvent: + step: int = 0 + + @dataclass class TensorMeta: shape: Tuple[int] = None # type: ignore @@ -451,7 +456,11 @@ def __init__( self._latest_step = 0 qname = CheckpointSharedObjPrefix.SAVE_STEP_QNAME + str(0) self._event_queue = SharedQueue(name=qname, create=True) + self._notify_queues = [] for i in range(self.local_shard_num): + self._notify_queues.append( + SharedQueue(name=CheckpointSharedObjPrefix.SAVE_STEP_QNAME + "_notify_" + str(i), + create=True)) self._shm_handlers.append(SharedMemoryHandler(i)) lock_name = CheckpointSharedObjPrefix.SHM_LOCK_NAME + str(i) self._shm_locks.append(SharedLock(name=lock_name, create=True)) @@ -707,6 +716,14 @@ def _save_shard( return False finally: shm_lock.release() + try: + self._notify_queues[local_shard_id].put( + CheckpointNotifyEvent(step=step), block=False + ) + except Exception as e: + logger.warning( + f"Skip notify for shard {local_shard_id}, step {step} due to: {e}" + ) def _dist_make_dir(self, path, timeout=30): if self._node_rank == 0: diff --git a/dlrover/trainer/tests/torch/checkpoint_egine_test.py b/dlrover/trainer/tests/torch/checkpoint_egine_test.py index 51c864216..47729379a 100644 --- a/dlrover/trainer/tests/torch/checkpoint_egine_test.py +++ b/dlrover/trainer/tests/torch/checkpoint_egine_test.py @@ -362,6 +362,46 @@ def test_sync_group(self): finally: dist.destroy_process_group() + def test_fast_save_memory(self): + engines = [ + FullCheckpointEngine, + DeepSpeedCheckpointEngine, + ] + for engine in engines: + self._test_fast_save_memory(engine) + + def _test_fast_save_memory(self, engine_class): + model = SimpleNet() + state_dict = dict( + model=model.state_dict(), + step=100, + ) + storage = PosixDiskStorage() + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_engine = engine_class(tmpdir, storage) + tmp = Path(tmpdir) + saved_file = tmp / "checkpoint-100/checkpoint.pt" + sd = {CheckpointConstant.MODEL_STATES_NAME: state_dict} + paths = {CheckpointConstant.MODEL_STATES_NAME: saved_file} + checkpoint_engine.save_to_storage(100, sd, paths) + + # Simulate quick save_to_memory after save_to_storage. + # save_to_memory will wait for the async saving to complete, + # so no need to sleep here. + checkpoint_engine.save_to_memory(101, sd, paths) + + # Check the tracker file and checkpoint, and the steps should + # be updated to 100 which is store by save_to_storage. + tracker_file = tmp / CheckpointConstant.TRACER_FILE_NAME + self.assertTrue(storage.exists(tracker_file)) + self.assertEqual(tracker_file.read_text(), "100") + state = torch.load(saved_file) + self.assertEqual(state["step"], 100) + + saver: AsyncCheckpointSaver = AsyncCheckpointSaver.get_ckpt_saver() + saver.close() + checkpoint_engine.close() + class PosixDiskStorageTest(unittest.TestCase): def setUp(self): diff --git a/dlrover/trainer/tests/torch/fsdp_ckpt_test.py b/dlrover/trainer/tests/torch/fsdp_ckpt_test.py index 878db2eb2..0ed49b6cd 100644 --- a/dlrover/trainer/tests/torch/fsdp_ckpt_test.py +++ b/dlrover/trainer/tests/torch/fsdp_ckpt_test.py @@ -419,3 +419,32 @@ def test_fsdp_checkpointer(self): self.assertListEqual(files, [".metadata", "__0_0.distcp"]) reader = checkpointer._engine.load(path) self.assertTrue(isinstance(reader, SharedMemoryReader)) + + def test_fast_save_memory(self): + state_dict = {"step": 100} + storage = PosixDiskStorage() + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + path = tmpdir / str(100) + paths = {CheckpointConstant.MODEL_STATES_NAME: path} + engine = FsdpCheckpointEngine(tmpdir, storage) + engine.save_to_storage(100, state_dict, paths=paths) + self.assertEqual(engine._cached_step, 100) + + # Simulate quick save_to_memory after save_to_storage. + # save_to_memory will wait for the async saving to complete, + # so no need to sleep here. + engine.save_to_memory(101, state_dict, paths) + self.assertEqual(engine._cached_step, 101) + + # Check if the files are created correctly. + self.assertTrue(storage.exists(tmpdir / "._dlrover_ckpt_stage")) + self.assertTrue(storage.exists(tmpdir / "100/__0_0.distcp")) + # Check the tracker file, and the steps should be updated to 100 + # which is store by save_to_storage. + tracker_file = tmpdir / CheckpointConstant.TRACER_FILE_NAME + self.assertTrue(storage.exists(tracker_file)) + self.assertEqual(tracker_file.read_text(), "100") + ## + reader = engine.load(path) + self.assertTrue(isinstance(reader, SharedMemoryReader)) diff --git a/dlrover/trainer/torch/flash_checkpoint/deepspeed_engine.py b/dlrover/trainer/torch/flash_checkpoint/deepspeed_engine.py index 0bf96bf7d..d799c9f83 100644 --- a/dlrover/trainer/torch/flash_checkpoint/deepspeed_engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/deepspeed_engine.py @@ -89,6 +89,10 @@ def save_to_memory(self, step, state_dict, paths): ["model_states", "optim_states"] of the state dict and the value is the path of storage to save. """ + if self._checkpoint_event_step > 0: + notify_event = self._notify_queue.get() + assert notify_event.step == self._checkpoint_event_step + self._checkpoint_event_step = -1 conf = CheckpointConfig(step=step, paths=paths) success = self.save_state_dict_to_memory(state_dict, conf) return success @@ -120,7 +124,9 @@ def save_to_storage(self, step, state_dict, paths): if self._local_rank == 0 and success: event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step) self._event_queue.put(event) + # All ranks should expect a notify to drain their local shard queue if success: + self._checkpoint_event_step = step self.latest_step = step return success diff --git a/dlrover/trainer/torch/flash_checkpoint/engine.py b/dlrover/trainer/torch/flash_checkpoint/engine.py index 67bda319d..d47b8a9de 100644 --- a/dlrover/trainer/torch/flash_checkpoint/engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/engine.py @@ -206,11 +206,14 @@ def __init__( ) else: self._event_queue = None # type: ignore + self._checkpoint_event_step = -1 self._update_saver_config() # lock for shared memory local_shard_num = self.get_local_shard_num() self.local_shard_id = self._local_rank % local_shard_num + self._notify_queue = SharedQueue(name=CheckpointSharedObjPrefix.SAVE_STEP_QNAME + + "_notify_" + str(self.local_shard_id), create=False) lock_name = CheckpointSharedObjPrefix.SHM_LOCK_NAME + str( self.local_shard_id ) diff --git a/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py b/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py index 8254af79d..158b1f521 100644 --- a/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py @@ -469,7 +469,7 @@ def get_saving_ranks(self): return None @timer - def save_to_memory(self, step, state_dict, paths: Dict[str, str]): + def _save_to_memory(self, step, state_dict, paths: Dict[str, str]): """ Synchronously Saves the state dict into the shared memory with the main process. If the agent in the main process is saving the shared memory @@ -487,6 +487,11 @@ def save_to_memory(self, step, state_dict, paths: Dict[str, str]): if self._local_rank != self.local_shard_id: return False + if self._checkpoint_event_step > 0: + notify_event = self._notify_queue.get() + assert notify_event.step == self._checkpoint_event_step + self._checkpoint_event_step = -1 + acquired = self._shm_lock.acquire(blocking=False) all_rank_ready = check_all_rank_ready(self._saver_group, acquired) if not all_rank_ready: @@ -549,6 +554,7 @@ def save_to_storage(self, step, state_dict, paths: Dict[str, str]): event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step) self._event_queue.put(event) if success: + self._checkpoint_event_step = step self.latest_step = step def get_saver_class(self): diff --git a/dlrover/trainer/torch/flash_checkpoint/full_ckpt_engine.py b/dlrover/trainer/torch/flash_checkpoint/full_ckpt_engine.py index 7ff9bafe3..e5865ec2a 100644 --- a/dlrover/trainer/torch/flash_checkpoint/full_ckpt_engine.py +++ b/dlrover/trainer/torch/flash_checkpoint/full_ckpt_engine.py @@ -112,6 +112,10 @@ def save_to_memory(self, step, state_dict, paths: Dict[str, str]): ["model_states", "optim_states"] of the state dict and the value is the path of storage to save. """ + if self._checkpoint_event_step > 0: + notify_event = self._notify_queue.get() + assert notify_event.step == self._checkpoint_event_step + self._checkpoint_event_step = -1 conf = CheckpointConfig(step=step, paths=paths) return self.save_state_dict_to_memory(state_dict, conf) @@ -140,7 +144,9 @@ def save_to_storage(self, step, state_dict, paths): if success and self._rank == 0: event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step) self._event_queue.put(event) + # All ranks should expect a notify to drain their local shard queue if success: + self._checkpoint_event_step = step self.latest_step = step return success