diff --git a/test/pjrt/test_collective_ops_tpu.py b/test/pjrt/test_collective_ops_tpu.py index 7ee9e7d8a66f..a554edb47f90 100644 --- a/test/pjrt/test_collective_ops_tpu.py +++ b/test/pjrt/test_collective_ops_tpu.py @@ -359,6 +359,56 @@ def test_all_to_all_single(self, use_dynamo): expected.sort().values), f"Got {val}, expected {expected}") + @staticmethod + def _send_recv_pipeline(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + cutoff = world_size // 2 + index = xr.global_ordinal() + tensor = torch.tensor([index], dtype=torch.float, device=device) + if index < cutoff: + dist.send(tensor, index + cutoff) + else: + dist.recv(tensor, index - cutoff) + return tensor.cpu() + + @staticmethod + def _send_recv_permute(): + dist.init_process_group("xla", init_method='xla://') + device = torch_xla.device() + world_size = xr.world_size() + index = xr.global_ordinal() + sending_tensor = torch.tensor([index], dtype=torch.float, device=device) + receiving_tensor = torch.tensor([-1.0], dtype=torch.float, device=device) + if index % 2 == 0: + dist.send(sending_tensor, (index + 1) % world_size) + dist.recv(receiving_tensor, (index - 1) % world_size) + else: + dist.recv(receiving_tensor, (index - 1) % world_size) + dist.send(sending_tensor, (index + 1) % world_size) + return receiving_tensor.cpu() + + @absltest.skipUnless(tpu.num_available_devices() % 2 == 0, + "Send/Recv test requires even number of devices") + def test_send_recv_pipeline(self): + """Send tensors on first N/2 devices to second N/2 devices.""" + results = pjrt.run_multiprocess(self._send_recv_pipeline) + world_size = tpu.num_expected_global_devices() + for ordinal, value in results.items(): + expected = ordinal if ordinal < world_size // 2 else ordinal - world_size // 2 + np.testing.assert_array_equal(value, [expected]) + + @absltest.skipUnless(tpu.num_available_devices() % 2 == 0, + "Send/Recv test requires even number of devices") + def test_send_recv_permute(self): + """Send tensor on device i to i + 1 (module world size).""" + results = pjrt.run_multiprocess(self._send_recv_permute) + world_size = tpu.num_expected_global_devices() + for ordinal, value in results.items(): + expected = (ordinal - 1) % world_size + np.testing.assert_array_equal(value, [expected]) + if __name__ == '__main__': absltest.main() diff --git a/test/test_torch_distributed_xla_backend.py b/test/test_torch_distributed_xla_backend.py index 99b721a4fa16..a421d881aca0 100644 --- a/test/test_torch_distributed_xla_backend.py +++ b/test/test_torch_distributed_xla_backend.py @@ -166,47 +166,6 @@ def test_reduce_scatter_coalesced(self): # purge all computations attached the device. torch_xla.sync() - @patch_world(0, 6) - def test_send(self): - device = torch_xla.device() - tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() - input_list = [tensor] - - with mock.patch.object( - torch_xla.distributed.xla_backend.ProcessGroupXla, - 'make_send_channel_id', - new=lambda self, dst_rank, tag: dst_rank * 2): - dist.send(tensor, 1) - - send_pattern = r'%send\.\d+ = .+ send\(.+\), channel_id=2' - senddone_pattern = r'%send\-done\.\d+ = .+ send\-done\(.+\), channel_id=2' - hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) - hlo_matches(hlo, send_pattern) - hlo_matches(hlo, senddone_pattern) - - # Don't try to run Send on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) - - @patch_world(0, 6) - def test_recv(self): - device = torch_xla.device() - tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank() - - with mock.patch.object( - torch_xla.distributed.xla_backend.ProcessGroupXla, - 'make_recv_channel_id', - new=lambda self, src_rank, tag: src_rank * 3): - dist.recv(tensor, 1) - - recv_pattern = r'%recv\.\d+ = .+ recv\(.+\), channel_id=3' - recvdone_pattern = r'%recv\-done\.\d+ = .+ recv\-done\(.+\), channel_id=3' - hlo = torch_xla._XLAC._get_xla_tensors_hlo([tensor]) - hlo_matches(hlo, recv_pattern) - hlo_matches(hlo, recvdone_pattern) - - # Don't try to run Recv on CPU because it's not implemented - torch_xla._XLAC._clear_pending_irs(str(torch_xla.device())) - @patch_world(rank=0, size=12) def test_new_group_no_ranks(self): with new_group_barrier_disabled(): diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 6b68e656d333..3dbad1a963eb 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -748,9 +748,6 @@ def collective_permute(value: torch.Tensor, pairs: List[List[int]]) -> torch.Tensor: """Performs a XLA `CollectivePermute()` operation on the input tensor. - WARNING: This function is not very reliable, may produce wrong results under - certain inputs. Use it at your own risk. - See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute Args: diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index daef50c243dc..d35782713283 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -1,5 +1,6 @@ import torch import torch.distributed as dist +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import rendezvous @@ -270,41 +271,34 @@ def scatter(self, output_tensor_list: list[torch.Tensor], rs_opts.reduceOp = dist.ReduceOp.SUM return self.reduce_scatter(output_tensor_list, inputs, rs_opts) - # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace - # the maker with their specific one. See unit test in - # test/test_torch_distributed_xla_backend.py for an example. - def make_send_channel_id(self, dst_rank, tag): - raise NotImplementedError - # Call site e.g. # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L877 def send(self, tensors, dst_rank, tag=0): + logging.warning( + "Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs." + ) results = [] for t in tensors: - channel_id = self.make_send_channel_id(dst_rank, tag) - # The input will be returned as result. - input_as_result = xm.send(t, channel_id) - # Make the sent tensor depend on the token, such that the `send` - # op can actually be built into the computation graph. - with torch.no_grad(): - t.copy_(input_as_result) - results.append(input_as_result) + result_t = xm.collective_permute( + t, pairs=[[xr.global_ordinal(), dst_rank]]) + torch_xla.sync() + results.append(result_t) return _ret_work(results) - # Dummy channel id maker. Different backend (TPU, GPU, etc) should replace - # the maker with their specific one. See unit test in - # test/test_torch_distributed_xla_backend.py for an example. - def make_recv_channel_id(self, src_rank, tag): - raise NotImplementedError - # Call site e.g. # https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913 def recv(self, out_tensors, src_rank, tag=0): + logging.warning( + "Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs." + ) results = [] for ot in out_tensors: - channel_id = self.make_recv_channel_id(src_rank, tag) - result = xm.recv(ot, channel_id) - results.append(result) + result_t = xm.collective_permute( + ot, pairs=[[src_rank, xr.global_ordinal()]]) + torch_xla.sync() + with torch.no_grad(): + ot.copy_(result_t) + results.append(result_t) return _ret_work(results) def recv_anysource(self, *args):