Skip to content

Commit 9695016

Browse files
Yifu Wangpobin6
authored andcommitted
[SymmetricMemory] introduce user-facing APIs empty() and rendezvous() (pytorch#139677)
Previously `SymmetricMemory` only had private pybind APIs: ```python from torch.distributed._symmetric_memory import _SymmetricMemory t = _SymmetricMemory.empty_strided_p2p( size=(64,), stride=(1,), dtype=torch.float32, device=device, ) symm_mem_hdl = _SymmetricMemory.rendezvous(t, group_name=group.group_name) ``` This PR introduces user-facing APIs empty() and rendezvous(): ```python import torch.distributed._symmetric_memory as symm_mem t = symm_mem.empty(64, device="cuda") symm_mem_hdl = symm_mem.rendezvous(t, group_name=group.group_name) ``` Notable differences compared to the pybind APIs: - `empty()` now resembles `torch.empty()`: - shape can either be an integer sequence or pack - no need to/can't specify stride anymore - device can either be `torch.device` or string - `group_name` needs to be specified at rendezvous time as opposed to allocation time. See pytorch#139529 for the rationales. I feel the new semantic is superior, hence enforcing it in the public API. - Currently, the pybind API still support specifying `group_name` at rendezvous time. This PR does not change the behavior of the pybind APIs. Pull Request resolved: pytorch#139677 Approved by: https://github.com/lw ghstack dependencies: pytorch#139529
1 parent 9d459cf commit 9695016

File tree

2 files changed

+174
-76
lines changed

2 files changed

+174
-76
lines changed

test/distributed/test_symmetric_memory.py

Lines changed: 66 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
import torch.distributed as dist
8+
import torch.distributed._symmetric_memory as symm_mem
89
from torch._C._autograd import DeviceType
910
from torch._C._distributed_c10d import _SymmetricMemory
1011
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
@@ -81,9 +82,25 @@ def _init_process(self):
8182
rank=self.rank,
8283
store=store,
8384
)
84-
enable_symm_mem_for_group(dist.group.WORLD.group_name)
8585
torch.manual_seed(42 + self.rank)
8686

87+
@skipIfRocm
88+
@skip_if_lt_x_gpu(2)
89+
def test_cuda_nvlink_connectivity_detection(self) -> None:
90+
from torch._C._distributed_c10d import _detect_dma_connectivity
91+
92+
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
93+
self.assertEqual(connectivity.device_type, DeviceType.CUDA)
94+
self.assertEqual(connectivity.connection_type, "nvlink")
95+
self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())
96+
for row in connectivity.matrix:
97+
self.assertEqual(len(row), torch.cuda.device_count())
98+
99+
@skipIfRocm
100+
def test_large_alloc(self) -> None:
101+
t = symm_mem.empty(2 * 1024**3, dtype=torch.uint8, device="cuda")
102+
self.assertEqual(t.numel() * t.element_size(), 2 * 1024**3)
103+
87104
def _get_test_alloc_args(self):
88105
shape = (64, 64)
89106
stride = (64, 1)
@@ -92,64 +109,56 @@ def _get_test_alloc_args(self):
92109
group_name = "0"
93110
return (shape, stride, dtype, device, group_name)
94111

95-
def _verify_symmetric_memory(self, symm_mem):
96-
self.assertEqual(symm_mem.world_size, 2)
112+
def _verify_symmetric_memory(self, symm_mem_hdl):
113+
self.assertEqual(symm_mem_hdl.world_size, 2)
97114

98-
buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32)
115+
buf = symm_mem_hdl.get_buffer(
116+
0, (symm_mem_hdl.buffer_size // 4,), torch.float32
117+
)
99118
self.assertEqual(buf.storage_offset(), 0)
100-
self.assertEqual(buf.untyped_storage().size(), symm_mem.buffer_size)
119+
self.assertEqual(buf.untyped_storage().size(), symm_mem_hdl.buffer_size)
101120

102-
if symm_mem.rank == 0:
103-
symm_mem.wait_signal(src_rank=1)
121+
if symm_mem_hdl.rank == 0:
122+
symm_mem_hdl.wait_signal(src_rank=1)
104123
self.assertTrue(buf.eq(42).all())
105124
else:
106125
buf.fill_(42)
107-
symm_mem.put_signal(dst_rank=0)
126+
symm_mem_hdl.put_signal(dst_rank=0)
108127

109-
symm_mem.barrier()
128+
symm_mem_hdl.barrier()
110129

111-
if symm_mem.rank == 0:
112-
symm_mem.barrier()
130+
if symm_mem_hdl.rank == 0:
131+
symm_mem_hdl.barrier()
113132
self.assertTrue(buf.eq(43).all())
114133
else:
115134
buf.fill_(43)
116-
symm_mem.barrier()
135+
symm_mem_hdl.barrier()
117136

118-
symm_mem.barrier()
119-
120-
@skipIfRocm
121-
@skip_if_lt_x_gpu(2)
122-
def test_cuda_nvlink_connectivity_detection(self) -> None:
123-
from torch._C._distributed_c10d import _detect_dma_connectivity
124-
125-
connectivity = _detect_dma_connectivity(DeviceType.CUDA, "nvlink")
126-
self.assertEqual(connectivity.device_type, DeviceType.CUDA)
127-
self.assertEqual(connectivity.connection_type, "nvlink")
128-
self.assertEqual(len(connectivity.matrix), torch.cuda.device_count())
129-
for row in connectivity.matrix:
130-
self.assertEqual(len(row), torch.cuda.device_count())
137+
symm_mem_hdl.barrier()
131138

132139
@skipIfRocm
133140
@skip_if_lt_x_gpu(2)
134141
def test_empty_strided_p2p(self) -> None:
135142
self._init_process()
143+
enable_symm_mem_for_group(dist.group.WORLD.group_name)
136144

137145
alloc_args = self._get_test_alloc_args()
138146

139147
t = torch.empty((64, 64), device=self.device)
140148
self.assertIsNone(_SymmetricMemory.rendezvous(t))
141149

142150
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
143-
symm_mem = _SymmetricMemory.rendezvous(t)
151+
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
144152

145153
del t
146-
self._verify_symmetric_memory(symm_mem)
154+
self._verify_symmetric_memory(symm_mem_hdl)
147155
dist.destroy_process_group()
148156

149157
@skipIfRocm
150158
@skip_if_lt_x_gpu(2)
151159
def test_empty_strided_p2p_persistent(self) -> None:
152160
self._init_process()
161+
enable_symm_mem_for_group(dist.group.WORLD.group_name)
153162

154163
alloc_args = self._get_test_alloc_args()
155164

@@ -168,51 +177,47 @@ def test_empty_strided_p2p_persistent(self) -> None:
168177
t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42)
169178
self.assertEqual(t.data_ptr(), data_ptr)
170179

171-
symm_mem = _SymmetricMemory.rendezvous(t)
172-
self._verify_symmetric_memory(symm_mem)
180+
symm_mem_hdl = _SymmetricMemory.rendezvous(t)
181+
self._verify_symmetric_memory(symm_mem_hdl)
173182
dist.destroy_process_group()
174183

175184
@skipIfRocm
176185
@skip_if_lt_x_gpu(2)
177186
def test_get_signal_pad(self) -> None:
178187
self._init_process()
179188

180-
t = _SymmetricMemory.empty_strided_p2p(*self._get_test_alloc_args())
181-
symm_mem = _SymmetricMemory.rendezvous(t)
189+
t = symm_mem.empty(1, device="cuda")
190+
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
182191
peer_rank = (self.rank + 1) % self.world_size
183192

184-
signal_pad = symm_mem.get_signal_pad(self.rank)
185-
self.assertEqual(signal_pad.data_ptr(), symm_mem.signal_pad_ptrs[symm_mem.rank])
193+
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
194+
self.assertEqual(
195+
signal_pad.data_ptr(), symm_mem_hdl.signal_pad_ptrs[symm_mem_hdl.rank]
196+
)
186197

187-
signal_pad = symm_mem.get_signal_pad(peer_rank)
198+
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank)
188199
self.assertEqual(signal_pad.dtype, torch.uint32)
189-
self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 4)
200+
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 4)
190201

191202
# Only specify sizes
192-
signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8))
203+
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8))
193204
self.assertEqual(signal_pad.dtype, torch.uint32)
194205
self.assertEqual(signal_pad.numel(), 64)
195206

196207
# Only specify dtype
197-
signal_pad = symm_mem.get_signal_pad(peer_rank, dtype=torch.uint64)
208+
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, dtype=torch.uint64)
198209
self.assertEqual(signal_pad.dtype, torch.uint64)
199-
self.assertEqual(signal_pad.numel(), symm_mem.signal_pad_size // 8)
210+
self.assertEqual(signal_pad.numel(), symm_mem_hdl.signal_pad_size // 8)
200211

201212
# Specify both sizes and dtype
202-
signal_pad = symm_mem.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
213+
signal_pad = symm_mem_hdl.get_signal_pad(peer_rank, (8, 8), dtype=torch.uint64)
203214
self.assertEqual(signal_pad.dtype, torch.uint64)
204215
self.assertEqual(signal_pad.numel(), 64)
205216

206217
# Sanity check that writes to buffer doesn't corrupt signal_pad
207-
t = _SymmetricMemory.empty_strided_p2p(
208-
(0,),
209-
(0,),
210-
torch.float32,
211-
self.device,
212-
dist.group.WORLD.group_name,
213-
)
214-
symm_mem = _SymmetricMemory.rendezvous(t)
215-
signal_pad = symm_mem.get_signal_pad(self.rank)
218+
t = symm_mem.empty(0, device="cuda")
219+
symm_mem_hdl = symm_mem.rendezvous(t)
220+
signal_pad = symm_mem_hdl.get_signal_pad(self.rank)
216221
signal_pad.fill_(42)
217222
t.fill_(0)
218223
self.assertTrue(signal_pad.eq(42).all())
@@ -224,14 +229,12 @@ def test_get_signal_pad(self) -> None:
224229
def test_barrier_timeout(self) -> None:
225230
self._init_process()
226231

227-
alloc_args = self._get_test_alloc_args()
228-
229-
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
230-
symm_mem = _SymmetricMemory.rendezvous(t)
232+
t = symm_mem.empty(1, device="cuda")
233+
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
231234

232235
if self.rank == 0:
233236
with self.assertRaises(RuntimeError):
234-
symm_mem.barrier(timeout_ms=1000)
237+
symm_mem_hdl.barrier(timeout_ms=1000)
235238
torch.cuda.synchronize()
236239
else:
237240
torch.cuda.synchronize()
@@ -247,17 +250,15 @@ def test_barrier_timeout(self) -> None:
247250
def test_put_signal_timeout(self) -> None:
248251
self._init_process()
249252

250-
alloc_args = self._get_test_alloc_args()
251-
252-
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
253-
symm_mem = _SymmetricMemory.rendezvous(t)
253+
t = symm_mem.empty(1, device="cuda")
254+
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
254255

255256
if self.rank == 0:
256257
with self.assertRaises(RuntimeError):
257258
# First, put a signal into rank 1's signal pad. Since rank 1
258259
# doesn't wait on this signal, the subsequent put will timeout.
259-
symm_mem.put_signal(dst_rank=1)
260-
symm_mem.put_signal(dst_rank=1, timeout_ms=1000)
260+
symm_mem_hdl.put_signal(dst_rank=1)
261+
symm_mem_hdl.put_signal(dst_rank=1, timeout_ms=1000)
261262
torch.cuda.synchronize()
262263
else:
263264
torch.cuda.synchronize()
@@ -273,14 +274,12 @@ def test_put_signal_timeout(self) -> None:
273274
def test_wait_signal_timeout(self) -> None:
274275
self._init_process()
275276

276-
alloc_args = self._get_test_alloc_args()
277-
278-
t = _SymmetricMemory.empty_strided_p2p(*alloc_args)
279-
symm_mem = _SymmetricMemory.rendezvous(t)
277+
t = symm_mem.empty(1, device="cuda")
278+
symm_mem_hdl = _SymmetricMemory.rendezvous(t, group=dist.group.WORLD)
280279

281280
if self.rank == 0:
282281
with self.assertRaises(RuntimeError):
283-
symm_mem.wait_signal(src_rank=1, timeout_ms=1000)
282+
symm_mem_hdl.wait_signal(src_rank=1, timeout_ms=1000)
284283
torch.cuda.synchronize()
285284
else:
286285
torch.cuda.synchronize()
@@ -685,7 +684,6 @@ def _init_process(self):
685684
rank=self.rank,
686685
store=store,
687686
)
688-
enable_symm_mem_for_group(dist.group.WORLD.group_name)
689687
torch.manual_seed(42 + self.rank)
690688

691689
@skipIfRocm
@@ -699,18 +697,10 @@ def test_subgroup(self) -> None:
699697

700698
world = dist.group.WORLD
701699
subgroup = subgroup_0 if world.rank() < world.size() // 2 else subgroup_1
702-
enable_symm_mem_for_group(subgroup.group_name)
703700

704-
t = _SymmetricMemory.empty_strided_p2p(
705-
size=(64,),
706-
stride=(1,),
707-
dtype=torch.float32,
708-
device=self.device,
709-
)
710-
symm_mem_world = _SymmetricMemory.rendezvous(t, group_name=world.group_name)
711-
symm_mem_subgroup = _SymmetricMemory.rendezvous(
712-
t, group_name=subgroup.group_name
713-
)
701+
t = symm_mem.empty(64, device="cuda")
702+
symm_mem_world = symm_mem.rendezvous(t, group=world)
703+
symm_mem_subgroup = symm_mem.rendezvous(t, group=subgroup)
714704

715705
self.assertEqual(symm_mem_world.world_size, world.size())
716706
self.assertEqual(symm_mem_world.rank, world.rank())

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory:
110110
_SymmetricMemory: the symmetric memory workspace associated with the
111111
group.
112112
"""
113+
enable_symm_mem_for_group(group_name)
114+
113115
tensor = _group_name_to_workspace_tensor.get(group_name)
114116
size = tensor.numel() * tensor.element_size() if tensor is not None else 0
115117
if tensor is None or size < min_size:
@@ -1386,3 +1388,109 @@ def _low_contention_reduce_scatter(
13861388
return _low_contention_reduce_scatter_with_workspace(
13871389
tensor, reduce_op, workspace
13881390
)
1391+
1392+
1393+
# =============================================================================
1394+
# User-facing APIs
1395+
# =============================================================================
1396+
1397+
1398+
from typing import Any, overload, Sequence, TYPE_CHECKING, Union
1399+
1400+
from torch.types import _device, _dtype, _int
1401+
1402+
1403+
if TYPE_CHECKING:
1404+
from torch._C._distributed_c10d import ProcessGroup
1405+
1406+
1407+
@overload
1408+
def empty(
1409+
*size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None
1410+
) -> torch.Tensor:
1411+
...
1412+
1413+
1414+
@overload
1415+
def empty(
1416+
size: Sequence[_int],
1417+
*,
1418+
dtype: Optional[_dtype] = None,
1419+
device: Optional[_device] = None,
1420+
) -> torch.Tensor:
1421+
...
1422+
1423+
1424+
def empty( # type: ignore[misc]
1425+
*size: Any,
1426+
dtype: Optional[_dtype] = None,
1427+
device: Optional[_device] = None,
1428+
) -> torch.Tensor:
1429+
r"""
1430+
empty(*size, *, dtype=None, device=None) -> Tensor
1431+
1432+
Similar to :func:`torch.empty()`. The returned tensor can be used by
1433+
:func:`torch._distributed._symmetric_memory.rendezvous()` to establish a
1434+
symmetric memory tensor among participating processes.
1435+
1436+
Args:
1437+
size (int...): a sequence of integers defining the shape of the output tensor.
1438+
Can be a variable number of arguments or a collection like a list or tuple.
1439+
1440+
Keyword args:
1441+
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
1442+
Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).
1443+
device (:class:`torch.device`, optional): the desired device of returned tensor.
1444+
Default: if ``None``, uses the current device for the default tensor type
1445+
(see :func:`torch.set_default_device`). :attr:`device` will be the CPU
1446+
for CPU tensor types and the current CUDA device for CUDA tensor types.
1447+
"""
1448+
if len(size) == 1 and isinstance(size[0], Sequence):
1449+
size = tuple(size[0])
1450+
else:
1451+
size = tuple(size)
1452+
1453+
if dtype is None:
1454+
dtype = torch.get_default_dtype()
1455+
1456+
if device is None:
1457+
device = torch.get_default_device()
1458+
1459+
return _SymmetricMemory.empty_strided_p2p(
1460+
size=size,
1461+
stride=torch._prims_common.make_contiguous_strides_for(size),
1462+
dtype=dtype,
1463+
device=torch.device(device),
1464+
)
1465+
1466+
1467+
def rendezvous(
1468+
tensor: torch.Tensor, group: Union[str, "ProcessGroup"]
1469+
) -> _SymmetricMemory:
1470+
r"""
1471+
rendezvous(tensor, group) -> _SymmetricMemory
1472+
1473+
Establish a symmetric memory tensor among participating processes. This is
1474+
a collective operation.
1475+
1476+
Args:
1477+
tensor (:class:`torch.Tensor`): the local tensor used to establish the symmetric memory tensor.
1478+
It must be allocated via :func:`torch._distributed._symmetric_memory.empty()`. The shape,
1479+
dtype, and device type must be identical across all participating processes.
1480+
group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the
1481+
participating processes. This can be either a group name or a process group object.
1482+
"""
1483+
from torch._C._distributed_c10d import ProcessGroup
1484+
1485+
if isinstance(group, str):
1486+
group_name = group
1487+
elif isinstance(group, ProcessGroup):
1488+
group_name = group.group_name
1489+
else:
1490+
raise TypeError(f"rendezvous: unsupported group type: {type(group)}")
1491+
1492+
enable_symm_mem_for_group(group_name)
1493+
return _SymmetricMemory.rendezvous(tensor, group_name)
1494+
1495+
1496+
__all__ = ["empty", "rendezvous"]

0 commit comments

Comments
 (0)