Skip to content

Commit e4d99b5

Browse files
d4l3kmeta-codesync[bot]
authored andcommitted
manager: fix allreduce reduction scaling (meta-pytorch#286)
Summary: We should only rescale tensors manually for `ReduceOp.AVG`. Pull Request resolved: meta-pytorch#286 Test Plan: Updated test to test all common reductions ``` pytest torchft/manager_test.py ``` Reviewed By: tushar00jain Differential Revision: D84879364 Pulled By: d4l3k fbshipit-source-id: 6c32348466fe920de71183c5fa8014427a8de121
1 parent b3be7ad commit e4d99b5

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

torchft/manager.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def allreduce(
387387
self,
388388
tensor: torch.Tensor,
389389
should_quantize: bool = False,
390-
reduce_op: ReduceOp = ReduceOp.SUM,
390+
reduce_op: ReduceOp = ReduceOp.AVG,
391391
) -> Work:
392392
"""
393393
Fault tolerant allreduce the tensor and return a Future that will be completed when
@@ -416,21 +416,30 @@ def allreduce(
416416
if not self.is_participating():
417417
tensor.zero_()
418418

419+
# special logic for average
420+
pg_reduce_op = reduce_op
421+
if reduce_op == ReduceOp.AVG:
422+
if not torch.is_floating_point(tensor):
423+
raise ValueError(
424+
"average reduce op is only supported for floating point tensors"
425+
)
426+
pg_reduce_op = ReduceOp.SUM
427+
419428
# TODO: increase timeout when waiting when healing
420429
try:
421430
# Run the allreduce async and save the work object so we can wait on
422431
# it later.
423432
if should_quantize and IS_TRITON_AVAILABLE:
424433
work = allreduce_quantized(
425434
[tensor],
426-
reduce_op,
435+
pg_reduce_op,
427436
self._pg,
428437
# pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
429438
torch.accelerator.current_stream(),
430439
)
431440
else:
432441
opts = AllreduceOptions()
433-
opts.reduceOp = reduce_op
442+
opts.reduceOp = pg_reduce_op
434443
work = self._pg.allreduce([tensor], opts)
435444

436445
# schedule grad normalization as a continuation
@@ -440,7 +449,7 @@ def callback(
440449
fut: torch.futures.Future[list[torch.Tensor]],
441450
) -> torch.Tensor:
442451
nonlocal tensor
443-
if reduce_op == ReduceOp.SUM:
452+
if reduce_op == ReduceOp.AVG:
444453
tensor /= num_participants
445454
return tensor
446455

torchft/manager_test.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from unittest.mock import create_autospec, MagicMock, patch
1414

1515
import torch
16-
from torch.distributed import TCPStore
16+
from torch.distributed import ReduceOp, TCPStore
1717

1818
from torchft._torchft import QuorumResult
1919
from torchft.checkpointing._rwlock import RWLock
@@ -590,10 +590,28 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
590590
manager._pg.allreduce.return_value = _DummyWork(None)
591591

592592
self.assertTrue(manager.is_participating())
593-
tensor = torch.tensor([1.0])
594-
work = manager.allreduce(tensor)
595-
work.wait()
596-
torch.testing.assert_close(tensor, torch.tensor([1.0 / 5]))
593+
594+
for dtype in (torch.float16, torch.bfloat16, torch.float32, torch.long):
595+
orig = torch.tensor([10], dtype=dtype)
596+
597+
if torch.is_floating_point(orig):
598+
tensor = orig.clone()
599+
manager.allreduce(tensor).wait()
600+
torch.testing.assert_close(tensor, orig / 5)
601+
602+
tensor = orig.clone()
603+
manager.allreduce(tensor, reduce_op=ReduceOp.AVG).wait()
604+
torch.testing.assert_close(tensor, orig / 5)
605+
606+
for reduce_op in [
607+
ReduceOp.SUM,
608+
ReduceOp.MAX,
609+
ReduceOp.MIN,
610+
ReduceOp.PRODUCT,
611+
]:
612+
tensor = orig.clone()
613+
manager.allreduce(tensor, reduce_op=reduce_op).wait()
614+
torch.testing.assert_close(tensor, orig)
597615

598616
# check healing numerics
599617
manager._healing = True

0 commit comments

Comments
 (0)