Skip to content

Commit fb74214

Browse files
committed
[Test] Fix sync buffer hook
1 parent 42aa0eb commit fb74214

File tree

1 file changed

+11
-19
lines changed

1 file changed

+11
-19
lines changed
Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
import os
32
from unittest.mock import MagicMock
43

54
import torch
65
import torch.distributed as torch_dist
76
import torch.nn as nn
7+
from torch.testing._internal.common_distributed import DistributedTestBase
88

99
from mmengine.dist import all_gather
1010
from mmengine.hooks import SyncBuffersHook
1111
from mmengine.registry import MODELS
12-
from mmengine.testing._internal import MultiProcessTestCase
1312
from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel
1413

1514

@@ -23,22 +22,14 @@ def __init__(self, data_preprocessor=None):
2322
def init_weights(self):
2423
for buffer in self.buffers():
2524
buffer.fill_(
26-
torch.tensor(int(os.environ['RANK']), dtype=torch.float32))
25+
torch.tensor(torch_dist.get_rank(), dtype=torch.float32))
2726
return super().init_weights()
2827

2928

30-
class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase):
31-
32-
def setUp(self) -> None:
33-
super().setUp()
34-
self._spawn_processes()
35-
36-
def prepare_subprocess(self):
37-
MODELS.register_module(module=ToyModuleWithNorm, force=True)
38-
super(MultiProcessTestCase, self).setUp()
29+
class TestSyncBuffersHook(DistributedTestBase, RunnerTestCase):
3930

4031
def test_sync_buffers_hook(self):
41-
self.setup_dist_env()
32+
self.create_pg('cuda')
4233
runner = MagicMock()
4334
runner.model = ToyModuleWithNorm()
4435
runner.model.init_weights()
@@ -53,9 +44,12 @@ def test_sync_buffers_hook(self):
5344
for buffer in runner.model.buffers():
5445
buffer1, buffer2 = all_gather(buffer)
5546
self.assertTrue(torch.allclose(buffer1, buffer2))
47+
torch_dist.destroy_process_group()
5648

5749
def test_with_runner(self):
58-
self.setup_dist_env()
50+
MODELS.register_module(module=ToyModuleWithNorm, force=True)
51+
self.create_pg('cuda')
52+
RunnerTestCase.setUp(self)
5953
cfg = self.epoch_based_cfg
6054
cfg.model = dict(type='ToyModuleWithNorm')
6155
cfg.launch = 'pytorch'
@@ -67,8 +61,6 @@ def test_with_runner(self):
6761
buffer1, buffer2 = all_gather(buffer)
6862
self.assertTrue(torch.allclose(buffer1, buffer2))
6963

70-
def setup_dist_env(self):
71-
super().setup_dist_env()
72-
os.environ['RANK'] = str(self.rank)
73-
torch_dist.init_process_group(
74-
backend='gloo', rank=self.rank, world_size=self.world_size)
64+
@property
65+
def world_size(self) -> int:
66+
return 2

0 commit comments

Comments
 (0)