Skip to content

Commit 6819be1

Browse files
feat(models): multibackend all_to_all wrapper (#95)
Small addition to fallback to support alltoall when using the Gloo backend to torch.distributed. This PR is needed to be able to run the transformer model on CPU. for 99.9% of users running on GPUs with the NCCL background, this change should not effect them Gloo does not offer an alltoall primitive, as shown [here](https://pytorch.org/docs/stable/distributed.html#backends) This commit implements am all_to_all fallback for Gloo, using the 'Linear Shift' algorithm from [Hoffman and Rünger, 2013](https://www.tu-chemnitz.de/informatik/PI/forschung/publikationen/download/HR_eurompi13.pdf). Because of syntax for `torch.dist` changing in torch 2.6, older versions of torch are not supported. --------- Co-authored-by: Harrison Cook <[email protected]>
1 parent 9fc5923 commit 6819be1

File tree

1 file changed

+38
-2
lines changed

1 file changed

+38
-2
lines changed

models/src/anemoi/models/distributed/transformer.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,42 @@
1818
from anemoi.models.distributed.utils import get_memory_format
1919

2020

21+
def _alltoallwrapper(output_list: list, input_list: list, group: ProcessGroup):
22+
"""Wrapper function for all_to_all across NCCL, MPI and Gloo backends.
23+
There is no all_to_all primitive for the Gloo backend. In that case each
24+
process broadcasts its tensor asynchronously.
25+
26+
Retuns nothing but modifies output_list in-place
27+
28+
"""
29+
comm_size = dist.get_world_size(group=group)
30+
31+
if dist.get_backend(group) == "gloo":
32+
33+
# Need to check torch version here bc the syntax for dist.send/recv changed in torch v2.6
34+
torch_version = torch.__version__.split(".")
35+
torch_major_version = int(torch_version[0])
36+
torch_minor_version = int(torch_version[1])
37+
if torch_major_version <= 2 and torch_minor_version < 6:
38+
raise NotImplementedError("Gloo all_to_all not implemented for torch < v2.6")
39+
40+
reqs = []
41+
rank = dist.get_rank(group=group)
42+
# Here we implement the linear shift algorithm from Hofmann and Ruenger, 2013
43+
for i in range(0, comm_size):
44+
j = (i - rank + comm_size) % comm_size
45+
if j != rank:
46+
# exchange data with rank j
47+
reqs.append(dist.isend(input_list[j], group_dst=j, group=group))
48+
reqs.append(dist.irecv(output_list[j], group_src=j, group=group))
49+
else:
50+
output_list[rank] = input_list[rank]
51+
for req in reqs:
52+
req.wait()
53+
else:
54+
dist.all_to_all(output_list, input_list, group=group)
55+
56+
2157
def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = None) -> Tensor:
2258
"""Apply all_to_all along the head dimension.
2359
@@ -52,7 +88,7 @@ def _headsalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] =
5288
for rank in range(comm_size)
5389
]
5490

55-
dist.all_to_all(output_list, input_list, group=group)
91+
_alltoallwrapper(output_list, input_list, group=group)
5692

5793
# Note: torch.cat already creates a contiguous tensor.
5894
return torch.cat(output_list, dim=-2).contiguous(memory_format=input_format)
@@ -79,7 +115,7 @@ def _seqalltoall(input_: Tensor, shapes: list, group: Optional[ProcessGroup] = N
79115

80116
output_list = [torch.empty_like(input_list[comm_rank]) for _ in range(comm_size)]
81117

82-
dist.all_to_all(output_list, input_list, group=group)
118+
_alltoallwrapper(output_list, input_list, group=group)
83119

84120
# Note: torch.cat already creates a contiguous tensor.
85121
return torch.cat(output_list, dim=-3).contiguous(memory_format=input_format)

0 commit comments

Comments
 (0)