From adb4075d30b21d5f55082930a960297bd2e84ffd Mon Sep 17 00:00:00 2001 From: cuixiaojin Date: Tue, 21 Oct 2025 19:28:58 +0800 Subject: [PATCH 1/8] [Hardware] broadcast support for Huawei Ascend NPU --- checkpoint_engine/ps.py | 76 ++++++++++++++++++++++++++----------- checkpoint_engine/worker.py | 18 +++++++-- 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 1493a69..41718dc 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple import httpx +import importlib import numpy as np import torch import torch.distributed as dist @@ -24,6 +25,31 @@ from torch.multiprocessing.reductions import reduce_tensor +class DeviceManager: + def __init__(self): + self.device_type = self._detect_device_type() + self._setup_device_module() + + def _detect_device_type(self): + if importlib.util.find_spec("torch_npu") is not None: + return "npu" + elif torch.cuda.is_available(): + return "cuda" + + def _setup_device_module(self): + if self.device_type == "npu": + import torch_npu + self.device_module = torch_npu.npu + elif self.device_type == "cuda": + self.device_module = torch.cuda + + def get_backend(self): + if self.device_type == "npu": + return "hccl" + elif self.device_type == "cuda": + return "nccl" + + if TYPE_CHECKING: from typing_extensions import TypedDict @@ -249,9 +275,12 @@ def _concat_tp_weights( return torch.cat([w for w in tp_weights], dim=tp_concat_dim) -def _get_physical_gpu_id(device_index: int | None = None) -> str: +def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str: try: - return f"GPU-{torch.cuda.get_device_properties(device_index).uuid!s}" + if importlib.util.find_spec("torch_npu") is not None: + return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{device_index}" + else: + return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}" except AssertionError as e: raise ValueError(f"fail to get physical gpu id {device_index}") from e @@ -588,11 +617,11 @@ def _get_master_port(master_port: int | None = None) -> int: class P2PStore: - def __init__(self): + def __init__(self, device_manager): from mooncake.engine import TransferEngine self.rank = int(os.getenv("RANK")) - gpu_count = torch.cuda.device_count() + gpu_count = device_manager.device_module.device_count() local_rank = self.rank % gpu_count device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) self.ip = _get_ip() @@ -672,7 +701,8 @@ def __init__( """ self._rank = rank or int(os.environ.get("RANK", None)) self._world_size = world_size or int(os.environ.get("WORLD_SIZE", None)) - self._gpu_count = gpu_count or torch.cuda.device_count() + self.device_manager = DeviceManager() + self._gpu_count = gpu_count or self.device_manager.device_module.device_count() self._local_rank = self._rank % self._gpu_count self._auto_pg = auto_pg self._all_hosts = [] @@ -684,7 +714,7 @@ def __init__( assert ( self._gpu_count is not None and self._gpu_count > 0 - and self._gpu_count <= torch.cuda.device_count() + and self._gpu_count <= self.device_manager.device_module.device_count() ), self._gpu_count assert ( self._mem_fraction is not None and self._mem_fraction > 0 and self._mem_fraction <= 1 @@ -697,14 +727,14 @@ def __init__( # dict key is owner_rank, value is a bucket metas list in owner_rank self._current_global_parameter_metas: dict[int, MemoryBufferMetaList] = {} try: - self._p2p_store = P2PStore() + self._p2p_store = P2PStore(self.device_manager) except ImportError as e: logger.warning(f"[rank{self._rank}] fail to initialize p2p store due to {e}") self._p2p_store = None device_index = self._local_rank - torch.cuda.set_device(device_index) - self._device_uuid = _get_physical_gpu_id(device_index) + self.device_manager.device_module.set_device(device_index) + self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index) def _logger_rank0(self, msg: str): if self._local_rank == 0: @@ -842,7 +872,7 @@ def init_process_group( is_master=self._rank == 0, ) dist.init_process_group( - backend="nccl", + backend=self.device_manager.get_backend(), world_size=self._world_size, rank=self._rank, timeout=timeout, @@ -889,12 +919,12 @@ def update( if self._auto_pg: dist.destroy_process_group() - torch.cuda.empty_cache() + self.device_manager.device_module.empty_cache() logger.info( f"[rank{self._rank}] update checkpoint {checkpoint_name} with ranks {ranks} done. " - f"Current CUDA allocated {torch.cuda.memory_allocated() / 1024 / 1024} MB, " - f"reserved {torch.cuda.memory_reserved() / 1024 / 1024} MB." + f"Current CUDA allocated {self.device_manager.device_module.memory_allocated() / 1024 / 1024} MB, " + f"reserved {self.device_manager.device_module.memory_reserved() / 1024 / 1024} MB." ) except Exception as e: logger.exception( @@ -918,13 +948,13 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, tensor = torch.tensor( [ # proportion of current cuda free memory bytes - int(float(torch.cuda.mem_get_info()[0]) * self._mem_fraction), + int(float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction), # we use negative value to reuse allreduce min operation # for getting the max value of zmq_addr_counter in all ranks -self._zmq_addr_counter, ], dtype=torch.int64, - device="cuda", + device=self.device_manager.device_type, ) dist.all_reduce(tensor, op=dist.ReduceOp.MIN) tensor = tensor.cpu() @@ -987,7 +1017,7 @@ def _copy_to_buffer( assert offset == bucket.size, f"offset {offset} != bucket_size {bucket.size}" if owner_rank is not None: self._p2p_store.batch_transfer_sync_read(target_addr, buf_ptrs, remote_ptrs, lens) - torch.cuda.synchronize() + self.device_manager.device_module.synchronize() def init_process_group_for_ranks( self, @@ -1057,7 +1087,7 @@ def _update_per_bucket_p2p( dist.barrier() bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True) - buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda") + buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type) ipc_buffer_name = "__ipc_buffer___" self._p2p_store.register_named_tensors({ipc_buffer_name: buffer}) logger.info( @@ -1093,7 +1123,7 @@ def _update_per_bucket_p2p( dist.barrier() socket.close() self._p2p_store.unregister_named_tensors([ipc_buffer_name]) - torch.cuda.empty_cache() + self.device_manager.device_module.empty_cache() def _get_addr_ptrs(self, owner_rank: int) -> tuple[str, list[tuple[int, int]]]: addr = self._current_global_parameter_metas[owner_rank].p2p_store_addr @@ -1138,7 +1168,7 @@ def _update_per_bucket( h2d_buffer: torch.Tensor | None = ( None if disable_h2d_buffer - else torch.empty(bucket_size, dtype=torch.uint8, device="cuda") + else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type) ) owner_rank_buckets: list[H2DBucket] = [] @@ -1147,7 +1177,7 @@ def _update_per_bucket( continue owner_rank_buckets.append(bucket) - buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device="cuda") + buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type) handle = reduce_tensor(buffer) buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list) @@ -1175,8 +1205,8 @@ def _update_per_bucket( continue bucket = _buckets[i] alloc, reserved = ( - torch.cuda.memory_allocated() / 1024 / 1024, - torch.cuda.memory_reserved() / 1024 / 1024, + self.device_manager.device_module.memory_allocated() / 1024 / 1024, + self.device_manager.device_module.memory_reserved() / 1024 / 1024, ) self._logger_rank0( f"[rank{self._rank}] begin to update bucket {gidx + 1}/{len(buckets)} owner_rank {owner_rank} in checkpoint {checkpoint_name}, bucket_size: {bucket.size / 1024 / 1024:.2f}MiB, length: {len(bucket.items)}. " @@ -1202,7 +1232,7 @@ def _update_per_bucket( req_thread.join() dist.barrier() socket.close() - torch.cuda.empty_cache() + self.device_manager.device_module.empty_cache() def _init_api(ps: ParameterServer) -> Any: diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index e332d73..93cb244 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -4,6 +4,7 @@ import torch import zmq +from checkpoint_engine.ps import DeviceManager def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: @@ -53,13 +54,14 @@ def update_weights_from_ipc( socket = zmq_ctx.socket(zmq.REP) socket.connect(zmq_handle) buffer: torch.Tensor | None = None + device_mananger = DeviceManager() while True: payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj() if payload is None: # means the update is done if post_hook is not None: post_hook() - torch.cuda.synchronize() + device_mananger.device_module.synchronize() socket.send(b"") break if isinstance(payload, tuple): @@ -71,13 +73,13 @@ def update_weights_from_ipc( continue assert isinstance(payload, list) run(_extract_weights(payload, buffer)) - torch.cuda.synchronize() + device_mananger.device_module.synchronize() socket.send(b"") socket.close() del buffer gc.collect() - torch.cuda.empty_cache() + device_mananger.device_module.empty_cache() class VllmColocateWorkerExtension: @@ -94,10 +96,18 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): from vllm.model_executor.model_loader.utils import process_weights_after_loading from vllm.platforms import current_platform + # vllm-ascend not init device + if current_platform.device_type == "npu" and self.device is None: + self.device = torch.device(f"npu:{self.local_rank}") assert self.device is not None if not hasattr(self, "_zmq_ctx") or self._zmq_ctx is None: self._zmq_ctx = zmq.Context() - device_uuid = current_platform.get_device_uuid(self.device.index) + if current_platform.device_type == "gpu": + device_uuid = current_platform.get_device_uuid(self.device.index) + elif current_platform.device_type == "npu": + device_uuid = ( + f"NPU-{current_platform.get_device_name(self.device.index)!s}-{self.device.index}" + ) update_weights_from_ipc( self._zmq_ctx, zmq_handles[device_uuid], From f083dbd1cc21f12dc7bf2bb007c6b3a9c0da08bc Mon Sep 17 00:00:00 2001 From: cuixiaojin Date: Tue, 21 Oct 2025 21:09:49 +0800 Subject: [PATCH 2/8] [modify] check npu is availble --- checkpoint_engine/ps.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 41718dc..59288f1 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple import httpx -import importlib import numpy as np import torch import torch.distributed as dist @@ -25,13 +24,21 @@ from torch.multiprocessing.reductions import reduce_tensor +def is_torch_npu_available() -> bool: + try: + if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)): + return torch.npu.is_available() + return False + except ImportError: + return False + class DeviceManager: def __init__(self): self.device_type = self._detect_device_type() self._setup_device_module() def _detect_device_type(self): - if importlib.util.find_spec("torch_npu") is not None: + if is_torch_npu_available(): return "npu" elif torch.cuda.is_available(): return "cuda" @@ -277,7 +284,7 @@ def _concat_tp_weights( def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str: try: - if importlib.util.find_spec("torch_npu") is not None: + if device_manager.device_type == "npu": return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{device_index}" else: return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}" From bed9862a5c6430f781290c272bd101a683d99625 Mon Sep 17 00:00:00 2001 From: cuixiaojin Date: Wed, 22 Oct 2025 14:32:40 +0800 Subject: [PATCH 3/8] [Fix] fix the pre-commit lint error --- checkpoint_engine/ps.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 59288f1..ad853c3 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -28,7 +28,8 @@ def is_torch_npu_available() -> bool: try: if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)): return torch.npu.is_available() - return False + else: + return False except ImportError: return False @@ -37,7 +38,7 @@ def __init__(self): self.device_type = self._detect_device_type() self._setup_device_module() - def _detect_device_type(self): + def _detect_device_type(self) -> str: if is_torch_npu_available(): return "npu" elif torch.cuda.is_available(): @@ -50,7 +51,7 @@ def _setup_device_module(self): elif self.device_type == "cuda": self.device_module = torch.cuda - def get_backend(self): + def get_backend(self) -> str: if self.device_type == "npu": return "hccl" elif self.device_type == "cuda": @@ -624,7 +625,7 @@ def _get_master_port(master_port: int | None = None) -> int: class P2PStore: - def __init__(self, device_manager): + def __init__(self, device_manager: DeviceManager): from mooncake.engine import TransferEngine self.rank = int(os.getenv("RANK")) From c9d3d421b72d62684f1de2506099f4e32ce99c20 Mon Sep 17 00:00:00 2001 From: cuixiaojin Date: Wed, 22 Oct 2025 16:36:05 +0800 Subject: [PATCH 4/8] [modify] address code view feedback --- checkpoint_engine/ps.py | 37 +++++++++++++++++++++---------------- checkpoint_engine/worker.py | 4 ++-- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index ad853c3..fe8f57f 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -24,25 +24,27 @@ from torch.multiprocessing.reductions import reduce_tensor -def is_torch_npu_available() -> bool: - try: - if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)): - return torch.npu.is_available() - else: - return False - except ImportError: - return False - class DeviceManager: def __init__(self): self.device_type = self._detect_device_type() self._setup_device_module() + def _is_torch_npu_available(self) -> bool: + try: + if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)): + return torch.npu.is_available() + else: + return False + except ImportError: + return False + def _detect_device_type(self) -> str: - if is_torch_npu_available(): + if self._is_torch_npu_available(): return "npu" elif torch.cuda.is_available(): return "cuda" + else: + raise TypeError("The current device type is not supported") def _setup_device_module(self): if self.device_type == "npu": @@ -50,8 +52,11 @@ def _setup_device_module(self): self.device_module = torch_npu.npu elif self.device_type == "cuda": self.device_module = torch.cuda + else: + raise TypeError("The current device type is not supported") - def get_backend(self) -> str: + @property + def backend(self) -> str: if self.device_type == "npu": return "hccl" elif self.device_type == "cuda": @@ -283,10 +288,10 @@ def _concat_tp_weights( return torch.cat([w for w in tp_weights], dim=tp_concat_dim) -def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str: +def _get_physical_gpu_id(device_manager: DeviceManager, rank_id: int, device_index: int | None = None) -> str: try: if device_manager.device_type == "npu": - return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{device_index}" + return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{rank_id}" else: return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}" except AssertionError as e: @@ -625,7 +630,7 @@ def _get_master_port(master_port: int | None = None) -> int: class P2PStore: - def __init__(self, device_manager: DeviceManager): + def __init__(self, device_manager : DeviceManager): from mooncake.engine import TransferEngine self.rank = int(os.getenv("RANK")) @@ -742,7 +747,7 @@ def __init__( device_index = self._local_rank self.device_manager.device_module.set_device(device_index) - self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index) + self._device_uuid = _get_physical_gpu_id(self.device_manager, self._rank, device_index) def _logger_rank0(self, msg: str): if self._local_rank == 0: @@ -880,7 +885,7 @@ def init_process_group( is_master=self._rank == 0, ) dist.init_process_group( - backend=self.device_manager.get_backend(), + backend=self.device_manager.backend, world_size=self._world_size, rank=self._rank, timeout=timeout, diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index 93cb244..344a6f0 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -4,7 +4,7 @@ import torch import zmq -from checkpoint_engine.ps import DeviceManager +from .ps import DeviceManager def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: @@ -106,7 +106,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): device_uuid = current_platform.get_device_uuid(self.device.index) elif current_platform.device_type == "npu": device_uuid = ( - f"NPU-{current_platform.get_device_name(self.device.index)!s}-{self.device.index}" + f"NPU-{current_platform.get_device_name(self.device.index)!s}-{self.rank}" ) update_weights_from_ipc( self._zmq_ctx, From c7c2d8dfa49b33d7c1bd6a5171aa06a75e8032ac Mon Sep 17 00:00:00 2001 From: cuixiaojin Date: Mon, 27 Oct 2025 20:55:56 +0800 Subject: [PATCH 5/8] [modify] generate uuid by npu smi info --- checkpoint_engine/device_utils.py | 68 +++++++++++++++++++++++++++++++ checkpoint_engine/ps.py | 60 +++++++-------------------- checkpoint_engine/worker.py | 8 ++-- 3 files changed, 87 insertions(+), 49 deletions(-) create mode 100644 checkpoint_engine/device_utils.py diff --git a/checkpoint_engine/device_utils.py b/checkpoint_engine/device_utils.py new file mode 100644 index 0000000..7d40c48 --- /dev/null +++ b/checkpoint_engine/device_utils.py @@ -0,0 +1,68 @@ +import os +import re +import socket +import subprocess + +import torch + + +def npu_generate_uuid() -> str: + str_pid = str(os.getpid()) + npu_num = 8 + try: + for npu_id in range(npu_num): + cmd = ["npu-smi", "info", "-t", "proc-mem", "-i", str(npu_id)] + result = subprocess.run(cmd, check=True, capture_output=True, text=True) # noqa: S603 + str_result = str(result.stdout) + if str_pid in str_result: + # In A3 server, one NPU has two chips. + match_chip_count = re.search(r"Chip Count[^\d]*(\d+)", str_result) + chip_count = int(match_chip_count.group(1)) + search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :] + match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid) + chip_id = int(match_chip_id.group(1)) + server_ip = socket.gethostbyname(socket.gethostname()) + return f"{server_ip}-{npu_id * chip_count + chip_id}" + ValueError("The current process is not running on the npu device") + except subprocess.CalledProcessError: + ValueError("The current process is not running on the npu device") + + +class DeviceManager: + def __init__(self): + self.device_type = self._detect_device_type() + self._setup_device_module() + + def _is_torch_npu_available(self) -> bool: + try: + if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)): + return torch.npu.is_available() + else: + return False + except ImportError: + return False + + def _detect_device_type(self) -> str: + if self._is_torch_npu_available(): + return "npu" + elif torch.cuda.is_available(): + return "cuda" + else: + raise TypeError("The current device type is not supported") + + def _setup_device_module(self): + if self.device_type == "npu": + import torch_npu + + self.device_module = torch_npu.npu + elif self.device_type == "cuda": + self.device_module = torch.cuda + else: + raise TypeError("The current device type is not supported") + + @property + def backend(self) -> str: + if self.device_type == "npu": + return "hccl" + elif self.device_type == "cuda": + return "nccl" diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index fe8f57f..24987ef 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -23,44 +23,7 @@ from safetensors.torch import safe_open from torch.multiprocessing.reductions import reduce_tensor - -class DeviceManager: - def __init__(self): - self.device_type = self._detect_device_type() - self._setup_device_module() - - def _is_torch_npu_available(self) -> bool: - try: - if hasattr(torch, "npu") and callable(getattr(torch.npu, "is_available", None)): - return torch.npu.is_available() - else: - return False - except ImportError: - return False - - def _detect_device_type(self) -> str: - if self._is_torch_npu_available(): - return "npu" - elif torch.cuda.is_available(): - return "cuda" - else: - raise TypeError("The current device type is not supported") - - def _setup_device_module(self): - if self.device_type == "npu": - import torch_npu - self.device_module = torch_npu.npu - elif self.device_type == "cuda": - self.device_module = torch.cuda - else: - raise TypeError("The current device type is not supported") - - @property - def backend(self) -> str: - if self.device_type == "npu": - return "hccl" - elif self.device_type == "cuda": - return "nccl" +from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid if TYPE_CHECKING: @@ -288,10 +251,11 @@ def _concat_tp_weights( return torch.cat([w for w in tp_weights], dim=tp_concat_dim) -def _get_physical_gpu_id(device_manager: DeviceManager, rank_id: int, device_index: int | None = None) -> str: +def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str: try: if device_manager.device_type == "npu": - return f"NPU-{device_manager.device_module.get_device_properties(device_index).name!s}-{rank_id}" + serial_number = npu_generate_uuid() + return f"NPU-{serial_number}" else: return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}" except AssertionError as e: @@ -630,7 +594,7 @@ def _get_master_port(master_port: int | None = None) -> int: class P2PStore: - def __init__(self, device_manager : DeviceManager): + def __init__(self, device_manager: DeviceManager): from mooncake.engine import TransferEngine self.rank = int(os.getenv("RANK")) @@ -747,7 +711,7 @@ def __init__( device_index = self._local_rank self.device_manager.device_module.set_device(device_index) - self._device_uuid = _get_physical_gpu_id(self.device_manager, self._rank, device_index) + self._device_uuid = _get_physical_gpu_id(self.device_manager, device_index) def _logger_rank0(self, msg: str): if self._local_rank == 0: @@ -961,7 +925,9 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int, tensor = torch.tensor( [ # proportion of current cuda free memory bytes - int(float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction), + int( + float(self.device_manager.device_module.mem_get_info()[0]) * self._mem_fraction + ), # we use negative value to reuse allreduce min operation # for getting the max value of zmq_addr_counter in all ranks -self._zmq_addr_counter, @@ -1100,7 +1066,9 @@ def _update_per_bucket_p2p( dist.barrier() bucket_size, _ = self._detect_bucket_size(disable_h2d_buffer=True) - buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type) + buffer = torch.empty( + bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type + ) ipc_buffer_name = "__ipc_buffer___" self._p2p_store.register_named_tensors({ipc_buffer_name: buffer}) logger.info( @@ -1190,7 +1158,9 @@ def _update_per_bucket( continue owner_rank_buckets.append(bucket) - buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type) + buffer = torch.empty( + bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type + ) handle = reduce_tensor(buffer) buckets_by_owner_rank: dict[int, list[H2DBucket]] = defaultdict(list) diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index 344a6f0..0959300 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -4,7 +4,8 @@ import torch import zmq -from .ps import DeviceManager + +from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid def _rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: @@ -105,9 +106,8 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): if current_platform.device_type == "gpu": device_uuid = current_platform.get_device_uuid(self.device.index) elif current_platform.device_type == "npu": - device_uuid = ( - f"NPU-{current_platform.get_device_name(self.device.index)!s}-{self.rank}" - ) + serial_number = npu_generate_uuid() + device_uuid = f"NPU-{serial_number}" update_weights_from_ipc( self._zmq_ctx, zmq_handles[device_uuid], From 56f0ec283a24544283393568c284617277246da5 Mon Sep 17 00:00:00 2001 From: cuixiaojin Date: Tue, 28 Oct 2025 15:27:48 +0800 Subject: [PATCH 6/8] [modify] fix the variable name --- checkpoint_engine/ps.py | 3 +-- checkpoint_engine/worker.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 24987ef..45d0621 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -254,8 +254,7 @@ def _concat_tp_weights( def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None = None) -> str: try: if device_manager.device_type == "npu": - serial_number = npu_generate_uuid() - return f"NPU-{serial_number}" + return f"NPU-{npu_generate_uuid()}" else: return f"GPU-{device_manager.device_module.get_device_properties(device_index).uuid!s}" except AssertionError as e: diff --git a/checkpoint_engine/worker.py b/checkpoint_engine/worker.py index 0959300..5a45dff 100644 --- a/checkpoint_engine/worker.py +++ b/checkpoint_engine/worker.py @@ -106,8 +106,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]): if current_platform.device_type == "gpu": device_uuid = current_platform.get_device_uuid(self.device.index) elif current_platform.device_type == "npu": - serial_number = npu_generate_uuid() - device_uuid = f"NPU-{serial_number}" + device_uuid = f"NPU-{npu_generate_uuid()}" update_weights_from_ipc( self._zmq_ctx, zmq_handles[device_uuid], From b2c8251c7d268f6fb13f9887c5f213c8defdd964 Mon Sep 17 00:00:00 2001 From: cuixiaojin Date: Wed, 29 Oct 2025 17:45:11 +0800 Subject: [PATCH 7/8] [modify] get ip --- checkpoint_engine/device_utils.py | 22 +++++++++++++++++++--- checkpoint_engine/ps.py | 23 +++-------------------- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/checkpoint_engine/device_utils.py b/checkpoint_engine/device_utils.py index 7d40c48..0d90162 100644 --- a/checkpoint_engine/device_utils.py +++ b/checkpoint_engine/device_utils.py @@ -2,9 +2,26 @@ import re import socket import subprocess - import torch +from functools import lru_cache +from loguru import logger + + +@lru_cache(maxsize=1) +def get_ip() -> str: + try: + # try to get ip from network interface + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + except Exception as e: # noqa: BLE001 + # fallback to get ip from hostname + logger.warning( + f"fail to get ip from network interface, fallback to get ip from hostname: {e}" + ) + return socket.gethostbyname(socket.gethostname()) + def npu_generate_uuid() -> str: str_pid = str(os.getpid()) @@ -21,8 +38,7 @@ def npu_generate_uuid() -> str: search_after_pid = str_result[str_result.find(str_pid) + len(str_pid) :] match_chip_id = re.search(r"Chip ID[^\d]*(\d+)", search_after_pid) chip_id = int(match_chip_id.group(1)) - server_ip = socket.gethostbyname(socket.gethostname()) - return f"{server_ip}-{npu_id * chip_count + chip_id}" + return f"{get_ip()}-{npu_id * chip_count + chip_id}" ValueError("The current process is not running on the npu device") except subprocess.CalledProcessError: ValueError("The current process is not running on the npu device") diff --git a/checkpoint_engine/ps.py b/checkpoint_engine/ps.py index 45d0621..a51934e 100644 --- a/checkpoint_engine/ps.py +++ b/checkpoint_engine/ps.py @@ -4,13 +4,11 @@ import os import pickle import random -import socket import threading import time from collections import defaultdict from collections.abc import Callable from datetime import timedelta -from functools import lru_cache from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple import httpx @@ -23,7 +21,7 @@ from safetensors.torch import safe_open from torch.multiprocessing.reductions import reduce_tensor -from checkpoint_engine.device_utils import DeviceManager, npu_generate_uuid +from checkpoint_engine.device_utils import DeviceManager, get_ip, npu_generate_uuid if TYPE_CHECKING: @@ -261,21 +259,6 @@ def _get_physical_gpu_id(device_manager: DeviceManager, device_index: int | None raise ValueError(f"fail to get physical gpu id {device_index}") from e -@lru_cache(maxsize=1) -def _get_ip() -> str: - try: - # try to get ip from network interface - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - s.connect(("8.8.8.8", 80)) - return s.getsockname()[0] - except Exception as e: # noqa: BLE001 - # fallback to get ip from hostname - logger.warning( - f"fail to get ip from network interface, fallback to get ip from hostname: {e}" - ) - return socket.gethostbyname(socket.gethostname()) - - def _ibv_get_device_list() -> list[str]: lib = ctypes.CDLL("libibverbs.so.1") lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices @@ -600,7 +583,7 @@ def __init__(self, device_manager: DeviceManager): gpu_count = device_manager.device_module.device_count() local_rank = self.rank % gpu_count device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices()) - self.ip = _get_ip() + self.ip = get_ip() # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases retry_count = 8 @@ -792,7 +775,7 @@ def gather_metas(self, checkpoint_name: str): for x in self._memory_pool.get(checkpoint_name, []) ], p2p_store_addr=None if self._p2p_store is None else self._p2p_store.addr, - host_ip=_get_ip(), + host_ip=get_ip(), device_uuid=self._device_uuid, ) From b225b624ba8036ce6a6b1781fb4b54095c88def8 Mon Sep 17 00:00:00 2001 From: kip-cxj <939544916@qq.com> Date: Thu, 30 Oct 2025 10:31:11 +0800 Subject: [PATCH 8/8] [fix] pre-commit fix Signed-off-by: kip-cxj <939544916@qq.com> --- checkpoint_engine/device_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/checkpoint_engine/device_utils.py b/checkpoint_engine/device_utils.py index 0d90162..013d871 100644 --- a/checkpoint_engine/device_utils.py +++ b/checkpoint_engine/device_utils.py @@ -2,9 +2,9 @@ import re import socket import subprocess -import torch - from functools import lru_cache + +import torch from loguru import logger