11# Copyright (c) OpenMMLab. All rights reserved.
2- import os
32from unittest .mock import MagicMock
43
54import torch
65import torch .distributed as torch_dist
76import torch .nn as nn
7+ from torch .testing ._internal .common_distributed import DistributedTestBase
88
99from mmengine .dist import all_gather
1010from mmengine .hooks import SyncBuffersHook
1111from mmengine .registry import MODELS
12- from mmengine .testing ._internal import MultiProcessTestCase
1312from mmengine .testing .runner_test_case import RunnerTestCase , ToyModel
1413
1514
@@ -23,22 +22,14 @@ def __init__(self, data_preprocessor=None):
2322 def init_weights (self ):
2423 for buffer in self .buffers ():
2524 buffer .fill_ (
26- torch .tensor (int ( os . environ [ 'RANK' ] ), dtype = torch .float32 ))
25+ torch .tensor (torch_dist . get_rank ( ), dtype = torch .float32 ))
2726 return super ().init_weights ()
2827
2928
30- class TestSyncBuffersHook (MultiProcessTestCase , RunnerTestCase ):
31-
32- def setUp (self ) -> None :
33- super ().setUp ()
34- self ._spawn_processes ()
35-
36- def prepare_subprocess (self ):
37- MODELS .register_module (module = ToyModuleWithNorm , force = True )
38- super (MultiProcessTestCase , self ).setUp ()
29+ class TestSyncBuffersHook (DistributedTestBase , RunnerTestCase ):
3930
4031 def test_sync_buffers_hook (self ):
41- self .setup_dist_env ( )
32+ self .create_pg ( 'cuda' )
4233 runner = MagicMock ()
4334 runner .model = ToyModuleWithNorm ()
4435 runner .model .init_weights ()
@@ -53,9 +44,12 @@ def test_sync_buffers_hook(self):
5344 for buffer in runner .model .buffers ():
5445 buffer1 , buffer2 = all_gather (buffer )
5546 self .assertTrue (torch .allclose (buffer1 , buffer2 ))
47+ torch_dist .destroy_process_group ()
5648
5749 def test_with_runner (self ):
58- self .setup_dist_env ()
50+ MODELS .register_module (module = ToyModuleWithNorm , force = True )
51+ self .create_pg ('cuda' )
52+ RunnerTestCase .setUp (self )
5953 cfg = self .epoch_based_cfg
6054 cfg .model = dict (type = 'ToyModuleWithNorm' )
6155 cfg .launch = 'pytorch'
@@ -67,8 +61,6 @@ def test_with_runner(self):
6761 buffer1 , buffer2 = all_gather (buffer )
6862 self .assertTrue (torch .allclose (buffer1 , buffer2 ))
6963
70- def setup_dist_env (self ):
71- super ().setup_dist_env ()
72- os .environ ['RANK' ] = str (self .rank )
73- torch_dist .init_process_group (
74- backend = 'gloo' , rank = self .rank , world_size = self .world_size )
64+ @property
65+ def world_size (self ) -> int :
66+ return 2
0 commit comments