Skip to content

Commit f70b66d

Browse files
authored
add nvshmem sum_reduce for mnnvl allreduce (#1152)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> add support for MNNVL all reduce through NVSHMEM sum reduce ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 71509fa commit f70b66d

File tree

3 files changed

+320
-0
lines changed

3 files changed

+320
-0
lines changed

csrc/nvshmem_binding.cu

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,71 @@ void alltoall(at::Tensor dest, at::Tensor source) {
8383

8484
void fake_alltoall(at::Tensor dest, at::Tensor source) {}
8585

86+
void sum_reduce(at::Tensor dest, at::Tensor source, int64_t nelems) {
87+
TORCH_CHECK(dest.is_contiguous(), "dest must be contiguous");
88+
TORCH_CHECK(source.is_contiguous(), "source must be contiguous");
89+
TORCH_CHECK(dest.scalar_type() == source.scalar_type(),
90+
"dest and source must have the same dtype");
91+
92+
// Add validation and conversion
93+
TORCH_CHECK(nelems >= 0, "nelems must be non-negative, got ", nelems);
94+
TORCH_CHECK(nelems <= SIZE_MAX, "nelems too large: ", nelems, " > ", SIZE_MAX);
95+
size_t nelems_size_t = static_cast<size_t>(nelems);
96+
97+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
98+
99+
switch (dest.scalar_type()) {
100+
case at::kHalf: // float16
101+
NVSHMEMCHECK(nvshmemx_half_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, (__half*)dest.data_ptr(),
102+
(__half*)source.data_ptr(), nelems_size_t,
103+
stream));
104+
break;
105+
case at::kFloat: // float32
106+
NVSHMEMCHECK(nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, (float*)dest.data_ptr(),
107+
(float*)source.data_ptr(), nelems_size_t,
108+
stream));
109+
break;
110+
case at::kBFloat16: // bfloat16
111+
NVSHMEMCHECK(nvshmemx_bfloat16_sum_reduce_on_stream(
112+
NVSHMEM_TEAM_WORLD, (__nv_bfloat16*)dest.data_ptr(), (__nv_bfloat16*)source.data_ptr(),
113+
nelems_size_t, stream));
114+
break;
115+
116+
default:
117+
TORCH_CHECK(false, "Unsupported dtype for nvshmem_sum_reduce: ", dest.scalar_type());
118+
}
119+
}
120+
121+
void fake_sum_reduce(at::Tensor dest, at::Tensor source, int64_t nelems) {}
122+
123+
void allreduce_on_stream_with_copy(at::Tensor dest_symm, at::Tensor source_symm,
124+
at::Tensor dest_local, at::Tensor source_local, int64_t nelems) {
125+
TORCH_CHECK(dest_symm.is_contiguous(), "dest_symm must be contiguous");
126+
TORCH_CHECK(source_symm.is_contiguous(), "source_symm must be contiguous");
127+
TORCH_CHECK(dest_local.is_contiguous(), "dest_local must be contiguous");
128+
TORCH_CHECK(source_local.is_contiguous(), "source_local must be contiguous");
129+
TORCH_CHECK(dest_symm.scalar_type() == source_symm.scalar_type(),
130+
"dest_symm and source_symm must have the same dtype");
131+
TORCH_CHECK(dest_symm.scalar_type() == source_local.scalar_type(),
132+
"dest_symm and source_local must have the same dtype");
133+
TORCH_CHECK(dest_local.scalar_type() == source_local.scalar_type(),
134+
"dest_local and source_local must have the same dtype");
135+
136+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
137+
138+
cudaMemcpyAsync(source_symm.data_ptr(), source_local.data_ptr(),
139+
nelems * source_local.element_size(), cudaMemcpyDefault, stream);
140+
nvshmemx_barrier_on_stream(NVSHMEM_TEAM_WORLD, stream);
141+
sum_reduce(dest_symm, source_symm, nelems);
142+
cudaMemcpyAsync(dest_local.data_ptr(), dest_symm.data_ptr(), nelems * dest_local.element_size(),
143+
cudaMemcpyDefault, stream);
144+
cudaStreamSynchronize(stream);
145+
}
146+
147+
void fake_allreduce_on_stream_with_copy(at::Tensor dest_symm, at::Tensor source_symm,
148+
at::Tensor dest_local, at::Tensor source_local,
149+
int64_t nelems) {}
150+
86151
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
87152
m.def("nvshmem_get_unique_id", &get_unique_id);
88153
m.def("nvshmem_unique_id_size", &unique_id_size);
@@ -96,6 +161,14 @@ TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
96161
m.def("nvshmem_alltoall(Tensor! dest, Tensor src) -> ()");
97162
m.impl("nvshmem_alltoall", c10::kCUDA, &alltoall);
98163
m.impl("nvshmem_alltoall", c10::kMeta, &fake_alltoall);
164+
m.def("nvshmem_sum_reduce(Tensor! dest, Tensor src, int nelems) -> ()");
165+
m.impl("nvshmem_sum_reduce", c10::kCUDA, &sum_reduce);
166+
m.impl("nvshmem_sum_reduce", c10::kMeta, &fake_sum_reduce);
167+
m.def(
168+
"nvshmem_allreduce_on_stream_with_copy(Tensor! dest_symm, Tensor source_symm, Tensor "
169+
"dest_local, Tensor source_local, int nelems) -> ()");
170+
m.impl("nvshmem_allreduce_on_stream_with_copy", c10::kCUDA, &allreduce_on_stream_with_copy);
171+
m.impl("nvshmem_allreduce_on_stream_with_copy", c10::kMeta, &fake_allreduce_on_stream_with_copy);
99172
};
100173

101174
} // namespace

flashinfer/comm/nvshmem_allreduce.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""
2+
Copyright (c) 2023 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
from typing import Optional
18+
19+
import torch
20+
import torch.distributed as dist
21+
from torch.distributed import ProcessGroup
22+
23+
from .nvshmem import get_nvshmem_module
24+
25+
26+
class NVSHMEMAllReduce:
27+
"""
28+
An AllReduce implementation for Single-Node and Multi-Node NVLink communication.
29+
This class handles NVLINK-specific allreduce operations, optimized for NVLink-enabled clusters.
30+
Note: Requires an active torch.distributed process group to be initialized
31+
prior to creating an instance of this class.
32+
33+
Args:
34+
local_rank (int): The local rank of the current process.
35+
world_size (int): The total number of processes in the distributed group.
36+
max_buffer_elements (int): The maximum number of elements that can be stored in
37+
the buffer. This is used to allocate memory in nvshmem symm heap. set to the
38+
largest tensor size you will be reducing.
39+
dtype (torch.dtype): The data type of the tensors to be reduced.
40+
device (torch.device): The device on which the tensors are located.
41+
group (torch.distributed.ProcessGroup, optional): The torch.distributed process group to use.
42+
should_init (bool, optional): Whether to initialize nvshmem. Defaults to True.
43+
Raises:
44+
RuntimeError: If nvshmem fails to initialize.
45+
"""
46+
47+
def __init__(
48+
self,
49+
local_rank: int,
50+
world_size: int,
51+
max_buffer_elements: int,
52+
dtype: torch.dtype,
53+
device: torch.device,
54+
group: Optional[ProcessGroup] = None,
55+
should_init: bool = True,
56+
):
57+
self.local_rank = local_rank
58+
self.world_size = world_size
59+
self.dtype = dtype
60+
self.device = device
61+
self.max_buffer_elements = max_buffer_elements
62+
self.group = group
63+
self.nvshmem_module = get_nvshmem_module()
64+
65+
self.should_init = should_init
66+
if self.should_init:
67+
self.init_nvshmem()
68+
69+
# assert PE and world size match
70+
my_pe = self.nvshmem_module.nvshmem_my_pe()
71+
n_pes = self.nvshmem_module.nvshmem_n_pes()
72+
if my_pe != local_rank:
73+
print(
74+
f"WARNING: Rank {local_rank}: PE mismatch! Expected PE {local_rank}, got PE {my_pe}",
75+
flush=True,
76+
)
77+
if n_pes != world_size:
78+
print(
79+
f"WARNING: Rank {local_rank}: World size mismatch! Expected {world_size}, got {n_pes}",
80+
flush=True,
81+
)
82+
83+
# allocate memory in nvshmem symm heap
84+
self.symm_buffer_input = self.nvshmem_module.nvshmem_malloc(
85+
[max_buffer_elements],
86+
self.dtype,
87+
self.device,
88+
)
89+
self.symm_buffer_output = self.nvshmem_module.nvshmem_malloc(
90+
[max_buffer_elements],
91+
self.dtype,
92+
self.device,
93+
)
94+
torch.distributed.barrier(self.group)
95+
96+
def init_nvshmem(self):
97+
torch.zeros(
98+
self.nvshmem_module.nvshmem_unique_id_size(),
99+
dtype=torch.uint8,
100+
device="cpu",
101+
)
102+
if self.local_rank == 0:
103+
uid = self.nvshmem_module.nvshmem_get_unique_id()
104+
else:
105+
uid = torch.zeros(
106+
self.nvshmem_module.nvshmem_unique_id_size(),
107+
dtype=torch.uint8,
108+
device="cpu",
109+
)
110+
torch.distributed.broadcast(uid, src=0)
111+
torch.distributed.barrier(self.group)
112+
init_status = self.nvshmem_module.nvshmem_init(
113+
uid, self.local_rank, self.world_size
114+
)
115+
torch.cuda.synchronize()
116+
if init_status != 0:
117+
raise RuntimeError("Failed to initialize nvshmem")
118+
119+
def all_reduce(self, inp: torch.Tensor, out: torch.Tensor) -> None:
120+
self.nvshmem_module.nvshmem_allreduce_on_stream_with_copy(
121+
self.symm_buffer_output,
122+
self.symm_buffer_input,
123+
out,
124+
inp,
125+
inp.numel(),
126+
)
127+
128+
def shutdown(self):
129+
del self.symm_buffer_input
130+
del self.symm_buffer_output
131+
torch.distributed.barrier(self.group)
132+
torch.cuda.synchronize()
133+
if self.should_init:
134+
self.nvshmem_module.nvshmem_finalize()

tests/test_nvshmem_allreduce.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import logging
2+
import multiprocessing as mp
3+
import os
4+
import socket
5+
from typing import Any
6+
7+
import pytest
8+
import torch
9+
import torch.distributed as dist
10+
11+
from flashinfer.comm.nvshmem_allreduce import NVSHMEMAllReduce
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
def _run_correctness_worker(world_size, rank, distributed_init_port):
17+
assert rank >= 0
18+
torch.cuda.set_device(rank)
19+
device = torch.device("cuda", rank)
20+
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
21+
dist.init_process_group(
22+
backend="cpu:gloo,cuda:nccl",
23+
rank=rank,
24+
world_size=world_size,
25+
device_id=device,
26+
init_method=distributed_init_method,
27+
)
28+
group = dist.group.WORLD
29+
num_ranks = torch.distributed.get_world_size()
30+
rank_id = torch.distributed.get_rank()
31+
32+
batch_sizes = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
33+
max_batch_size = 4096
34+
hidden_dim = 8192
35+
test_loop = 10
36+
tensor_dtype = torch.bfloat16
37+
nvshmem_allreduce = NVSHMEMAllReduce(
38+
rank_id,
39+
num_ranks,
40+
max_batch_size * hidden_dim,
41+
tensor_dtype,
42+
device,
43+
group,
44+
)
45+
46+
try:
47+
for batch_size in batch_sizes:
48+
for _ in range(test_loop):
49+
tensor_size = batch_size * hidden_dim
50+
inp1 = torch.full(
51+
[tensor_size], rank_id, dtype=tensor_dtype, device=device
52+
)
53+
inp1_ref = inp1.clone()
54+
out1 = torch.empty_like(inp1)
55+
nvshmem_allreduce.all_reduce(inp1, out1)
56+
torch.distributed.all_reduce(inp1_ref, group=group)
57+
torch.cuda.synchronize()
58+
torch.testing.assert_close(out1, inp1_ref)
59+
torch.distributed.barrier(group)
60+
except Exception as e:
61+
print(f"Rank {rank_id}: Exception during test: {e}")
62+
raise
63+
finally:
64+
torch.distributed.barrier(group)
65+
nvshmem_allreduce.shutdown()
66+
torch.distributed.destroy_process_group(group)
67+
68+
69+
def get_open_port() -> int:
70+
try:
71+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
72+
s.bind(("127.0.0.1", 0))
73+
return s.getsockname()[1]
74+
except OSError:
75+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
76+
s.bind(("::1", 0))
77+
return s.getsockname()[1]
78+
79+
80+
def multi_process_parallel(
81+
world_size: int, test_target: Any, target_args: tuple = ()
82+
) -> None:
83+
mp.set_start_method("spawn", force=True)
84+
85+
procs = []
86+
distributed_init_port = get_open_port()
87+
for i in range(world_size):
88+
proc_args = (world_size, i, distributed_init_port) + target_args
89+
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
90+
proc.start()
91+
procs.append(proc)
92+
93+
for i in range(world_size):
94+
procs[i].join()
95+
assert (
96+
procs[i].exitcode == 0
97+
), f"Process {i} failed with exit code {procs[i].exitcode}"
98+
99+
100+
@pytest.mark.parametrize("world_size", [8])
101+
def test_nvshmem_allreduce(world_size):
102+
available_gpus = torch.cuda.device_count()
103+
if world_size > available_gpus:
104+
raise ValueError(
105+
f"world_size {world_size} is greater than available_gpus {available_gpus}"
106+
)
107+
print(f"Running test for world_size={world_size}")
108+
multi_process_parallel(
109+
world_size,
110+
_run_correctness_worker,
111+
target_args=(),
112+
)
113+
print(f"NVSHMEM allreduce tp = {world_size}: OK")

0 commit comments

Comments
 (0)