From 4c0d0b8443666889379c953266798cc1055b62d8 Mon Sep 17 00:00:00 2001 From: Phillip Liu Date: Mon, 7 Jul 2025 16:25:51 -0700 Subject: [PATCH] Updated in broadcast_str to use correct tensor size Summary: Tensor size from source rank was [] before the fix while on other ranks tensor size was [1]. Broadcasting from [] to [1] should be illegal usage. The bug heppened to not cause any failures. Differential Revision: D77901586 --- tests/utils/test_distributed_gpu.py | 7 ++++++- torchtnt/utils/distributed.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/utils/test_distributed_gpu.py b/tests/utils/test_distributed_gpu.py index abae65a80a..942ca2b6c5 100644 --- a/tests/utils/test_distributed_gpu.py +++ b/tests/utils/test_distributed_gpu.py @@ -86,12 +86,17 @@ def test_spawn_multi_process(self) -> None: def test_broadcast_str(self) -> None: spawn_multi_process(2, "gloo", self._test_broadcast_str) + @skip_if_not_gpu + @skip_if_not_distributed + def test_broadcast_str_gpu(self) -> None: + spawn_multi_process(2, "nccl", self._test_broadcast_str) + @staticmethod def _test_broadcast_str() -> None: """ Tests that test_broadcast_strworks as expected """ - + init_from_env() val = None if dist.get_rank() == 0: val = "foo" diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index f5524e7d23..385b75da10 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -738,7 +738,7 @@ def broadcast_str( # convert string to tensor buffer = torch.frombuffer(val.encode("utf-8"), dtype=torch.uint8) buffer = buffer.to(device=device) - buffer_length = torch.tensor((len(buffer)), dtype=torch.int, device=device) + buffer_length = torch.tensor([len(buffer)], dtype=torch.int, device=device) if fixed_buffer_size is not None: if len(buffer) > fixed_buffer_size: