Skip to content

Commit a83d27a

Browse files
committed
fix(fabric): raise on CPU tensor passed to all_reduce in non-CPU setup (#21530)
1 parent c05cadb commit a83d27a

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

src/lightning/fabric/fabric.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,8 +701,22 @@ def all_reduce(
701701
A tensor of the same shape as the input with values reduced pointwise across processes. The same is
702702
applied to tensors in a collection if a collection is given as input.
703703
704+
Raises:
705+
RuntimeError: If a CPU tensor is passed while running on a non-CPU device. Move the tensor to
706+
``fabric.device`` before calling ``all_reduce``.
707+
704708
"""
705709
self._validate_launched()
710+
if self.device.type != "cpu":
711+
712+
def _validate_tensor_device(tensor: Tensor) -> None:
713+
if tensor.device.type == "cpu":
714+
raise RuntimeError(
715+
"`Fabric.all_reduce` received a CPU tensor while running on a non-CPU device. Move the tensor"
716+
" to fabric.device before calling all_reduce."
717+
)
718+
719+
apply_to_collection(data, Tensor, _validate_tensor_device)
706720
group = group if group is not None else torch.distributed.group.WORLD
707721
data = convert_to_tensors(data, device=self.device)
708722
return apply_to_collection(data, Tensor, self._strategy.all_reduce, group=group, reduce_op=reduce_op)

tests/tests_fabric/test_fabric.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,17 @@ def test_all_reduce():
12141214
fabric._strategy.all_reduce.assert_has_calls([call(torch.tensor(4), **defaults), call(torch.tensor(5), **defaults)])
12151215

12161216

1217+
def test_all_reduce_raises_for_cpu_tensor_on_non_cpu_root_device():
1218+
"""Test that `Fabric.all_reduce()` raises an error when a CPU tensor is passed while the root device is not CPU."""
1219+
fabric = Fabric()
1220+
fabric._strategy = Mock(root_device=torch.device("cuda", 0))
1221+
fabric._launched = True
1222+
1223+
tensor_cpu = torch.tensor(1.0, device="cpu")
1224+
with pytest.raises(RuntimeError, match=r"Move the tensor to fabric.device before calling all_reduce"):
1225+
fabric.all_reduce(tensor_cpu)
1226+
1227+
12171228
def test_rank_zero_first(monkeypatch):
12181229
"""Test that rank 0 completes first before all other processes can execute under `.rank_zero_first()`."""
12191230

0 commit comments

Comments
 (0)