Skip to content

implement send and recv using collective_permute #9373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,56 @@ def test_all_to_all_single(self, use_dynamo):
expected.sort().values),
f"Got {val}, expected {expected}")

@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last time we checked, we also noticed that https://github.com/pytorch/xla/blob/master/test/test_mp_collective_permute.py didn't work on the CPU, but send/recv did. We might want to double check it.

Is test/test_torch_distributed_xla_backend.py tested for CPU and Neuron? Would it be possible to test it and see if the change is compatible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is test/test_torch_distributed_xla_backend.py tested for CPU and Neuron? Would it be possible to test it and see if the change is compatible?

It is, but it just checks that the expected IR is emitted. It doesn't run anything. And in this case it wasn't a reliable test because, at least for TPU, that IR does not actually run.

test_mp_collective_permute is run for both TPU and Neuron. I don't think it works for CPU but neither do send/recv. The success of test_mp_collective_permute indicates this change should work for Neuron, but to be more certain I could add a test that covers a pipeline-like transfer in addition to the existing test of a permutation-like transfer.

The most direct test would be something like what's in test_collective_ops_tpu.py, which runs the ops to completion, for Neuron.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The most direct test would be something like what's in test_collective_ops_tpu.py, which runs the ops to completion, for Neuron.

This would be great. Any chance we can move it outside of this file and make it general? I can help test it out if so. Otherwise, I'll need to follow up if we can port this entire file to Neuron. I see tpu.num_expected_global_devices, and pjrt.run_multiprocess, but haven't seen/used these before.

def _send_recv_pipeline():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
cutoff = world_size // 2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if the world size is not even, this test will hang. For example, if world size is 3, then index 0 will send to 1 and 1 will recv from 0, but index 2 will try to recv from 1 without an associated send.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll update the test so that it is more defensive

index = xr.global_ordinal()
tensor = torch.tensor([index], dtype=torch.float, device=device)
if index < cutoff:
dist.send(tensor, index + cutoff)
else:
dist.recv(tensor, index - cutoff)
return tensor.cpu()

@staticmethod
def _send_recv_permute():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
index = xr.global_ordinal()
sending_tensor = torch.tensor([index], dtype=torch.float, device=device)
receiving_tensor = torch.tensor([-1.0], dtype=torch.float, device=device)
if index % 2 == 0:
dist.send(sending_tensor, (index + 1) % world_size)
dist.recv(receiving_tensor, (index - 1) % world_size)
else:
dist.recv(receiving_tensor, (index - 1) % world_size)
dist.send(sending_tensor, (index + 1) % world_size)
return receiving_tensor.cpu()

@absltest.skipUnless(tpu.num_available_devices() % 2 == 0,
"Send/Recv test requires even number of devices")
def test_send_recv_pipeline(self):
"""Send tensors on first N/2 devices to second N/2 devices."""
results = pjrt.run_multiprocess(self._send_recv_pipeline)
world_size = tpu.num_expected_global_devices()
for ordinal, value in results.items():
expected = ordinal if ordinal < world_size // 2 else ordinal - world_size // 2
np.testing.assert_array_equal(value, [expected])

@absltest.skipUnless(tpu.num_available_devices() % 2 == 0,
"Send/Recv test requires even number of devices")
def test_send_recv_permute(self):
"""Send tensor on device i to i + 1 (module world size)."""
results = pjrt.run_multiprocess(self._send_recv_permute)
world_size = tpu.num_expected_global_devices()
for ordinal, value in results.items():
expected = (ordinal - 1) % world_size
np.testing.assert_array_equal(value, [expected])


if __name__ == '__main__':
absltest.main()
41 changes: 0 additions & 41 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,47 +166,6 @@ def test_reduce_scatter_coalesced(self):
# purge all computations attached the device.
torch_xla.sync()

@patch_world(0, 6)
def test_send(self):
device = torch_xla.device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
input_list = [tensor]

with mock.patch.object(
torch_xla.distributed.xla_backend.ProcessGroupXla,
'make_send_channel_id',
new=lambda self, dst_rank, tag: dst_rank * 2):
dist.send(tensor, 1)

send_pattern = r'%send\.\d+ = .+ send\(.+\), channel_id=2'
senddone_pattern = r'%send\-done\.\d+ = .+ send\-done\(.+\), channel_id=2'
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, send_pattern)
hlo_matches(hlo, senddone_pattern)

# Don't try to run Send on CPU because it's not implemented
torch_xla._XLAC._clear_pending_irs(str(torch_xla.device()))

@patch_world(0, 6)
def test_recv(self):
device = torch_xla.device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()

with mock.patch.object(
torch_xla.distributed.xla_backend.ProcessGroupXla,
'make_recv_channel_id',
new=lambda self, src_rank, tag: src_rank * 3):
dist.recv(tensor, 1)

recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3'
recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3'
hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor])
hlo_matches(hlo, recv_pattern)
hlo_matches(hlo, recvdone_pattern)

# Don't try to run Recv on CPU because it's not implemented
torch_xla._XLAC._clear_pending_irs(str(torch_xla.device()))

@patch_world(rank=0, size=12)
def test_new_group_no_ranks(self):
with new_group_barrier_disabled():
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,9 +748,6 @@ def collective_permute(value: torch.Tensor,
pairs: List[List[int]]) -> torch.Tensor:
"""Performs a XLA `CollectivePermute()` operation on the input tensor.

WARNING: This function is not very reliable, may produce wrong results under
certain inputs. Use it at your own risk.

Comment on lines -751 to -753
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in #8815 there's no context for this ancient warning. Given the age, lack of details, and lack of any other reported bugs I think it's best to remove it. If we get a specific bug report then we can act on that.

See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute

Args:
Expand Down
40 changes: 17 additions & 23 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import rendezvous
Expand Down Expand Up @@ -270,41 +271,34 @@ def scatter(self, output_tensor_list: list[torch.Tensor],
rs_opts.reduceOp = dist.ReduceOp.SUM
return self.reduce_scatter(output_tensor_list, inputs, rs_opts)

# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace
# the maker with their specific one. See unit test in
# test/test_torch_distributed_xla_backend.py for an example.
def make_send_channel_id(self, dst_rank, tag):
raise NotImplementedError

# Call site e.g.
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L877
def send(self, tensors, dst_rank, tag=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're warning to use collective_permute, but it still ends up using a collective permute, should the warning itself be clearer that this is happening under the hood?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could word this better. The real advice is to restructure your code so that each process calls collective_permute with all of the send-recv pairs

logging.warning(
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs."
)
results = []
for t in tensors:
channel_id = self.make_send_channel_id(dst_rank, tag)
# The input will be returned as result.
input_as_result = xm.send(t, channel_id)
# Make the sent tensor depend on the token, such that the `send`
# op can actually be built into the computation graph.
with torch.no_grad():
t.copy_(input_as_result)
results.append(input_as_result)
result_t = xm.collective_permute(
t, pairs=[[xr.global_ordinal(), dst_rank]])
torch_xla.sync()
results.append(result_t)
return _ret_work(results)

# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace
# the maker with their specific one. See unit test in
# test/test_torch_distributed_xla_backend.py for an example.
def make_recv_channel_id(self, src_rank, tag):
raise NotImplementedError

# Call site e.g.
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913
def recv(self, out_tensors, src_rank, tag=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the warning on the recv end too, so each host has it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not assume someone reading "recv" will have read the documentation for "send". I think we should add documentation here. I would then add a note specific about what the IR expectation will be for "send" and "recv" on each of their comments.

logging.warning(
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs."
)
results = []
for ot in out_tensors:
channel_id = self.make_recv_channel_id(src_rank, tag)
result = xm.recv(ot, channel_id)
results.append(result)
result_t = xm.collective_permute(
ot, pairs=[[src_rank, xr.global_ordinal()]])
torch_xla.sync()
with torch.no_grad():
ot.copy_(result_t)
results.append(result_t)
return _ret_work(results)

def recv_anysource(self, *args):
Expand Down
Loading