-
-
Notifications
You must be signed in to change notification settings - Fork 10.1k
[PERF] PyTorch Symmetric Memory All-Reduce #20759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
aa3efaa
ff33a5a
e2e8e0c
f3a267c
0bf3002
f5b5f42
04d3f48
aa5e7d2
eacc031
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import random | ||
import typing | ||
|
||
import pytest | ||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
|
||
import vllm.envs as envs | ||
from vllm.distributed import cleanup_dist_env_and_memory | ||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce | ||
from vllm.distributed.device_communicators.cuda_communicator import ( | ||
CudaCommunicator) | ||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, | ||
get_tp_group, | ||
init_distributed_environment, | ||
initialize_model_parallel) | ||
from vllm.platforms import current_platform | ||
from vllm.utils import update_environment_variables | ||
|
||
torch.manual_seed(42) | ||
random.seed(44) | ||
|
||
test_size_elements = 4 * 1024 * 1024 | ||
|
||
|
||
def symm_mem_allreduce_worker(local_rank: int, world_size: int): | ||
monkeypatch = pytest.MonkeyPatch() | ||
with monkeypatch.context() as m: | ||
m.delenv("CUDA_VISIBLE_DEVICES", raising=False) | ||
dtype = torch.bfloat16 | ||
device = torch.device(f"cuda:{local_rank}") | ||
torch.cuda.set_device(device) | ||
torch.set_default_device(device) | ||
torch.set_default_dtype(dtype) | ||
update_environment_variables({ | ||
'RANK': str(local_rank), | ||
'LOCAL_RANK': str(local_rank), | ||
'WORLD_SIZE': str(world_size), | ||
'MASTER_ADDR': 'localhost', | ||
'MASTER_PORT': '12345', | ||
}) | ||
|
||
init_distributed_environment() | ||
initialize_model_parallel(tensor_model_parallel_size=world_size) | ||
|
||
cuda_communicator = typing.cast(CudaCommunicator, | ||
get_tp_group().device_communicator) | ||
symm_mem_comm = cuda_communicator.symm_mem_comm | ||
if symm_mem_comm is None or symm_mem_comm.disabled: | ||
pytest.skip("SymmMemCommunicator is not available or disabled.") | ||
|
||
inp_direct_symm_mem = torch.randint(1, | ||
23, (test_size_elements, ), | ||
dtype=dtype, | ||
device=device) | ||
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): | ||
pytest.skip( | ||
"SymmMemCommunicator isn't used for this world and input size." | ||
) | ||
|
||
original_inp_direct_symm_mem = inp_direct_symm_mem.clone() | ||
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) | ||
assert out_direct_symm_mem is not None | ||
|
||
group = get_tensor_model_parallel_group().device_group | ||
dist.all_reduce(original_inp_direct_symm_mem, group=group) | ||
torch.testing.assert_close(out_direct_symm_mem, | ||
original_inp_direct_symm_mem, | ||
atol=2.5, | ||
rtol=0.1) | ||
|
||
# Test tensor_model_parallel_all_reduce which should use symm_mem | ||
inp_tensor_parallel = torch.randint(-23, | ||
1, (test_size_elements, ), | ||
dtype=dtype, | ||
device=device) | ||
original_inp_tensor_parallel = inp_tensor_parallel.clone() | ||
out_tensor_parallel = tensor_model_parallel_all_reduce( | ||
inp_tensor_parallel) | ||
dist.all_reduce(original_inp_tensor_parallel, group=group) | ||
torch.testing.assert_close(out_tensor_parallel, | ||
original_inp_tensor_parallel, | ||
atol=2.5, | ||
rtol=0.1) | ||
|
||
|
||
@pytest.mark.skipif( | ||
not current_platform.is_cuda(), | ||
reason="SymmMemAllreduce is only available for CUDA platforms.") | ||
@pytest.mark.parametrize("tp_size", [2]) | ||
@pytest.mark.parametrize("pipeline_parallel_size", [1]) | ||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], | ||
reason="Only test on CUDA") | ||
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, | ||
pipeline_parallel_size): | ||
world_size = tp_size * pipeline_parallel_size | ||
if world_size > torch.cuda.device_count(): | ||
pytest.skip("Not enough GPUs to run the test.") | ||
|
||
# Enable SymmMemCommunicator | ||
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") | ||
|
||
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) | ||
cleanup_dist_env_and_memory() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
from typing import Optional, Union | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed import ProcessGroup | ||
|
||
from vllm.distributed.device_communicators.all_reduce_utils import ( | ||
SYMM_MEM_ALL_REDUCE_MAX_SIZES) | ||
from vllm.logger import init_logger | ||
from vllm.platforms import current_platform | ||
|
||
try: | ||
import torch.distributed._symmetric_memory as torch_symm_mem | ||
|
||
symm_mem_available = True | ||
except ImportError: | ||
symm_mem_available = False | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class SymmMemCommunicator: | ||
_WORLD_SIZES_MULTIMEM = { | ||
"9.0": [4, 6, 8], | ||
"10.0": [6, 8], | ||
} | ||
|
||
def __init__(self, group: ProcessGroup, device: Union[int, str, | ||
torch.device]): | ||
self.disabled = True | ||
|
||
if not symm_mem_available: | ||
return | ||
|
||
if not current_platform.is_cuda(): | ||
logger.warning("SymmMemCommunicator: symmetric " | ||
"memory is not available.") | ||
return | ||
if isinstance(device, int): | ||
device = torch.device(f"cuda:{device}") | ||
elif isinstance(device, str): | ||
device = torch.device(device) | ||
torch.cuda.set_device(device) | ||
self.dtype = torch.bfloat16 | ||
self.device = device | ||
self.group = group | ||
self.world_size = dist.get_world_size(self.group) | ||
self.device_capability = current_platform.get_device_capability( | ||
).as_version_str() | ||
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: | ||
logger.warning( | ||
"SymmMemCommunicator: Device capability %s not supported, " | ||
"communicator is not available.", | ||
self.device_capability, | ||
) | ||
return | ||
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ | ||
self.device_capability]: | ||
logger.warning( | ||
"SymmMemCommunicator: World size %d not supported, " | ||
"communicator is not available.", | ||
self.world_size, | ||
) | ||
return | ||
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ | ||
self.world_size] | ||
self.buffer = torch_symm_mem.empty( | ||
self.max_size // self.dtype.itemsize, | ||
device=self.device, | ||
dtype=self.dtype, | ||
) | ||
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) | ||
if handle.multicast_ptr == 0: | ||
logger.warning("SymmMemCommunicator: symmetric memory " | ||
"multicast operations are not supported.") | ||
return | ||
self.disabled = False | ||
|
||
def should_use_symm_mem(self, inp: torch.Tensor): | ||
if self.disabled: | ||
return False | ||
if inp.dtype != self.dtype: | ||
return False | ||
inp_size = inp.numel() * inp.element_size() | ||
if inp_size % 4 != 0: | ||
return False | ||
return inp_size < self.max_size | ||
|
||
def all_reduce( | ||
self, | ||
inp: torch.Tensor, | ||
*, | ||
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: | ||
if not self.should_use_symm_mem(inp): | ||
return None | ||
if out is None: | ||
out = torch.empty_like(inp) | ||
self.buffer[:inp.numel()].copy_(inp.view(-1)) | ||
if self.world_size in self._WORLD_SIZES_MULTIMEM[ | ||
self.device_capability]: | ||
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you guys check that this is deterministic? nccl claims multimem allreduce is deterministic on newer driver NVIDIA/nccl#1497 (comment), but it's not in the docs |
||
"sum", | ||
self.group.group_name) | ||
else: | ||
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason you are not using |
||
"sum", | ||
self.group.group_name) | ||
out.copy_(self.buffer[:inp.numel()].view(out.shape)) | ||
Comment on lines
+100
to
+110
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is okay. We can talk more on how to optimize away the copy-in and copy-out :) |
||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
SymmMemCommunicator
is hardcoded to usetorch.bfloat16
, limiting its use with models using other dtypes. Consider initializing buffers based on the input tensor's dtype during the firstall_reduce
call to increase flexibility.