Skip to content

Commit 369d7d1

Browse files
committed
fix
1 parent 97b39f5 commit 369d7d1

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

infinity/utils/dist.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout
3434
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
3535
local_rank = global_rank % num_gpus
3636
torch.cuda.set_device(local_rank)
37+
print(f"global_rank:{global_rank} local_rank:{local_rank} num_gpus:{num_gpus}")
3738

3839
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
3940
"""
@@ -42,7 +43,9 @@ def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout
4243
print(f'[dist initialize] mp method={method}')
4344
mp.set_start_method(method)
4445
"""
45-
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))
46+
tdist.init_process_group(backend=backend,
47+
device_id=torch.device(f'cuda:{local_rank}'),
48+
timeout=datetime.timedelta(seconds=timeout_minutes * 60))
4649

4750
global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
4851
__local_rank = local_rank

0 commit comments

Comments
 (0)