diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 58b5a61..ae568c8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: rev: v0.12.2 hooks: - id: ruff - args: [--fix, --exit-non-zero-on-fix] + args: [--fix, --exit-non-zero-on-fix, --ignore, S603,] - id: ruff-format - repo: https://github.com/codespell-project/codespell rev: v2.4.1 diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 1493a69..40fd458 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -886,10 +886,6 @@ def update( return self.init_process_group_for_ranks(ranks) self._update_per_bucket_p2p(checkpoint_name, req_func, ranks) - if self._auto_pg: - dist.destroy_process_group() - - torch.cuda.empty_cache() logger.info( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. " @@ -901,6 +897,11 @@ def update( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}" ) raise + finally: + if self._auto_pg: + dist.destroy_process_group() + + torch.cuda.empty_cache() def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]: def zmq_handle(device_uuid: str) -> str: @@ -1191,7 +1192,16 @@ def _update_per_bucket( else: buffer_b.data.copy_(h2d_buffer[: bucket.size]) dist.broadcast(buffer_b, src=owner_rank) - socket.recv() + resp_list: list[bytes] = [b""] * dist.get_world_size() + resp = socket.recv() + dist.all_gather_object(resp_list, resp) + torch.cuda.synchronize() + if any(r != b"" for r in resp_list): + # quit early if any rank failed + failed_ranks = [i for i, r in enumerate(resp_list) if r != b""] + raise RuntimeError( + f"failed to update weights due to remote error(s) on rank(s): {failed_ranks}" + ) dist.barrier() socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size)) gidx += 1 diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index e332d73..5d3c7f1 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -70,9 +70,13 @@ def update_weights_from_ipc( socket.send(b"") continue assert isinstance(payload, list) - run(_extract_weights(payload, buffer)) - torch.cuda.synchronize() - socket.send(b"") + try: + run(_extract_weights(payload, buffer)) + torch.cuda.synchronize() + socket.send(b"") + except Exception as e: + socket.send_pyobj(e) + raise socket.close() del buffer diff --git a/tests/test_error_quit.py b/tests/test_error_quit.py new file mode 100644 index 0000000..452d5c0 --- /dev/null +++ b/tests/test_error_quit.py @@ -0,0 +1,134 @@ +import os +import random +import subprocess +import time + +import pytest +import torch +import zmq +from torch.multiprocessing import Queue, get_context + +from checkpoint_engine.ps import ParameterServer, _get_physical_gpu_id +from checkpoint_engine.worker import update_weights_from_ipc + + +def gen_test_tensors(rank: int) -> list[tuple[str, torch.Tensor]]: + tensors = [] + for layer in range(random.randint(10, 50)): + for num in range(random.randint(50, 100)): + r = random.randint(0, 16) + if r < 4: + dtype = torch.bfloat16 + elif r < 10: + dtype = torch.float16 + elif r < 14: + dtype = torch.float8_e4m3fn + else: + dtype = torch.float + tensors.append( + ( + f"rank{rank}.layer{layer}.num{num}", + torch.randn([random.randint(100, 500), random.randint(500, 1000)]).to(dtype), + ) + ) + return tensors + + +def receiver_proc_with_error( + rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue +): + torch.cuda.set_device(rank) + named_tensors = {name: tensor.cuda() for name, tensor in named_tensors.items()} + _zmq_ctx = zmq.Context() + + def trigger_error(socket_paths: list[tuple[str, str]]): + socket_paths = dict(socket_paths) + update_weights_from_ipc( + _zmq_ctx, + socket_paths[device_uuid], + device_id=rank, + run=error_run, + post_hook=lambda: torch.cuda.synchronize(), + ) + + def error_run(weights: list[tuple[str, torch.Tensor]]): + weights = weights # Do some fake processing + time.sleep(random.uniform(0.1, 0.5)) + if random.random() < 0.6: + raise RuntimeError("Intentional Error for testing.") + + while True: + socket_paths: list[tuple[str, str]] = queue.get() + if socket_paths is None: + break + try: + trigger_error(socket_paths) + except RuntimeError: + print(f"[rank{rank}] successfully triggered error.") + raise + + +def run(): + rank = int(os.getenv("RANK")) + ctx = get_context("spawn") + queue = ctx.Queue() + _device_uuid = _get_physical_gpu_id(rank) + ps = ParameterServer(auto_pg=True) + named_tensors = dict(gen_test_tensors(rank)) + checkpoint_name = "test" + proc = ctx.Process( + target=receiver_proc_with_error, args=(rank, _device_uuid, named_tensors, queue) + ) + proc.daemon = True + proc.start() + try: + ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors) + ps.gather_metas(checkpoint_name) + ranks = [] + ps.update(checkpoint_name, queue.put, ranks=ranks) + # sleep 3s to wait process group is destroyed + time.sleep(3) + except RuntimeError as e: + print(f"[rank{rank}] Caught expected RuntimeError from worker process: {e}") + assert "failed to update weights due to remote error(s)" in str(e) + except Exception as e: + print(f"[rank{rank}] Caught unexpected exception: {e}") + raise + finally: + ps.unregister_checkpoint(checkpoint_name) + queue.put(None) + + +@pytest.mark.gpu +def test_update(): + world_size = torch.cuda.device_count() + assert world_size >= 2, "This test requires at least 2 GPUs." + + master_addr = "localhost" + master_port = 25400 + + cmd = [ + "torchrun", + "--nproc_per_node", + str(world_size), + "--master_addr", + master_addr, + "--master_port", + str(master_port), + "tests/test_error_quit.py", + ] + + result = subprocess.run( + cmd, + capture_output=False, + text=True, + cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + shell=False, + check=False, + ) + + assert result.returncode == 0 + + +if __name__ == "__main__": + run()