|
1 | 1 | # pyright: reportCallIssue=false |
2 | 2 |
|
3 | | -from collections.abc import Sequence |
| 3 | +from typing import Any, Optional |
4 | 4 |
|
5 | | -import torch |
| 5 | +import nvshmem.core as nvshmem # type: ignore[import] |
| 6 | +import torch.distributed as dist |
6 | 7 |
|
7 | | -from .ops import _ops |
8 | 8 |
|
9 | 9 | ###### NVSHMEM ###### |
10 | | - |
11 | | - |
12 | | -def nvshmem_get_unique_id() -> torch.Tensor: |
13 | | - return _ops.nvshmem_get_unique_id() |
14 | | - |
15 | | - |
16 | | -def nvshmem_unique_id_size() -> int: |
17 | | - return _ops.nvshmem_unique_id_size() |
18 | | - |
19 | | - |
20 | | -def nvshmem_alloc_empty_unique_id() -> torch.Tensor: |
21 | | - return torch.zeros(nvshmem_unique_id_size(), dtype=torch.uint8, device="cpu") |
22 | | - |
23 | | - |
24 | | -def nvshmem_init(uid: torch.Tensor, rank: int, world_size: int) -> int: |
25 | | - status = _ops.nvshmem_init(uid, rank, world_size) |
26 | | - torch.cuda.synchronize() |
27 | | - return status |
28 | | - |
29 | | - |
30 | | -def nvshmem_alltoall(dest: torch.Tensor, source: torch.Tensor) -> None: |
31 | | - return _ops.nvshmem_alltoall(dest, source) |
32 | | - |
33 | | - |
34 | | -def nvshmem_finalize() -> None: |
35 | | - torch.cuda.synchronize() |
36 | | - _ops.nvshmem_finalize() |
37 | | - |
38 | | - |
39 | | -def nvshmem_my_pe() -> int: |
40 | | - return _ops.nvshmem_my_pe() |
41 | | - |
42 | | - |
43 | | -def nvshmem_n_pes() -> int: |
44 | | - return _ops.nvshmem_n_pes() |
45 | | - |
46 | | - |
47 | | -def nvshmem_malloc( |
48 | | - shape: Sequence[int], |
49 | | - dtype: torch.dtype, |
50 | | - device: torch.device, |
51 | | -) -> torch.Tensor: |
52 | | - return _ops.nvshmem_malloc(shape, dtype, device) |
53 | | - |
54 | | - |
55 | | -def nvshmem_barrier_all() -> None: |
56 | | - _ops.nvshmem_barrier_all() |
57 | | - |
58 | | - |
59 | | -def nvshmem_barrier_all_on_current_stream() -> None: |
60 | | - _ops.nvshmem_barrier_all_on_current_stream() |
| 10 | +def nvshmem_init( |
| 11 | + global_rank: int, |
| 12 | + local_rank: int, |
| 13 | + world_size: int, |
| 14 | + device: Any, |
| 15 | + uid: Optional[Any] = None, |
| 16 | +) -> None: |
| 17 | + uniqueid = nvshmem.get_unique_id(empty=True) |
| 18 | + if local_rank == 0: |
| 19 | + uniqueid = nvshmem.get_unique_id() |
| 20 | + broadcast_objects = [uniqueid] |
| 21 | + else: |
| 22 | + broadcast_objects = [None] |
| 23 | + |
| 24 | + dist.broadcast_object_list(broadcast_objects, src=0) |
| 25 | + dist.barrier() |
| 26 | + |
| 27 | + nvshmem.init( |
| 28 | + device=device, |
| 29 | + uid=broadcast_objects[0], |
| 30 | + rank=global_rank, |
| 31 | + nranks=world_size, |
| 32 | + initializer_method="uid", |
| 33 | + ) |
| 34 | + |
| 35 | + |
| 36 | +# This stream wrapper returns the format required by CUDA Python. This workaround will be removed when nvshmem4py supports Torch stream interoperability. |
| 37 | +# For more information see: https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol |
| 38 | +class PyTorchStreamWrapper: |
| 39 | + def __init__(self, pt_stream: Any) -> None: |
| 40 | + self.pt_stream = pt_stream |
| 41 | + self.handle = pt_stream.cuda_stream |
| 42 | + |
| 43 | + def __cuda_stream__(self) -> tuple[int, int]: |
| 44 | + stream_id = self.pt_stream.cuda_stream |
| 45 | + return (0, stream_id) |
0 commit comments