-
Notifications
You must be signed in to change notification settings - Fork 558
implement send and recv using collective_permute #9373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
ae8052e
e3f75e0
1259414
81624bb
b65e2cf
37f2e63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -359,6 +359,56 @@ def test_all_to_all_single(self, use_dynamo): | |
expected.sort().values), | ||
f"Got {val}, expected {expected}") | ||
|
||
@staticmethod | ||
def _send_recv_pipeline(): | ||
dist.init_process_group("xla", init_method='xla://') | ||
device = torch_xla.device() | ||
world_size = xr.world_size() | ||
cutoff = world_size // 2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if the world size is not even, this test will hang. For example, if world size is 3, then index 0 will send to 1 and 1 will recv from 0, but index 2 will try to recv from 1 without an associated send. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. I'll update the test so that it is more defensive |
||
index = xr.global_ordinal() | ||
tensor = torch.tensor([index], dtype=torch.float, device=device) | ||
if index < cutoff: | ||
dist.send(tensor, index + cutoff) | ||
else: | ||
dist.recv(tensor, index - cutoff) | ||
return tensor.cpu() | ||
|
||
@staticmethod | ||
def _send_recv_permute(): | ||
dist.init_process_group("xla", init_method='xla://') | ||
device = torch_xla.device() | ||
world_size = xr.world_size() | ||
index = xr.global_ordinal() | ||
sending_tensor = torch.tensor([index], dtype=torch.float, device=device) | ||
receiving_tensor = torch.tensor([-1.0], dtype=torch.float, device=device) | ||
if index % 2 == 0: | ||
dist.send(sending_tensor, (index + 1) % world_size) | ||
dist.recv(receiving_tensor, (index - 1) % world_size) | ||
else: | ||
dist.recv(receiving_tensor, (index - 1) % world_size) | ||
dist.send(sending_tensor, (index + 1) % world_size) | ||
return receiving_tensor.cpu() | ||
|
||
@absltest.skipUnless(tpu.num_available_devices() % 2 == 0, | ||
"Send/Recv test requires even number of devices") | ||
def test_send_recv_pipeline(self): | ||
"""Send tensors on first N/2 devices to second N/2 devices.""" | ||
results = pjrt.run_multiprocess(self._send_recv_pipeline) | ||
world_size = tpu.num_expected_global_devices() | ||
for ordinal, value in results.items(): | ||
expected = ordinal if ordinal < world_size // 2 else ordinal - world_size // 2 | ||
np.testing.assert_array_equal(value, [expected]) | ||
|
||
@absltest.skipUnless(tpu.num_available_devices() % 2 == 0, | ||
"Send/Recv test requires even number of devices") | ||
def test_send_recv_permute(self): | ||
"""Send tensor on device i to i + 1 (module world size).""" | ||
results = pjrt.run_multiprocess(self._send_recv_permute) | ||
world_size = tpu.num_expected_global_devices() | ||
for ordinal, value in results.items(): | ||
expected = (ordinal - 1) % world_size | ||
np.testing.assert_array_equal(value, [expected]) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -748,9 +748,6 @@ def collective_permute(value: torch.Tensor, | |
pairs: List[List[int]]) -> torch.Tensor: | ||
"""Performs a XLA `CollectivePermute()` operation on the input tensor. | ||
|
||
WARNING: This function is not very reliable, may produce wrong results under | ||
certain inputs. Use it at your own risk. | ||
|
||
Comment on lines
-751
to
-753
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed in #8815 there's no context for this ancient warning. Given the age, lack of details, and lack of any other reported bugs I think it's best to remove it. If we get a specific bug report then we can act on that. |
||
See: https://www.tensorflow.org/xla/operation_semantics#collectivepermute | ||
|
||
Args: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import torch | ||
import torch.distributed as dist | ||
import torch_xla | ||
import torch_xla.core.xla_model as xm | ||
import torch_xla.runtime as xr | ||
from torch_xla._internal import rendezvous | ||
|
@@ -270,41 +271,34 @@ def scatter(self, output_tensor_list: list[torch.Tensor], | |
rs_opts.reduceOp = dist.ReduceOp.SUM | ||
return self.reduce_scatter(output_tensor_list, inputs, rs_opts) | ||
|
||
# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace | ||
# the maker with their specific one. See unit test in | ||
# test/test_torch_distributed_xla_backend.py for an example. | ||
def make_send_channel_id(self, dst_rank, tag): | ||
raise NotImplementedError | ||
|
||
# Call site e.g. | ||
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L877 | ||
def send(self, tensors, dst_rank, tag=0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we're warning to use collective_permute, but it still ends up using a collective permute, should the warning itself be clearer that this is happening under the hood? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could word this better. The real advice is to restructure your code so that each process calls collective_permute with all of the send-recv pairs |
||
logging.warning( | ||
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs." | ||
) | ||
results = [] | ||
for t in tensors: | ||
channel_id = self.make_send_channel_id(dst_rank, tag) | ||
# The input will be returned as result. | ||
input_as_result = xm.send(t, channel_id) | ||
# Make the sent tensor depend on the token, such that the `send` | ||
# op can actually be built into the computation graph. | ||
with torch.no_grad(): | ||
t.copy_(input_as_result) | ||
results.append(input_as_result) | ||
result_t = xm.collective_permute( | ||
t, pairs=[[xr.global_ordinal(), dst_rank]]) | ||
torch_xla.sync() | ||
results.append(result_t) | ||
return _ret_work(results) | ||
|
||
# Dummy channel id maker. Different backend (TPU, GPU, etc) should replace | ||
# the maker with their specific one. See unit test in | ||
# test/test_torch_distributed_xla_backend.py for an example. | ||
def make_recv_channel_id(self, src_rank, tag): | ||
raise NotImplementedError | ||
|
||
# Call site e.g. | ||
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L913 | ||
def recv(self, out_tensors, src_rank, tag=0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need the warning on the recv end too, so each host has it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not assume someone reading "recv" will have read the documentation for "send". I think we should add documentation here. I would then add a note specific about what the IR expectation will be for "send" and "recv" on each of their comments. |
||
logging.warning( | ||
"Individual send/recv ops are inefficient on an XLA device. Consider using xla_model.collective_permute() and specifying all source-target pairs." | ||
) | ||
results = [] | ||
for ot in out_tensors: | ||
channel_id = self.make_recv_channel_id(src_rank, tag) | ||
result = xm.recv(ot, channel_id) | ||
results.append(result) | ||
result_t = xm.collective_permute( | ||
ot, pairs=[[src_rank, xr.global_ordinal()]]) | ||
torch_xla.sync() | ||
with torch.no_grad(): | ||
ot.copy_(result_t) | ||
results.append(result_t) | ||
return _ret_work(results) | ||
|
||
def recv_anysource(self, *args): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last time we checked, we also noticed that https://github.com/pytorch/xla/blob/master/test/test_mp_collective_permute.py didn't work on the CPU, but send/recv did. We might want to double check it.
Is
test/test_torch_distributed_xla_backend.py
tested for CPU and Neuron? Would it be possible to test it and see if the change is compatible?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, but it just checks that the expected IR is emitted. It doesn't run anything. And in this case it wasn't a reliable test because, at least for TPU, that IR does not actually run.
test_mp_collective_permute is run for both TPU and Neuron. I don't think it works for CPU but neither do send/recv. The success of test_mp_collective_permute indicates this change should work for Neuron, but to be more certain I could add a test that covers a pipeline-like transfer in addition to the existing test of a permutation-like transfer.
The most direct test would be something like what's in test_collective_ops_tpu.py, which runs the ops to completion, for Neuron.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be great. Any chance we can move it outside of this file and make it general? I can help test it out if so. Otherwise, I'll need to follow up if we can port this entire file to Neuron. I see
tpu.num_expected_global_devices
, andpjrt.run_multiprocess
, but haven't seen/used these before.