11import os
2+
23from unittest .mock import Mock , patch
34
45from dask_pytorch .dispatch import run , dispatch_with_ddp
@@ -31,16 +32,40 @@ def test_run():
3132 output = run (client , fake_pytorch_func )
3233
3334 client .submit .assert_any_call (
34- dispatch_with_ddp , fake_pytorch_func , host , 23456 , 0 , len (workers ), workers = [worker_keys [0 ]]
35+ dispatch_with_ddp ,
36+ pytorch_function = fake_pytorch_func ,
37+ master_addr = host ,
38+ master_port = 23456 ,
39+ rank = 0 ,
40+ world_size = len (workers ),
41+ backend = "nccl" ,
3542 )
3643 client .submit .assert_any_call (
37- dispatch_with_ddp , fake_pytorch_func , host , 23456 , 1 , len (workers ), workers = [worker_keys [1 ]]
44+ dispatch_with_ddp ,
45+ pytorch_function = fake_pytorch_func ,
46+ master_addr = host ,
47+ master_port = 23456 ,
48+ rank = 1 ,
49+ world_size = len (workers ),
50+ backend = "nccl" ,
3851 )
3952 client .submit .assert_any_call (
40- dispatch_with_ddp , fake_pytorch_func , host , 23456 , 2 , len (workers ), workers = [worker_keys [2 ]]
53+ dispatch_with_ddp ,
54+ pytorch_function = fake_pytorch_func ,
55+ master_addr = host ,
56+ master_port = 23456 ,
57+ rank = 2 ,
58+ world_size = len (workers ),
59+ backend = "nccl" ,
4160 )
4261 client .submit .assert_any_call (
43- dispatch_with_ddp , fake_pytorch_func , host , 23456 , 3 , len (workers ), workers = [worker_keys [3 ]]
62+ dispatch_with_ddp ,
63+ pytorch_function = fake_pytorch_func ,
64+ master_addr = host ,
65+ master_port = 23456 ,
66+ rank = 3 ,
67+ world_size = len (workers ),
68+ backend = "nccl" ,
4469 )
4570 assert output == fake_results
4671
@@ -51,7 +76,17 @@ def test_dispatch_with_ddp():
5176 with patch .object (os , "environ" , {}) as environ , patch (
5277 "dask_pytorch.dispatch.dist" , return_value = Mock ()
5378 ) as dist :
54- dispatch_with_ddp (pytorch_func , "master_addr" , 2343 , 1 , 10 , "a" , "b" , foo = "bar" )
79+ dispatch_with_ddp (
80+ pytorch_func ,
81+ "master_addr" ,
82+ 2343 ,
83+ 1 ,
84+ 10 ,
85+ "nccl" ,
86+ "a" ,
87+ "b" ,
88+ foo = "bar" ,
89+ )
5590 assert environ ["MASTER_ADDR" ] == "master_addr"
5691 assert environ ["MASTER_PORT" ] == "2343"
5792 assert environ ["RANK" ] == "1"
0 commit comments