Skip to content

Commit c9da72f

Browse files
committed
[C10D] Support group ranks in P2POp and batch_isend_irecv
Changes semantic of __repr__ of P2POp: s, d are now group ranks instead of global ranks. I think this is OK since I also updated the field names to make this obvious. Also add mypy annotations Partially addresses RFC 0042 (pytorch/rfcs#71) See more details/motivation in #140460 ghstack-source-id: 6f61786 Pull Request resolved: #141054
1 parent 93aef68 commit c9da72f

File tree

2 files changed

+68
-14
lines changed

2 files changed

+68
-14
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3940,6 +3940,40 @@ def test_send_recv_subgroup(self, async_op, group_rank):
39403940
else:
39413941
c10d.send(x, dst=self.rank - 1, group=subgroup)
39423942

3943+
@requires_nccl()
3944+
@skip_if_lt_x_gpu(4)
3945+
@parametrize("group_rank", [True, False])
3946+
def test_batch_send_recv_subgroup(self, group_rank):
3947+
world_size = 4
3948+
if self.rank >= world_size:
3949+
return
3950+
subgroup = self._init_two_pg2_subgroups(world_size)
3951+
device = torch.device("cuda:%d" % self.rank)
3952+
ops = []
3953+
if self.rank == 0 or self.rank == 2:
3954+
x = torch.empty((10,), device=device)
3955+
if group_rank:
3956+
ops.append(c10d.P2POp(dist.irecv, x, group=subgroup, group_peer=1))
3957+
else:
3958+
ops.append(
3959+
c10d.P2POp(dist.irecv, x, peer=self.rank + 1, group=subgroup)
3960+
)
3961+
3962+
for work in dist.batch_isend_irecv(ops):
3963+
work.wait()
3964+
expected = torch.ones((10,), device=device) * (self.rank + 1)
3965+
self.assertEqual(x, expected)
3966+
else:
3967+
x = torch.ones((10,), device=device) * self.rank
3968+
if group_rank:
3969+
ops.append(c10d.P2POp(dist.isend, x, group=subgroup, group_peer=0))
3970+
else:
3971+
ops.append(
3972+
c10d.P2POp(dist.isend, x, peer=self.rank - 1, group=subgroup)
3973+
)
3974+
for work in dist.batch_isend_irecv(ops):
3975+
work.wait()
3976+
39433977
@requires_nccl()
39443978
@skip_if_lt_x_gpu(4)
39453979
@parametrize("group_rank", [True, False])

torch/distributed/distributed_c10d.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -469,57 +469,61 @@ class P2POp:
469469
The type of ``op`` is either ``torch.distributed.isend`` or
470470
``torch.distributed.irecv``.
471471
tensor (Tensor): Tensor to send or receive.
472-
peer (int): Destination or source rank.
472+
peer (int, optional): Destination or source rank.
473473
group (ProcessGroup, optional): The process group to work on. If None,
474474
the default process group will be used.
475475
tag (int, optional): Tag to match send with recv.
476+
group_peer (int, optional): Destination or source rank.
476477
"""
477478

478479
def __init__(
479480
self,
480481
op: Callable,
481482
tensor: torch.Tensor,
482-
peer: int,
483+
peer: Optional[int] = None,
483484
group: Optional[ProcessGroup] = None,
484485
tag: int = 0,
486+
group_peer: Optional[int] = None,
485487
):
486488
"""Init."""
487489
self.op = op
488490
self.tensor = tensor
489-
self.peer = peer
490-
self.group = group
491+
self.group = _group_or_default_group(group)
492+
self.peer = _canonicalize_group_rank(
493+
self.group, peer, group_peer, return_global=True
494+
)
491495
self.tag = tag
496+
self.group_peer = _canonicalize_group_rank(self.group, peer, group_peer)
492497

493498
def __new__(
494499
cls,
495500
op: Callable,
496501
tensor: torch.Tensor,
497-
peer: int,
502+
peer: Optional[int] = None,
498503
group: Optional[ProcessGroup] = None,
499504
tag: int = 0,
505+
group_peer: Optional[int] = None,
500506
):
501507
"""Create and return a new instance of the class."""
502508
_check_op(op)
503509
_check_single_tensor(tensor, "tensor")
510+
504511
return object.__new__(cls)
505512

506513
def __repr__(self):
507514
my_group_rank = get_rank(self.group)
508-
peer_group_rank = (
509-
get_group_rank(self.group, self.peer) if self.group else self.peer
510-
)
511515
op_name = self.op.__name__
512516
group_name = self.group.group_name if self.group else "default_pg"
513517
if "send" in op_name:
514518
s = my_group_rank
515-
d = peer_group_rank
519+
d = self.group_peer
516520
elif "recv" in op_name:
517-
s = peer_group_rank
521+
s = self.group_peer
518522
d = my_group_rank
519523
else:
520524
return super().__repr__()
521525

522-
return f"P2POp({op_name} pg={group_name}, s={s}, d={d}, {self.tensor.shape}, {self.tensor.dtype})"
526+
return f"P2POp({op_name} pg={group_name}, group_src={s}, group_dst={d}, {self.tensor.shape}, {self.tensor.dtype})"
523527

524528

525529
class _CollOp:
@@ -2546,7 +2550,7 @@ def _coalescing_manager(
25462550
work.wait() # type: ignore[possibly-undefined]
25472551

25482552

2549-
def batch_isend_irecv(p2p_op_list):
2553+
def batch_isend_irecv(p2p_op_list: List[P2POp]) -> List[Work]:
25502554
"""
25512555
Send or Receive a batch of tensors asynchronously and return a list of requests.
25522556
@@ -2589,17 +2593,33 @@ def batch_isend_irecv(p2p_op_list):
25892593
_check_p2p_op_list(p2p_op_list)
25902594
group = p2p_op_list[0].group
25912595
device = p2p_op_list[0].tensor.device
2596+
2597+
def peer_kwarg(op: P2POp) -> Dict[str, int]:
2598+
key = "group_dst" if op.op == isend else "group_src"
2599+
return {key: op.group_peer}
2600+
25922601
if device.type == "cuda":
25932602
# NCCL style coalescing
25942603
with _coalescing_manager(group, device, async_ops=True) as cm:
25952604
for p2p_op in p2p_op_list:
2596-
p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
2605+
p2p_op.op(
2606+
p2p_op.tensor,
2607+
group=p2p_op.group,
2608+
tag=p2p_op.tag,
2609+
**peer_kwarg(p2p_op),
2610+
)
2611+
25972612
return cm.works
25982613
else:
25992614
# Backward support for Gloo
26002615
reqs = []
26012616
for p2p_op in p2p_op_list:
2602-
work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag)
2617+
work = p2p_op.op(
2618+
p2p_op.tensor,
2619+
group=p2p_op.group,
2620+
tag=p2p_op.tag,
2621+
**peer_kwarg(p2p_op),
2622+
)
26032623
if work:
26042624
reqs.append(work)
26052625
return reqs

0 commit comments

Comments
 (0)