Skip to content

Commit ee0ed99

Browse files
authored
Fix the expanding logic of SLURM_JOB_NODELIST and add unit tests for parallel training. (#913)
* Add unit tests of `cluster` and `env`. * Fix the expanding logic of `SLURM_JOB_NODELIST`.
1 parent 689ffa4 commit ee0ed99

File tree

5 files changed

+162
-34
lines changed

5 files changed

+162
-34
lines changed

deepmd/cluster/slurm.py

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
https://github.com/deepsense-ai/tensorflow_on_slurm ####
66
"""
77

8-
import re
8+
import hostlist
99
import os
1010

1111
from deepmd.cluster import local
12-
from typing import List, Tuple, Optional, Iterable
12+
from typing import List, Tuple, Optional
1313

1414
__all__ = ["get_resource"]
1515

@@ -31,7 +31,7 @@ def get_resource() -> Tuple[str, List[str], Optional[List[int]]]:
3131
ValueError
3232
if current nodename is not found in node list
3333
"""
34-
nodelist = _expand_nodelist(os.environ["SLURM_JOB_NODELIST"])
34+
nodelist = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])
3535
nodename = os.environ["SLURMD_NODENAME"]
3636
num_nodes_env = os.getenv("SLURM_JOB_NUM_NODES")
3737
if num_nodes_env:
@@ -49,34 +49,3 @@ def get_resource() -> Tuple[str, List[str], Optional[List[int]]]:
4949
)
5050
gpus = local.get_gpus()
5151
return nodename, nodelist, gpus
52-
53-
54-
def _pad_zeros(iterable: Iterable, length: int):
55-
return (str(t).rjust(length, "0") for t in iterable)
56-
57-
58-
def _expand_ids(ids: str) -> List[str]:
59-
result = []
60-
for _id in ids.split(","):
61-
if "-" in _id:
62-
str_end = _id.split("-")[1]
63-
begin, end = [int(token) for token in _id.split("-")]
64-
result.extend(_pad_zeros(range(begin, end + 1), len(str_end)))
65-
else:
66-
result.append(_id)
67-
return result
68-
69-
70-
def _expand_nodelist(nodelist: str) -> List[str]:
71-
result = []
72-
interval_list = nodelist.split(",")
73-
for interval in interval_list:
74-
match = re.search(r"(.*)\[(.*)\]", interval)
75-
if match:
76-
prefix = match.group(1)
77-
ids = match.group(2)
78-
ids_list = _expand_ids(ids)
79-
result.extend([f"{prefix}{_id}" for _id in ids_list])
80-
else:
81-
result.append(interval)
82-
return result

deepmd/env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"MODEL_VERSION",
3333
"SHARED_LIB_MODULE",
3434
"default_tf_session_config",
35+
"reset_default_tf_session_config",
3536
"op_module",
3637
"op_grads_module",
3738
]

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ numpy
22
scipy
33
pyyaml
44
dargs >= 0.2.6
5+
python-hostlist >= 1.21
56
typing_extensions; python_version < "3.7"

source/tests/test_cluster.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import unittest
2+
3+
from deepmd.cluster import local, slurm
4+
from unittest import mock
5+
6+
7+
kHostName = 'compute-b24-1'
8+
9+
10+
class FakePopen(object):
11+
def __init__(self, stdout=b'', stderr=b'', returncode=0):
12+
self._stdout = stdout
13+
self._stderr = stderr
14+
self._returncode = returncode
15+
16+
def communicate(self):
17+
return self._stdout, self._stderr
18+
19+
@property
20+
def returncode(self):
21+
return self._returncode
22+
23+
24+
class TestGPU(unittest.TestCase):
25+
@mock.patch('subprocess.Popen')
26+
def test_none(self, mock_Popen):
27+
mock_Popen.return_value.__enter__.return_value = FakePopen(b'0', b'')
28+
gpus = local.get_gpus()
29+
self.assertIsNone(gpus)
30+
31+
@mock.patch('subprocess.Popen')
32+
def test_valid(self, mock_Popen):
33+
mock_Popen.return_value.__enter__.return_value = FakePopen(b'2', b'')
34+
gpus = local.get_gpus()
35+
self.assertEqual(gpus, [0, 1])
36+
37+
@mock.patch('subprocess.Popen')
38+
def test_error(self, mock_Popen):
39+
mock_Popen.return_value.__enter__.return_value = \
40+
FakePopen(stderr=b'!', returncode=1)
41+
with self.assertRaises(RuntimeError) as cm:
42+
_ = local.get_gpus()
43+
self.assertIn('Failed to detect', str(cm.exception))
44+
45+
46+
class TestLocal(unittest.TestCase):
47+
@mock.patch('socket.gethostname')
48+
def test_resource(self, mock_gethostname):
49+
mock_gethostname.return_value = kHostName
50+
nodename, nodelist, _ = local.get_resource()
51+
self.assertEqual(nodename, kHostName)
52+
self.assertEqual(nodelist, [kHostName])
53+
54+
55+
class TestSlurm(unittest.TestCase):
56+
@mock.patch.dict('os.environ', values={
57+
'SLURM_JOB_NODELIST': kHostName,
58+
'SLURMD_NODENAME': kHostName,
59+
'SLURM_JOB_NUM_NODES': '1'
60+
})
61+
def test_single(self):
62+
nodename, nodelist, _ = slurm.get_resource()
63+
self.assertEqual(nodename, kHostName)
64+
self.assertEqual(nodelist, [kHostName])
65+
66+
@mock.patch.dict('os.environ', values={
67+
'SLURM_JOB_NODELIST': 'compute-b24-[1-3,5-9],compute-b25-[4,8]',
68+
'SLURMD_NODENAME': 'compute-b24-2',
69+
'SLURM_JOB_NUM_NODES': '10'
70+
})
71+
def test_multiple(self):
72+
nodename, nodelist, _ = slurm.get_resource()
73+
self.assertEqual(nodename, 'compute-b24-2')
74+
self.assertEqual(nodelist, [
75+
'compute-b24-1',
76+
'compute-b24-2',
77+
'compute-b24-3',
78+
'compute-b24-5',
79+
'compute-b24-6',
80+
'compute-b24-7',
81+
'compute-b24-8',
82+
'compute-b24-9',
83+
'compute-b25-4',
84+
'compute-b25-8'
85+
])
86+
87+
def test_illegal(self):
88+
environ = {
89+
'SLURM_JOB_NODELIST': 'compute-b24-[3-5]',
90+
'SLURMD_NODENAME': 'compute-b24-4'
91+
}
92+
with mock.patch.dict('os.environ', environ):
93+
with self.assertRaises(RuntimeError) as cm:
94+
_ = slurm.get_resource()
95+
self.assertIn('Could not get SLURM number', str(cm.exception))
96+
97+
environ = {
98+
'SLURM_JOB_NODELIST': 'compute-b24-1,compute-b25-2',
99+
'SLURMD_NODENAME': 'compute-b25-2',
100+
'SLURM_JOB_NUM_NODES': '4'
101+
}
102+
with mock.patch.dict('os.environ', environ):
103+
with self.assertRaises(ValueError) as cm:
104+
_ = slurm.get_resource()
105+
self.assertIn('Number of slurm nodes 2', str(cm.exception))
106+
107+
environ = {
108+
'SLURM_JOB_NODELIST': 'compute-b24-1,compute-b25-3',
109+
'SLURMD_NODENAME': 'compute-b25-2',
110+
'SLURM_JOB_NUM_NODES': '2'
111+
}
112+
with mock.patch.dict('os.environ', environ):
113+
with self.assertRaises(ValueError) as cm:
114+
_ = slurm.get_resource()
115+
self.assertIn('Nodename(compute-b25-2', str(cm.exception))

source/tests/test_env.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import unittest
2+
3+
from deepmd import env
4+
from unittest import mock
5+
6+
7+
class TestTFThreadCount(unittest.TestCase):
8+
@mock.patch.dict('os.environ', values={})
9+
def test_empty(self):
10+
intra, inter = env.get_tf_default_nthreads()
11+
self.assertEqual(intra, 0)
12+
self.assertEqual(inter, 0)
13+
14+
@mock.patch.dict('os.environ', values={
15+
'TF_INTRA_OP_PARALLELISM_THREADS': '5',
16+
'TF_INTER_OP_PARALLELISM_THREADS': '3'
17+
})
18+
def test_given(self):
19+
intra, inter = env.get_tf_default_nthreads()
20+
self.assertEqual(intra, 5)
21+
self.assertEqual(inter, 3)
22+
23+
24+
class TestTFSessionConfig(unittest.TestCase):
25+
def test_default(self):
26+
shared = env.default_tf_session_config
27+
new = env.get_tf_session_config()
28+
self.assertNotEqual(id(shared), id(new))
29+
30+
@mock.patch('deepmd.env.get_tf_default_nthreads')
31+
def test_get(self, mock_method):
32+
mock_method.return_value = (5, 3)
33+
config = env.get_tf_session_config()
34+
self.assertEqual(config.intra_op_parallelism_threads, 5)
35+
self.assertEqual(config.inter_op_parallelism_threads, 3)
36+
37+
def test_reset(self):
38+
shared = env.default_tf_session_config
39+
env.reset_default_tf_session_config(True)
40+
self.assertEqual(shared.device_count['GPU'], 0)
41+
env.reset_default_tf_session_config(False)
42+
self.assertEqual(len(shared.device_count), 0)

0 commit comments

Comments
 (0)