Skip to content

Commit aad0d31

Browse files
committed
Add hopper support
Signed-off-by: ilmarkov <[email protected]>
1 parent 1541ee7 commit aad0d31

File tree

7 files changed

+70
-32
lines changed

7 files changed

+70
-32
lines changed

docs/design/v1/multiprocessing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.
7777

7878
There are other miscellaneous places hard-coding the use of `spawn`:
7979

80-
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
80+
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
8181
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>
8282

8383
Related PRs:

tools/check_pickle_imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
'vllm/distributed/utils.py',
3939
'vllm/distributed/parallel_state.py',
4040
'vllm/engine/multiprocessing/client.py',
41-
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
41+
'vllm/distributed/device_communicators/all_reduce_utils.py',
4242
'vllm/distributed/device_communicators/shm_broadcast.py',
4343
'vllm/engine/multiprocessing/engine.py',
4444
'benchmarks/kernels/graph_machete_bench.py',

vllm/distributed/device_communicators/custom_all_reduce_utils.py renamed to vllm/distributed/device_communicators/all_reduce_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,39 @@
2323

2424
logger = init_logger(__name__)
2525

26+
MiB = 1024 * 1024
27+
# Max size for each world size in case symmetric memory is available
28+
# For different SM architectures
29+
CUSTOM_ALL_REDUCE_MAX_SIZES = {
30+
"9.0": {
31+
2: 64 * MiB, # 64 MB
32+
4: 32 * MiB, # 32 MB
33+
6: MiB // 2, # 512 KB
34+
8: MiB // 4, # 256 KB
35+
},
36+
"10.0": {
37+
2: 2 * MiB, # 2 MB
38+
4: 2 * MiB, # 2 MB
39+
6: 2 * MiB, # 2 MB
40+
8: 2 * MiB, # 2 MB
41+
}
42+
}
43+
44+
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
45+
"9.0": {
46+
2: 64 * MiB, # 64 MB
47+
4: 32 * MiB, # 32 MB
48+
6: 64 * MiB, # 64 MB
49+
8: 64 * MiB, # 64 MB
50+
},
51+
"10.0": {
52+
2: 8 * MiB, # 8 MB
53+
4: 32 * MiB, # 32 MB
54+
6: 128 * MiB, # 128 MB
55+
8: 128 * MiB, # 128 MB
56+
}
57+
}
58+
2659

2760
def producer(batch_src: Sequence[int],
2861
producer_queue,

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def all_reduce(self, input_):
115115
assert out is not None
116116
return out
117117
symm_mem_comm = self.symm_mem_comm
118-
if symm_mem_comm is not None and not symm_mem_comm.disabled and \
118+
if symm_mem_comm is not None and \
119119
symm_mem_comm.should_use_symm_mem(input_):
120120
out = symm_mem_comm.all_reduce(input_)
121121
assert out is not None

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
import vllm.envs as envs
1212
from vllm import _custom_ops as ops
13-
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
14-
gpu_p2p_access_check)
13+
from vllm.distributed.device_communicators.all_reduce_utils import (
14+
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
1515
from vllm.distributed.parallel_state import in_the_same_node_as
1616
from vllm.logger import init_logger
1717
from vllm.platforms import current_platform
@@ -49,14 +49,6 @@ def is_weak_contiguous(inp: torch.Tensor):
4949
class CustomAllreduce:
5050

5151
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
52-
MiB = 1024 * 1024
53-
# Max sizes for each world size in case symmetric memory is available
54-
_MAX_SIZES = {
55-
2: 2 * MiB, # 1 MB
56-
4: 2 * MiB, # 1 MB
57-
6: MiB, # 512 KB
58-
8: MiB // 2, # 512 KB
59-
}
6052

6153
# max_size: max supported allreduce size
6254
def __init__(self,
@@ -117,9 +109,13 @@ def __init__(self,
117109
# now `device` is a `torch.device` object
118110
assert isinstance(device, torch.device)
119111
self.device = device
120-
if current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM:
121-
max_size = CustomAllreduce._MAX_SIZES[world_size]
122-
112+
device_capability = current_platform.get_device_capability(
113+
).as_version_str()
114+
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
115+
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
116+
max_size = min(
117+
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
118+
max_size)
123119
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
124120
if cuda_visible_devices:
125121
device_ids = list(map(int, cuda_visible_devices.split(",")))

vllm/distributed/device_communicators/symm_mem.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch.distributed as dist
77
from torch.distributed import ProcessGroup
88

9+
from vllm.distributed.device_communicators.all_reduce_utils import (
10+
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
911
from vllm.logger import init_logger
1012
from vllm.platforms import current_platform
1113

@@ -20,13 +22,9 @@
2022

2123

2224
class SymmMemCommunicator:
23-
MiB = 1024 * 1024
24-
# Max sizes for each world size
25-
_MAX_SIZES = {
26-
2: 8 * MiB,
27-
4: 32 * MiB,
28-
6: 128 * MiB,
29-
8: 128 * MiB,
25+
_WORLD_SIZES_MULTIMEM = {
26+
"9.0": [4, 6, 8],
27+
"10.0": [6, 8],
3028
}
3129

3230
def __init__(self, group: ProcessGroup, device: Union[int, str,
@@ -49,15 +47,27 @@ def __init__(self, group: ProcessGroup, device: Union[int, str,
4947
self.device = device
5048
self.group = group
5149
self.world_size = dist.get_world_size(self.group)
52-
if self.world_size not in self._MAX_SIZES:
50+
self.device_capability = current_platform.get_device_capability(
51+
).as_version_str()
52+
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
53+
logger.warning(
54+
"SymmMemCommunicator: Device capability %s not supported, "
55+
"communicator is not available.",
56+
self.device_capability,
57+
)
58+
return
59+
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
60+
self.device_capability]:
5361
logger.warning(
5462
"SymmMemCommunicator: World size %d not supported, "
5563
"communicator is not available.",
5664
self.world_size,
5765
)
5866
return
67+
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
68+
self.world_size]
5969
self.buffer = torch_symm_mem.empty(
60-
self._MAX_SIZES[self.world_size] // self.dtype.itemsize,
70+
self.max_size // self.dtype.itemsize,
6171
device=self.device,
6272
dtype=self.dtype,
6373
)
@@ -76,7 +86,7 @@ def should_use_symm_mem(self, inp: torch.Tensor):
7686
inp_size = inp.numel() * inp.element_size()
7787
if inp_size % 4 != 0:
7888
return False
79-
return inp_size <= self._MAX_SIZES[self.world_size]
89+
return inp_size < self.max_size
8090

8191
def all_reduce(
8292
self,
@@ -88,14 +98,13 @@ def all_reduce(
8898
if out is None:
8999
out = torch.empty_like(inp)
90100
self.buffer[:inp.numel()].copy_(inp.view(-1))
91-
if self.world_size in [2, 4]:
92-
# Use two-shot all-reduce for 2 and 4 GPUs
93-
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
101+
if self.world_size in self._WORLD_SIZES_MULTIMEM[
102+
self.device_capability]:
103+
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
94104
"sum",
95105
self.group.group_name)
96106
else:
97-
# Use multi-mem all-reduce for 6 and 8 GPUs
98-
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
107+
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
99108
"sum",
100109
self.group.group_name)
101110
out.copy_(self.buffer[:inp.numel()].view(out.shape))

vllm/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ def get_vllm_port() -> Optional[int]:
625625
("1", "true")),
626626

627627
# By default, vLLM will check the peer-to-peer capability itself,
628-
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa
628+
# in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/all_reduce_utils.py#L101-L108 for details. # noqa
629629
# If this env var is set to 1, vLLM will skip the peer-to-peer check,
630630
# and trust the driver's peer-to-peer capability report.
631631
"VLLM_SKIP_P2P_CHECK":

0 commit comments

Comments
 (0)