Skip to content

Conversation

qijiale76
Copy link
Contributor

What changes were proposed in this pull request?

This PR based on #1585

Previously, a save_to_memory call could overwrite a shared memory buffer while the background saver process was still reading from it for a save_to_storage operation. To prevent this, a notification queue was introduced in #1585, where the saver process notifies the training process once the buffer is free.

However, the implementation in #1585 only had rank 0 consume this notification. This caused all non-rank-0 processes to hang indefinitely on subsequent checkpoints, as they never consumed the event from their corresponding notification queue.

This patch fixes the issue by ensuring that every rank waits for and consumes the notification from its dedicated queue before proceeding with the next save_to_memory operation. This guarantees that the shared memory buffer is not written to prematurely.

Why are the changes needed?

Without proper synchronization across all ranks, tasks would hang during checkpointing, especially when checkpoints were triggered in rapid succession. This fix ensures the reliability of the asynchronous saving process by correctly coordinating the main training processes and the background saver.

Does this PR introduce any user-facing change?

No

How was this patch tested?

Verified in live training jobs. The race condition is resolved and the task no longer hangs.

@codecov
Copy link

codecov bot commented Sep 26, 2025

Codecov Report

❌ Patch coverage is 0% with 27 lines in your changes missing coverage. Please review.
✅ Project coverage is 15.09%. Comparing base (917b4c3) to head (4238f32).

Files with missing lines Patch % Lines
dlrover/python/elastic_agent/torch/ckpt_saver.py 0.00% 9 Missing ⚠️
...over/trainer/torch/flash_checkpoint/fsdp_engine.py 0.00% 6 Missing ⚠️
...trainer/torch/flash_checkpoint/deepspeed_engine.py 0.00% 5 Missing ⚠️
...trainer/torch/flash_checkpoint/full_ckpt_engine.py 0.00% 5 Missing ⚠️
dlrover/trainer/torch/flash_checkpoint/engine.py 0.00% 2 Missing ⚠️

❌ Your patch check has failed because the patch coverage (0.00%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.
❌ Your project check has failed because the head coverage (15.09%) is below the target coverage (81.00%). You can increase the head coverage or adjust the target coverage.
❌ Your project check has failed because you have indirect coverage changes. Learn more about Unexpected Coverage Changes and reasons for indirect coverage changes.

❗ There is a different number of reports uploaded between BASE (917b4c3) and HEAD (4238f32). Click for more details.

HEAD has 1 upload less than BASE
Flag BASE (917b4c3) HEAD (4238f32)
2 1
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #1649       +/-   ##
===========================================
- Coverage   80.13%   15.09%   -65.04%     
===========================================
  Files         228      228               
  Lines       22213    22228       +15     
===========================================
- Hits        17801     3356    -14445     
- Misses       4412    18872    +14460     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@BalaBalaYi BalaBalaYi added the enhancement New feature or request label Oct 21, 2025
self._event_queue = SharedQueue(name=qname, create=True)
self._notify_queues = []
for i in range(self.local_shard_num):
self._notify_queues.append(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a function to generate queue name


@timer
def save_to_memory(self, step, state_dict, paths: Dict[str, str]):
def _save_to_memory(self, step, state_dict, paths: Dict[str, str]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be modified, as it is inherently a public abstract method.

["model_states", "optim_states"] of the state dict and
the value is the path of storage to save.
"""
if self._checkpoint_event_step > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be basic impl from parent class.



@dataclass
class CheckpointNotifyEvent:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used?

["model_states", "optim_states"] of the state dict and
the value is the path of storage to save.
"""
if self._checkpoint_event_step > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also i don't get the point why not using 'last_step' but use '_checkpoint_event_step' instead?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants