Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
rev: v0.12.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
args: [--fix, --exit-non-zero-on-fix, --ignore, S603,]
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.4.1
Expand Down
20 changes: 15 additions & 5 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,10 +886,6 @@ def update(
return
self.init_process_group_for_ranks(ranks)
self._update_per_bucket_p2p(checkpoint_name, req_func, ranks)
if self._auto_pg:
dist.destroy_process_group()

torch.cuda.empty_cache()

logger.info(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
Expand All @@ -901,6 +897,11 @@ def update(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} error {e}"
)
raise
finally:
if self._auto_pg:
dist.destroy_process_group()

torch.cuda.empty_cache()

def _bind_zmq_socket(self) -> tuple[zmq.Socket, list[tuple[str, str]]]:
def zmq_handle(device_uuid: str) -> str:
Expand Down Expand Up @@ -1191,7 +1192,16 @@ def _update_per_bucket(
else:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
dist.broadcast(buffer_b, src=owner_rank)
socket.recv()
resp_list: list[bytes] = [b""] * dist.get_world_size()
resp = socket.recv()
dist.all_gather_object(resp_list, resp)
torch.cuda.synchronize()
if any(r != b"" for r in resp_list):
# quit early if any rank failed
failed_ranks = [i for i, r in enumerate(resp_list) if r != b""]
raise RuntimeError(
f"failed to update weights due to remote error(s) on rank(s): {failed_ranks}"
)
dist.barrier()
socket.send_pyobj(_to_named_tensor(bucket.items, gidx % 2 * bucket_size))
gidx += 1
Expand Down
10 changes: 7 additions & 3 deletions checkpoint_engine/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,13 @@ def update_weights_from_ipc(
socket.send(b"")
continue
assert isinstance(payload, list)
run(_extract_weights(payload, buffer))
torch.cuda.synchronize()
socket.send(b"")
try:
run(_extract_weights(payload, buffer))
torch.cuda.synchronize()
socket.send(b"")
except Exception as e:
socket.send_pyobj(e)
raise

socket.close()
del buffer
Expand Down
134 changes: 134 additions & 0 deletions tests/test_error_quit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
import random
import subprocess
import time

import pytest
import torch
import zmq
from torch.multiprocessing import Queue, get_context

from checkpoint_engine.ps import ParameterServer, _get_physical_gpu_id
from checkpoint_engine.worker import update_weights_from_ipc


def gen_test_tensors(rank: int) -> list[tuple[str, torch.Tensor]]:
tensors = []
for layer in range(random.randint(10, 50)):
for num in range(random.randint(50, 100)):
r = random.randint(0, 16)
if r < 4:
dtype = torch.bfloat16
elif r < 10:
dtype = torch.float16
elif r < 14:
dtype = torch.float8_e4m3fn
else:
dtype = torch.float
tensors.append(
(
f"rank{rank}.layer{layer}.num{num}",
torch.randn([random.randint(100, 500), random.randint(500, 1000)]).to(dtype),
)
)
return tensors


def receiver_proc_with_error(
rank: int, device_uuid: str, named_tensors: dict[str, torch.Tensor], queue: Queue
):
torch.cuda.set_device(rank)
named_tensors = {name: tensor.cuda() for name, tensor in named_tensors.items()}
_zmq_ctx = zmq.Context()

def trigger_error(socket_paths: list[tuple[str, str]]):
socket_paths = dict(socket_paths)
update_weights_from_ipc(
_zmq_ctx,
socket_paths[device_uuid],
device_id=rank,
run=error_run,
post_hook=lambda: torch.cuda.synchronize(),
)

def error_run(weights: list[tuple[str, torch.Tensor]]):
weights = weights # Do some fake processing
time.sleep(random.uniform(0.1, 0.5))
if random.random() < 0.6:
raise RuntimeError("Intentional Error for testing.")

while True:
socket_paths: list[tuple[str, str]] = queue.get()
if socket_paths is None:
break
try:
trigger_error(socket_paths)
except RuntimeError:
print(f"[rank{rank}] successfully triggered error.")
raise


def run():
rank = int(os.getenv("RANK"))
ctx = get_context("spawn")
queue = ctx.Queue()
_device_uuid = _get_physical_gpu_id(rank)
ps = ParameterServer(auto_pg=True)
named_tensors = dict(gen_test_tensors(rank))
checkpoint_name = "test"
proc = ctx.Process(
target=receiver_proc_with_error, args=(rank, _device_uuid, named_tensors, queue)
)
proc.daemon = True
proc.start()
try:
ps.register_checkpoint(checkpoint_name, named_tensors=named_tensors)
ps.gather_metas(checkpoint_name)
ranks = []
ps.update(checkpoint_name, queue.put, ranks=ranks)
# sleep 3s to wait process group is destroyed
time.sleep(3)
except RuntimeError as e:
print(f"[rank{rank}] Caught expected RuntimeError from worker process: {e}")
assert "failed to update weights due to remote error(s)" in str(e)
except Exception as e:
print(f"[rank{rank}] Caught unexpected exception: {e}")
raise
finally:
ps.unregister_checkpoint(checkpoint_name)
queue.put(None)


@pytest.mark.gpu
def test_update():
world_size = torch.cuda.device_count()
assert world_size >= 2, "This test requires at least 2 GPUs."

master_addr = "localhost"
master_port = 25400

cmd = [
"torchrun",
"--nproc_per_node",
str(world_size),
"--master_addr",
master_addr,
"--master_port",
str(master_port),
"tests/test_error_quit.py",
]

result = subprocess.run(
cmd,
capture_output=False,
text=True,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
shell=False,
check=False,
)

assert result.returncode == 0


if __name__ == "__main__":
run()