-
Notifications
You must be signed in to change notification settings - Fork 58
feat: quit checkpoint engine when error occurs #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
specture724
wants to merge
5
commits into
MoonshotAI:main
Choose a base branch
from
specture724:fix/quit-when-error
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
5ad2d95
feat: quit checkpoint worker process when error occurs
specture724 7839fc4
misc: resolve pr issues
specture724 08018ab
misc: pytest version test
specture724 4d2156d
misc: pytest version test
specture724 6d0f1ac
Merge branch 'fix/quit-when-error' of github.com:specture724/checkpoi…
specture724 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| import os | ||
| import random | ||
| import subprocess | ||
| import time | ||
|
|
||
| import pytest | ||
| import torch | ||
| import zmq | ||
| from torch.multiprocessing import Queue, get_context | ||
|
|
||
| from checkpoint_engine.ps import ParameterServer, _get_physical_gpu_id | ||
| from checkpoint_engine.worker import update_weights_from_ipc | ||
|
|
||
|
|
||
| def gen_test_tensors(rank: int) -> list[tuple[str, torch.Tensor]]: | ||
| tensors = [] | ||
| for layer in range(random.randint(10, 50)): | ||
| for num in range(random.randint(50, 100)): | ||
| r = random.randint(0, 16) | ||
| if r < 4: | ||
| dtype = torch.bfloat16 | ||
| elif r < 10: | ||
| dtype = torch.float16 | ||
| elif r < 14: | ||
| dtype = torch.float8_e4m3fn | ||
| else: | ||
| dtype = torch.float | ||
| tensors.append( | ||
| ( | ||
| f"rank{rank}.layer{layer}.num{num}", | ||
| torch.randn([random.randint(100, 500), random.randint(500, 1000)]).to(dtype), | ||
| ) | ||
| ) | ||
| return tensors | ||
|
|
||
|
|
||
| def receiver_proc_with_error( | ||
| rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue | ||
| ): | ||
| torch.cuda.set_device(rank) | ||
| named_tensors = {name: tensor.cuda() for name, tensor in named_tensors.items()} | ||
| _zmq_ctx = zmq.Context() | ||
|
|
||
| def trigger_error(socket_paths: list[tuple[str, str]]): | ||
| socket_paths = dict(socket_paths) | ||
| update_weights_from_ipc( | ||
| _zmq_ctx, | ||
| socket_paths[device_uuid], | ||
| device_id=rank, | ||
| run=error_run, | ||
| post_hook=lambda: torch.cuda.synchronize(), | ||
| ) | ||
|
|
||
| def error_run(weights: list[tuple[str, torch.Tensor]]): | ||
| weights = weights # Do some fake processing | ||
| time.sleep(random.uniform(0.1, 0.5)) | ||
| if random.random() < 0.6: | ||
| raise RuntimeError("Intentional Error for testing.") | ||
|
|
||
| while True: | ||
| socket_paths: list[tuple[str, str]] = queue.get() | ||
| if socket_paths is None: | ||
| break | ||
| try: | ||
| trigger_error(socket_paths) | ||
| except RuntimeError: | ||
| print(f"[rank{rank}] successfully triggered error.") | ||
| raise | ||
|
|
||
|
|
||
| def run(): | ||
| rank = int(os.getenv("RANK")) | ||
| ctx = get_context("spawn") | ||
| queue = ctx.Queue() | ||
| _device_uuid = _get_physical_gpu_id(rank) | ||
| ps = ParameterServer(auto_pg=True) | ||
| named_tensors = dict(gen_test_tensors(rank)) | ||
| checkpoint_name = "test" | ||
| proc = ctx.Process( | ||
| target=receiver_proc_with_error, args=(rank, _device_uuid, named_tensors, queue) | ||
| ) | ||
| proc.daemon = True | ||
| proc.start() | ||
| try: | ||
| ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors) | ||
| ps.gather_metas(checkpoint_name) | ||
| ranks = [] | ||
| ps.update(checkpoint_name, queue.put, ranks=ranks) | ||
| # sleep 3s to wait process group is destroyed | ||
| time.sleep(3) | ||
| except RuntimeError as e: | ||
| print(f"[rank{rank}] Caught expected RuntimeError from worker process: {e}") | ||
| assert "failed to update weights due to remote error(s)" in str(e) | ||
| except Exception as e: | ||
| print(f"[rank{rank}] Caught unexpected exception: {e}") | ||
| raise | ||
| finally: | ||
| ps.unregister_checkpoint(checkpoint_name) | ||
| queue.put(None) | ||
|
|
||
|
|
||
| @pytest.mark.gpu | ||
| def test_update(): | ||
| world_size = torch.cuda.device_count() | ||
| assert world_size >= 2, "This test requires at least 2 GPUs." | ||
|
|
||
| master_addr = "localhost" | ||
| master_port = 25400 | ||
|
|
||
| cmd = [ | ||
| "torchrun", | ||
| "--nproc_per_node", | ||
| str(world_size), | ||
| "--master_addr", | ||
| master_addr, | ||
| "--master_port", | ||
| str(master_port), | ||
| "tests/test_error_quit.py", | ||
| ] | ||
|
|
||
| result = subprocess.run( | ||
| cmd, | ||
| capture_output=False, | ||
| text=True, | ||
| cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), | ||
| shell=False, | ||
| check=False, | ||
| ) | ||
|
|
||
| assert result.returncode == 0 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.