Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions checkpoint_engine/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import re
import socket
import subprocess
from functools import lru_cache

import torch
from loguru import logger


@lru_cache(maxsize=1)
def get_ip() -> str:
try:
# try to get ip from network interface
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
except Exception as e: # noqa: BLE001
# fallback to get ip from hostname
logger.warning(
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
)
return socket.gethostbyname(socket.gethostname())


def npu_generate_uuid() -> str:
str_pid = str(os.getpid())
npu_num = 8
try:
for npu_id in range(npu_num):
cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)]
result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603
str_result = str(result.stdout)
if str_pid in str_result:
# In A3 server, one NPU has two chips.
match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result)
chip_count = int(match_chip_count.group(1))
search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :]
match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid)
chip_id = int(match_chip_id.group(1))
return f"{get_ip()}-{npu_id * chip_count + chip_id}"
ValueError("The current process is not running on the npu device")
except subprocess.CalledProcessError:
ValueError("The current process is not running on the npu device")


class DeviceManager:
def __init__(self):
self.device_type = self._detect_device_type()
self._setup_device_module()

def _is_torch_npu_available(self) -> bool:
try:
if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)):
return torch.npu.is_available()
else:
return False
except ImportError:
return False

def _detect_device_type(self) -> str:
if self._is_torch_npu_available():
return "npu"
elif torch.cuda.is_available():
return "cuda"
else:
raise TypeError("The current device type is not supported")

def _setup_device_module(self):
if self.device_type == "npu":
import torch_npu

self.device_module = torch_npu.npu
elif self.device_type == "cuda":
self.device_module = torch.cuda
else:
raise TypeError("The current device type is not supported")

@property
def backend(self) -> str:
if self.device_type == "npu":
return "hccl"
elif self.device_type == "cuda":
return "nccl"
73 changes: 33 additions & 40 deletions checkpoint_engine/ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
import os
import pickle
import random
import socket
import threading
import time
from collections import defaultdict
from collections.abc import Callable
from datetime import timedelta
from functools import lru_cache
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple

import httpx
Expand All @@ -23,6 +21,8 @@
from safetensors.torch import safe_open
from torch.multiprocessing.reductions import reduce_tensor

from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid


if TYPE_CHECKING:
from typing import TypeVar
Expand Down Expand Up @@ -254,28 +254,16 @@ def _concat_tp_weights(
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)


def _get_physical_gpu_id(device_index: int | None = None) -> str:
def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str:
try:
return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}"
if device_manager.device_type == "npu":
return f"NPU-{npu_generate_uuid()}"
else:
return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}"
except AssertionError as e:
raise ValueError(f"fail to get physical gpu id {device_index}") from e


@lru_cache(maxsize=1)
def _get_ip() -> str:
try:
# try to get ip from network interface
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(("8.8.8.8", 80))
return s.getsockname()[0]
except Exception as e: # noqa: BLE001
# fallback to get ip from hostname
logger.warning(
f"fail to get ip from network interface, fallback to get ip from hostname: {e}"
)
return socket.gethostbyname(socket.gethostname())


def _ibv_get_device_list() -> list[str]:
lib = ctypes.CDLL("libibverbs.so.1")
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
Expand Down Expand Up @@ -677,14 +665,14 @@ def _get_bcast_rank_map(world_size: int, ranks: list[int] | None) -> dict[int, i


class P2PStore:
def __init__(self):
def __init__(self, device_manager: DeviceManager):
from mooncake.engine import TransferEngine

self.rank = int(os.getenv("RANK"))
gpu_count = torch.cuda.device_count()
gpu_count = device_manager.device_module.device_count()
local_rank = self.rank % gpu_count
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
self.ip = _get_ip()
self.ip = get_ip()

# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
retry_count = 8
Expand Down Expand Up @@ -761,7 +749,8 @@ def __init__(
"""
self._rank = rank or int(os.environ.get("RANK", None))
self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None))
self._gpu_count = gpu_count or torch.cuda.device_count()
self.device_manager = DeviceManager()
self._gpu_count = gpu_count or self.device_manager.device_module.device_count()
self._local_rank = self._rank % self._gpu_count
self._auto_pg = auto_pg
self._all_hosts = []
Expand All @@ -775,7 +764,7 @@ def __init__(
assert (
self._gpu_count is not None
and self._gpu_count > 0
and self._gpu_count <= torch.cuda.device_count()
and self._gpu_count <= self.device_manager.device_module.device_count()
), self._gpu_count
assert (
self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1
Expand All @@ -788,14 +777,14 @@ def __init__(
# dict key is owner_rank, value is a bucket metas list in owner_rank
self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {}
try:
self._p2p_store = P2PStore()
self._p2p_store = P2PStore(self.device_manager)
except ImportError as e:
logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}")
self._p2p_store = None

device_index = self._local_rank
torch.cuda.set_device(device_index)
self._device_uuid = _get_physical_gpu_id(device_index)
self.device_manager.device_module.set_device(device_index)
self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index)
self._rdma_device = None if self._p2p_store is None else self._p2p_store.device

def _logger_rank0(self, msg: str):
Expand Down Expand Up @@ -885,7 +874,7 @@ def gather_metas(self, checkpoint_name: str):
for x in self._memory_pool.get(checkpoint_name, [])
],
p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr,
host_ip=_get_ip(),
host_ip=get_ip(),
device_uuid=self._device_uuid,
rdma_device=self._rdma_device or "",
)
Expand Down Expand Up @@ -948,7 +937,7 @@ def init_process_group(
is_master=self._rank == 0,
)
dist.init_process_group(
backend="nccl",
backend=self.device_manager.backend,
world_size=self._world_size,
rank=self._rank,
timeout=timeout,
Expand Down Expand Up @@ -994,12 +983,12 @@ def update(
if self._auto_pg:
dist.destroy_process_group()

torch.cuda.empty_cache()
self.device_manager.device_module.empty_cache()

logger.info(
f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. "
f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, "
f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB."
f"Current CUDA allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, "
f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB."
)
except Exception as e:
logger.exception(
Expand All @@ -1023,13 +1012,15 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
tensor = torch.tensor(
[
# proportion of current cuda free memory bytes
int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction),
int(
float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction
),
# we use negative value to reuse allreduce min operation
# for getting the max value of zmq_addr_counter in all ranks
-self._zmq_addr_counter,
],
dtype=torch.int64,
device="cuda",
device=self.device_manager.device_type,
)
dist.all_reduce(tensor, op=dist.ReduceOp.MIN)
tensor = tensor.cpu()
Expand Down Expand Up @@ -1092,7 +1083,7 @@ def _copy_to_buffer(
assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}"
if owner_rank is not None:
self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens)
torch.cuda.synchronize()
self.device_manager.device_module.synchronize()

def init_process_group_for_ranks(
self,
Expand Down Expand Up @@ -1199,7 +1190,7 @@ def _update_per_bucket(
h2d_buffer: torch.Tensor | None = (
None
if disable_h2d_buffer
else torch.empty(bucket_size, dtype=torch.uint8, device="cuda")
else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
)
# p2p store need to register h2d_buffer to let other ranks read
if ranks:
Expand All @@ -1212,7 +1203,9 @@ def _update_per_bucket(
continue
receiver_rank_buckets.append((owner_rank, bucket))

buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda")
buffer = torch.empty(
bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type
)
handle = reduce_tensor(buffer)

buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
Expand Down Expand Up @@ -1245,8 +1238,8 @@ def _update_per_bucket(
continue
bucket = _buckets[i]
alloc, reserved = (
torch.cuda.memory_allocated() / 1024 / 1024,
torch.cuda.memory_reserved() / 1024 / 1024,
self.device_manager.device_module.memory_allocated() / 1024 / 1024,
self.device_manager.device_module.memory_reserved() / 1024 / 1024,
)
self._logger_rank0(
f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} receiver_rank {receiver_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. "
Expand Down Expand Up @@ -1276,7 +1269,7 @@ def _update_per_bucket(
if ranks and h2d_buffer is not None:
self._p2p_store.unregister_named_tensors([h2d_buffer_name])

torch.cuda.empty_cache()
self.device_manager.device_module.empty_cache()


def _init_api(ps: ParameterServer) -> Any:
Expand Down
17 changes: 13 additions & 4 deletions checkpoint_engine/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import zmq

from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid


def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
func, args = handle
Expand Down Expand Up @@ -53,13 +55,14 @@ def update_weights_from_ipc(
socket = zmq_ctx.socket(zmq.REP)
socket.connect(zmq_handle)
buffer: torch.Tensor | None = None
device_mananger = DeviceManager()
while True:
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
if payload is None:
# means the update is done
if post_hook is not None:
post_hook()
torch.cuda.synchronize()
device_mananger.device_module.synchronize()
socket.send(b"")
break
if isinstance(payload, tuple):
Expand All @@ -71,13 +74,13 @@ def update_weights_from_ipc(
continue
assert isinstance(payload, list)
run(_extract_weights(payload, buffer))
torch.cuda.synchronize()
device_mananger.device_module.synchronize()
socket.send(b"")

socket.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
device_mananger.device_module.empty_cache()


class VllmColocateWorkerExtension:
Expand All @@ -94,10 +97,16 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
from vllm.model_executor.model_loader.utils import process_weights_after_loading
from vllm.platforms import current_platform

# vllm-ascend not init device
if current_platform.device_type == "npu" and self.device is None:
self.device = torch.device(f"npu:{self.local_rank}")
assert self.device is not None
if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
device_uuid = current_platform.get_device_uuid(self.device.index)
if current_platform.device_type == "gpu":
device_uuid = current_platform.get_device_uuid(self.device.index)
elif current_platform.device_type == "npu":
device_uuid = f"NPU-{npu_generate_uuid()}"
update_weights_from_ipc(
self._zmq_ctx,
zmq_handles[device_uuid],
Expand Down