Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions dlrover/python/elastic_agent/torch/ckpt_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ class CheckpointEvent:
global_shard_num: int = 0


@dataclass
class CheckpointNotifyEvent:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used?

step: int = 0


@dataclass
class TensorMeta:
shape: Tuple[int] = None # type: ignore
Expand Down Expand Up @@ -451,7 +456,11 @@ def __init__(
self._latest_step = 0
qname = CheckpointSharedObjPrefix.SAVE_STEP_QNAME + str(0)
self._event_queue = SharedQueue(name=qname, create=True)
self._notify_queues = []
for i in range(self.local_shard_num):
self._notify_queues.append(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a function to generate queue name

SharedQueue(name=CheckpointSharedObjPrefix.SAVE_STEP_QNAME + "_notify_" + str(i),
create=True))
self._shm_handlers.append(SharedMemoryHandler(i))
lock_name = CheckpointSharedObjPrefix.SHM_LOCK_NAME + str(i)
self._shm_locks.append(SharedLock(name=lock_name, create=True))
Expand Down Expand Up @@ -707,6 +716,14 @@ def _save_shard(
return False
finally:
shm_lock.release()
try:
self._notify_queues[local_shard_id].put(
CheckpointNotifyEvent(step=step), block=False
)
except Exception as e:
logger.warning(
f"Skip notify for shard {local_shard_id}, step {step} due to: {e}"
)

def _dist_make_dir(self, path, timeout=30):
if self._node_rank == 0:
Expand Down
40 changes: 40 additions & 0 deletions dlrover/trainer/tests/torch/checkpoint_egine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,46 @@ def test_sync_group(self):
finally:
dist.destroy_process_group()

def test_fast_save_memory(self):
engines = [
FullCheckpointEngine,
DeepSpeedCheckpointEngine,
]
for engine in engines:
self._test_fast_save_memory(engine)

def _test_fast_save_memory(self, engine_class):
model = SimpleNet()
state_dict = dict(
model=model.state_dict(),
step=100,
)
storage = PosixDiskStorage()
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_engine = engine_class(tmpdir, storage)
tmp = Path(tmpdir)
saved_file = tmp / "checkpoint-100/checkpoint.pt"
sd = {CheckpointConstant.MODEL_STATES_NAME: state_dict}
paths = {CheckpointConstant.MODEL_STATES_NAME: saved_file}
checkpoint_engine.save_to_storage(100, sd, paths)

# Simulate quick save_to_memory after save_to_storage.
# save_to_memory will wait for the async saving to complete,
# so no need to sleep here.
checkpoint_engine.save_to_memory(101, sd, paths)

# Check the tracker file and checkpoint, and the steps should
# be updated to 100 which is store by save_to_storage.
tracker_file = tmp / CheckpointConstant.TRACER_FILE_NAME
self.assertTrue(storage.exists(tracker_file))
self.assertEqual(tracker_file.read_text(), "100")
state = torch.load(saved_file)
self.assertEqual(state["step"], 100)

saver: AsyncCheckpointSaver = AsyncCheckpointSaver.get_ckpt_saver()
saver.close()
checkpoint_engine.close()


class PosixDiskStorageTest(unittest.TestCase):
def setUp(self):
Expand Down
29 changes: 29 additions & 0 deletions dlrover/trainer/tests/torch/fsdp_ckpt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,32 @@ def test_fsdp_checkpointer(self):
self.assertListEqual(files, [".metadata", "__0_0.distcp"])
reader = checkpointer._engine.load(path)
self.assertTrue(isinstance(reader, SharedMemoryReader))

def test_fast_save_memory(self):
state_dict = {"step": 100}
storage = PosixDiskStorage()
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
path = tmpdir / str(100)
paths = {CheckpointConstant.MODEL_STATES_NAME: path}
engine = FsdpCheckpointEngine(tmpdir, storage)
engine.save_to_storage(100, state_dict, paths=paths)
self.assertEqual(engine._cached_step, 100)

# Simulate quick save_to_memory after save_to_storage.
# save_to_memory will wait for the async saving to complete,
# so no need to sleep here.
engine.save_to_memory(101, state_dict, paths)
self.assertEqual(engine._cached_step, 101)

# Check if the files are created correctly.
self.assertTrue(storage.exists(tmpdir / "._dlrover_ckpt_stage"))
self.assertTrue(storage.exists(tmpdir / "100/__0_0.distcp"))
# Check the tracker file, and the steps should be updated to 100
# which is store by save_to_storage.
tracker_file = tmpdir / CheckpointConstant.TRACER_FILE_NAME
self.assertTrue(storage.exists(tracker_file))
self.assertEqual(tracker_file.read_text(), "100")
##
reader = engine.load(path)
self.assertTrue(isinstance(reader, SharedMemoryReader))
6 changes: 6 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/deepspeed_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def save_to_memory(self, step, state_dict, paths):
["model_states", "optim_states"] of the state dict and
the value is the path of storage to save.
"""
if self._checkpoint_event_step > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be basic impl from parent class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also i don't get the point why not using 'last_step' but use '_checkpoint_event_step' instead?

notify_event = self._notify_queue.get()
assert notify_event.step == self._checkpoint_event_step
self._checkpoint_event_step = -1
conf = CheckpointConfig(step=step, paths=paths)
success = self.save_state_dict_to_memory(state_dict, conf)
return success
Expand Down Expand Up @@ -120,7 +124,9 @@ def save_to_storage(self, step, state_dict, paths):
if self._local_rank == 0 and success:
event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step)
self._event_queue.put(event)
# All ranks should expect a notify to drain their local shard queue
if success:
self._checkpoint_event_step = step
self.latest_step = step
return success

Expand Down
3 changes: 3 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,14 @@ def __init__(
)
else:
self._event_queue = None # type: ignore
self._checkpoint_event_step = -1
self._update_saver_config()

# lock for shared memory
local_shard_num = self.get_local_shard_num()
self.local_shard_id = self._local_rank % local_shard_num
self._notify_queue = SharedQueue(name=CheckpointSharedObjPrefix.SAVE_STEP_QNAME +
"_notify_" + str(self.local_shard_id), create=False)
lock_name = CheckpointSharedObjPrefix.SHM_LOCK_NAME + str(
self.local_shard_id
)
Expand Down
8 changes: 7 additions & 1 deletion dlrover/trainer/torch/flash_checkpoint/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def get_saving_ranks(self):
return None

@timer
def save_to_memory(self, step, state_dict, paths: Dict[str, str]):
def _save_to_memory(self, step, state_dict, paths: Dict[str, str]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be modified, as it is inherently a public abstract method.

"""
Synchronously Saves the state dict into the shared memory with the main
process. If the agent in the main process is saving the shared memory
Expand All @@ -487,6 +487,11 @@ def save_to_memory(self, step, state_dict, paths: Dict[str, str]):
if self._local_rank != self.local_shard_id:
return False

if self._checkpoint_event_step > 0:
notify_event = self._notify_queue.get()
assert notify_event.step == self._checkpoint_event_step
self._checkpoint_event_step = -1

acquired = self._shm_lock.acquire(blocking=False)
all_rank_ready = check_all_rank_ready(self._saver_group, acquired)
if not all_rank_ready:
Expand Down Expand Up @@ -549,6 +554,7 @@ def save_to_storage(self, step, state_dict, paths: Dict[str, str]):
event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step)
self._event_queue.put(event)
if success:
self._checkpoint_event_step = step
self.latest_step = step

def get_saver_class(self):
Expand Down
6 changes: 6 additions & 0 deletions dlrover/trainer/torch/flash_checkpoint/full_ckpt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def save_to_memory(self, step, state_dict, paths: Dict[str, str]):
["model_states", "optim_states"] of the state dict and
the value is the path of storage to save.
"""
if self._checkpoint_event_step > 0:
notify_event = self._notify_queue.get()
assert notify_event.step == self._checkpoint_event_step
self._checkpoint_event_step = -1
conf = CheckpointConfig(step=step, paths=paths)
return self.save_state_dict_to_memory(state_dict, conf)

Expand Down Expand Up @@ -140,7 +144,9 @@ def save_to_storage(self, step, state_dict, paths):
if success and self._rank == 0:
event = CheckpointEvent(type=CheckpointEventType.SAVE, step=step)
self._event_queue.put(event)
# All ranks should expect a notify to drain their local shard queue
if success:
self._checkpoint_event_step = step
self.latest_step = step
return success

Expand Down