diff --git a/src/torchmetrics/utilities/distributed.py b/src/torchmetrics/utilities/distributed.py index a68cffec3f1..0f3f2d62412 100644 --- a/src/torchmetrics/utilities/distributed.py +++ b/src/torchmetrics/utilities/distributed.py @@ -147,7 +147,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens torch.distributed.all_gather(gathered_result, result_padded, group) for idx, item_size in enumerate(local_sizes): slice_param = [slice(dim_size) for dim_size in item_size] - gathered_result[idx] = gathered_result[idx][slice_param] + gathered_result[idx] = gathered_result[idx][tuple(slice_param)] # to propagate autograd graph from local rank gathered_result[torch.distributed.get_rank(group)] = result return gathered_result