Skip to content

Commit 2d1895e

Browse files
ilmarkovilmarkovmgoin
authored andcommitted
[PERF] PyTorch Symmetric Memory All-Reduce (vllm-project#20759)
Signed-off-by: ilmarkov <[email protected]> Signed-off-by: ilmarkov <[email protected]> Signed-off-by: Michael Goin <[email protected]> Co-authored-by: ilmarkov <[email protected]> Co-authored-by: Michael Goin <[email protected]>
1 parent d0f3baf commit 2d1895e

File tree

8 files changed

+279
-5
lines changed

8 files changed

+279
-5
lines changed

docs/design/multiprocessing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.
7777

7878
There are other miscellaneous places hard-coding the use of `spawn`:
7979

80-
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
80+
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
8181
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>
8282

8383
Related PRs:
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import random
5+
import typing
6+
7+
import pytest
8+
import torch
9+
import torch.distributed as dist
10+
import torch.multiprocessing as mp
11+
12+
import vllm.envs as envs
13+
from vllm.distributed import cleanup_dist_env_and_memory
14+
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
15+
from vllm.distributed.device_communicators.cuda_communicator import (
16+
CudaCommunicator)
17+
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
18+
get_tp_group,
19+
init_distributed_environment,
20+
initialize_model_parallel)
21+
from vllm.platforms import current_platform
22+
from vllm.utils import update_environment_variables
23+
24+
torch.manual_seed(42)
25+
random.seed(44)
26+
27+
test_size_elements = 4 * 1024 * 1024
28+
29+
30+
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
31+
monkeypatch = pytest.MonkeyPatch()
32+
with monkeypatch.context() as m:
33+
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
34+
dtype = torch.bfloat16
35+
device = torch.device(f"cuda:{local_rank}")
36+
torch.cuda.set_device(device)
37+
torch.set_default_device(device)
38+
torch.set_default_dtype(dtype)
39+
update_environment_variables({
40+
'RANK': str(local_rank),
41+
'LOCAL_RANK': str(local_rank),
42+
'WORLD_SIZE': str(world_size),
43+
'MASTER_ADDR': 'localhost',
44+
'MASTER_PORT': '12345',
45+
})
46+
47+
init_distributed_environment()
48+
initialize_model_parallel(tensor_model_parallel_size=world_size)
49+
50+
cuda_communicator = typing.cast(CudaCommunicator,
51+
get_tp_group().device_communicator)
52+
symm_mem_comm = cuda_communicator.symm_mem_comm
53+
if symm_mem_comm is None or symm_mem_comm.disabled:
54+
pytest.skip("SymmMemCommunicator is not available or disabled.")
55+
56+
inp_direct_symm_mem = torch.randint(1,
57+
23, (test_size_elements, ),
58+
dtype=dtype,
59+
device=device)
60+
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
61+
pytest.skip(
62+
"SymmMemCommunicator isn't used for this world and input size."
63+
)
64+
65+
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
66+
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
67+
assert out_direct_symm_mem is not None
68+
69+
group = get_tensor_model_parallel_group().device_group
70+
dist.all_reduce(original_inp_direct_symm_mem, group=group)
71+
torch.testing.assert_close(out_direct_symm_mem,
72+
original_inp_direct_symm_mem,
73+
atol=2.5,
74+
rtol=0.1)
75+
76+
# Test tensor_model_parallel_all_reduce which should use symm_mem
77+
inp_tensor_parallel = torch.randint(-23,
78+
1, (test_size_elements, ),
79+
dtype=dtype,
80+
device=device)
81+
original_inp_tensor_parallel = inp_tensor_parallel.clone()
82+
out_tensor_parallel = tensor_model_parallel_all_reduce(
83+
inp_tensor_parallel)
84+
dist.all_reduce(original_inp_tensor_parallel, group=group)
85+
torch.testing.assert_close(out_tensor_parallel,
86+
original_inp_tensor_parallel,
87+
atol=2.5,
88+
rtol=0.1)
89+
90+
91+
@pytest.mark.skipif(
92+
not current_platform.is_cuda(),
93+
reason="SymmMemAllreduce is only available for CUDA platforms.")
94+
@pytest.mark.parametrize("tp_size", [2])
95+
@pytest.mark.parametrize("pipeline_parallel_size", [1])
96+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
97+
reason="Only test on CUDA")
98+
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
99+
pipeline_parallel_size):
100+
world_size = tp_size * pipeline_parallel_size
101+
if world_size > torch.cuda.device_count():
102+
pytest.skip("Not enough GPUs to run the test.")
103+
104+
# Enable SymmMemCommunicator
105+
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
106+
107+
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
108+
cleanup_dist_env_and_memory()

tools/check_pickle_imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
'vllm/distributed/utils.py',
3838
'vllm/distributed/parallel_state.py',
3939
'vllm/engine/multiprocessing/client.py',
40-
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
40+
'vllm/distributed/device_communicators/all_reduce_utils.py',
4141
'vllm/distributed/device_communicators/shm_broadcast.py',
4242
'vllm/engine/multiprocessing/engine.py',
4343
'benchmarks/kernels/graph_machete_bench.py',

vllm/distributed/device_communicators/custom_all_reduce_utils.py renamed to vllm/distributed/device_communicators/all_reduce_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,39 @@
2323

2424
logger = init_logger(__name__)
2525

26+
MiB = 1024 * 1024
27+
# Max size for each world size in case symmetric memory is available
28+
# For different SM architectures
29+
CUSTOM_ALL_REDUCE_MAX_SIZES = {
30+
"9.0": {
31+
2: 64 * MiB, # 64 MB
32+
4: 32 * MiB, # 32 MB
33+
6: MiB // 2, # 512 KB
34+
8: MiB // 4, # 256 KB
35+
},
36+
"10.0": {
37+
2: 2 * MiB, # 2 MB
38+
4: 2 * MiB, # 2 MB
39+
6: 2 * MiB, # 2 MB
40+
8: 2 * MiB, # 2 MB
41+
}
42+
}
43+
44+
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
45+
"9.0": {
46+
2: 64 * MiB, # 64 MB
47+
4: 32 * MiB, # 32 MB
48+
6: 64 * MiB, # 64 MB
49+
8: 64 * MiB, # 64 MB
50+
},
51+
"10.0": {
52+
2: 8 * MiB, # 8 MB
53+
4: 32 * MiB, # 32 MB
54+
6: 128 * MiB, # 128 MB
55+
8: 128 * MiB, # 128 MB
56+
}
57+
}
58+
2659

2760
def producer(batch_src: Sequence[int],
2861
producer_queue,

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def __init__(self,
4444
PyNcclCommunicator)
4545
from vllm.distributed.device_communicators.quick_all_reduce import (
4646
QuickAllReduce)
47+
from vllm.distributed.device_communicators.symm_mem import (
48+
SymmMemCommunicator)
4749

4850
self.pynccl_comm: Optional[PyNcclCommunicator] = None
4951
if use_pynccl and self.world_size > 1:
@@ -54,6 +56,7 @@ def __init__(self,
5456

5557
self.ca_comm: Optional[CustomAllreduce] = None
5658
self.qr_comm: Optional[QuickAllReduce] = None
59+
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
5760
if use_custom_allreduce and self.world_size > 1:
5861
# Initialize a custom fast all-reduce implementation.
5962
self.ca_comm = CustomAllreduce(
@@ -69,6 +72,12 @@ def __init__(self,
6972
# currently be an MI300 series.
7073
self.qr_comm = QuickAllReduce(group=self.cpu_group,
7174
device=self.device)
75+
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
76+
self.symm_mem_comm = SymmMemCommunicator(
77+
group=self.cpu_group,
78+
device=self.device,
79+
)
80+
7281
if self.use_all2all:
7382
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
7483
if all2all_backend == "naive":
@@ -105,6 +114,12 @@ def all_reduce(self, input_):
105114
out = ca_comm.custom_all_reduce(input_)
106115
assert out is not None
107116
return out
117+
symm_mem_comm = self.symm_mem_comm
118+
if symm_mem_comm is not None and \
119+
symm_mem_comm.should_use_symm_mem(input_):
120+
out = symm_mem_comm.all_reduce(input_)
121+
assert out is not None
122+
return out
108123
pynccl_comm = self.pynccl_comm
109124
assert pynccl_comm is not None
110125
out = pynccl_comm.all_reduce(input_)

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import vllm.envs as envs
1212
from vllm import _custom_ops as ops
13-
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
14-
gpu_p2p_access_check)
13+
from vllm.distributed.device_communicators.all_reduce_utils import (
14+
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
1515
from vllm.distributed.parallel_state import in_the_same_node_as
1616
from vllm.logger import init_logger
1717
from vllm.platforms import current_platform
@@ -109,7 +109,13 @@ def __init__(self,
109109
# now `device` is a `torch.device` object
110110
assert isinstance(device, torch.device)
111111
self.device = device
112-
112+
device_capability = current_platform.get_device_capability(
113+
).as_version_str()
114+
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
115+
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
116+
max_size = min(
117+
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
118+
max_size)
113119
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
114120
if cuda_visible_devices:
115121
device_ids = list(map(int, cuda_visible_devices.split(",")))
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Optional, Union
4+
5+
import torch
6+
import torch.distributed as dist
7+
from torch.distributed import ProcessGroup
8+
9+
from vllm.distributed.device_communicators.all_reduce_utils import (
10+
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
11+
from vllm.logger import init_logger
12+
from vllm.platforms import current_platform
13+
14+
try:
15+
import torch.distributed._symmetric_memory as torch_symm_mem
16+
17+
symm_mem_available = True
18+
except ImportError:
19+
symm_mem_available = False
20+
21+
logger = init_logger(__name__)
22+
23+
24+
class SymmMemCommunicator:
25+
_WORLD_SIZES_MULTIMEM = {
26+
"9.0": [4, 6, 8],
27+
"10.0": [6, 8],
28+
}
29+
30+
def __init__(self, group: ProcessGroup, device: Union[int, str,
31+
torch.device]):
32+
self.disabled = True
33+
34+
if not symm_mem_available:
35+
return
36+
37+
if not current_platform.is_cuda():
38+
logger.warning("SymmMemCommunicator: symmetric "
39+
"memory is not available.")
40+
return
41+
if isinstance(device, int):
42+
device = torch.device(f"cuda:{device}")
43+
elif isinstance(device, str):
44+
device = torch.device(device)
45+
torch.cuda.set_device(device)
46+
self.dtype = torch.bfloat16
47+
self.device = device
48+
self.group = group
49+
self.world_size = dist.get_world_size(self.group)
50+
self.device_capability = current_platform.get_device_capability(
51+
).as_version_str()
52+
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
53+
logger.warning(
54+
"SymmMemCommunicator: Device capability %s not supported, "
55+
"communicator is not available.",
56+
self.device_capability,
57+
)
58+
return
59+
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
60+
self.device_capability]:
61+
logger.warning(
62+
"SymmMemCommunicator: World size %d not supported, "
63+
"communicator is not available.",
64+
self.world_size,
65+
)
66+
return
67+
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
68+
self.world_size]
69+
self.buffer = torch_symm_mem.empty(
70+
self.max_size // self.dtype.itemsize,
71+
device=self.device,
72+
dtype=self.dtype,
73+
)
74+
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
75+
if handle.multicast_ptr == 0:
76+
logger.warning("SymmMemCommunicator: symmetric memory "
77+
"multicast operations are not supported.")
78+
return
79+
self.disabled = False
80+
81+
def should_use_symm_mem(self, inp: torch.Tensor):
82+
if self.disabled:
83+
return False
84+
if inp.dtype != self.dtype:
85+
return False
86+
inp_size = inp.numel() * inp.element_size()
87+
if inp_size % 4 != 0:
88+
return False
89+
return inp_size < self.max_size
90+
91+
def all_reduce(
92+
self,
93+
inp: torch.Tensor,
94+
*,
95+
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
96+
if not self.should_use_symm_mem(inp):
97+
return None
98+
if out is None:
99+
out = torch.empty_like(inp)
100+
self.buffer[:inp.numel()].copy_(inp.view(-1))
101+
if self.world_size in self._WORLD_SIZES_MULTIMEM[
102+
self.device_capability]:
103+
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
104+
"sum",
105+
self.group.group_name)
106+
else:
107+
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
108+
"sum",
109+
self.group.group_name)
110+
out.copy_(self.buffer[:inp.numel()].view(out.shape))
111+
return out

vllm/envs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
VLLM_HAS_FLASHINFER_CUBIN: bool = False
162162
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
163163
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
164+
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
164165
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
165166

166167

0 commit comments

Comments
 (0)