Skip to content

Commit c98e4b4

Browse files
authored
Merge pull request #9 from tskisner/comm_reuse
Support passing in pre-created node and node-rank communicators
2 parents ba9ce2e + a737123 commit c98e4b4

File tree

2 files changed

+95
-17
lines changed

2 files changed

+95
-17
lines changed

pshmem/shmem.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212

1313

1414
class MPIShared(object):
15-
"""
16-
Create a shared memory buffer that is replicated across nodes.
15+
"""Create a shared memory buffer that is replicated across nodes.
1716
1817
For the given array dimensions and datatype, the original communicator
1918
is split into groups of processes that can share memory (i.e. that are
@@ -32,13 +31,19 @@ class MPIShared(object):
3231
If comm is None, a simple local numpy array is used.
3332
3433
Args:
35-
shape (tuple): the dimensions of the array.
36-
dtype (np.dtype): the data type of the array.
37-
comm (MPI.Comm): the full communicator to use. This may span
34+
shape (tuple): The dimensions of the array.
35+
dtype (np.dtype): The data type of the array.
36+
comm (MPI.Comm): The full communicator to use. This may span
3837
multiple nodes, and each node will have a copy of the data.
38+
comm_node (MPI.Comm): The communicator of processes within the
39+
same node. If None, the node communicator will be created.
40+
comm_node_rank (MPI.Comm): The communicator of processes with
41+
the same rank across all nodes. If None, this will be
42+
created.
43+
3944
"""
4045

41-
def __init__(self, shape, dtype, comm):
46+
def __init__(self, shape, dtype, comm, comm_node=None, comm_node_rank=None):
4247
# Copy the datatype in order to support arguments that are aliases,
4348
# like "numpy.float64".
4449
self._dtype = np.dtype(dtype)
@@ -79,14 +84,43 @@ def __init__(self, shape, dtype, comm):
7984
if self._comm is not None:
8085
from mpi4py import MPI
8186

82-
self._nodecomm = self._comm.Split_type(MPI.COMM_TYPE_SHARED, 0)
87+
self._free_comm_node = False
88+
if comm_node is None:
89+
# Create it
90+
self._nodecomm = self._comm.Split_type(MPI.COMM_TYPE_SHARED, 0)
91+
self._free_comm_node = True
92+
else:
93+
# Check it
94+
if self._procs % comm_node.size != 0:
95+
msg = "Node communicator size ({}) does not divide ".format(
96+
comm_node.size
97+
)
98+
msg += "evenly into the total number of processes ({})".format(
99+
self._procs
100+
)
101+
raise ValueError(msg)
102+
self._nodecomm = comm_node
83103
self._noderank = self._nodecomm.rank
84104
self._nodeprocs = self._nodecomm.size
85105
self._nodes = self._procs // self._nodeprocs
86106
if self._nodes * self._nodeprocs < self._procs:
87107
self._nodes += 1
88108
self._mynode = self._rank // self._nodeprocs
89-
self._rankcomm = self._comm.Split(self._noderank, self._mynode)
109+
110+
self._free_comm_node_rank = False
111+
if comm_node_rank is None:
112+
# Create it
113+
self._rankcomm = self._comm.Split(self._noderank, self._mynode)
114+
self._free_comm_node_rank = True
115+
else:
116+
# Check it
117+
if comm_node_rank.size != self._nodes:
118+
msg = "Node rank communicator size ({}) does not match ".format(
119+
comm_node_rank.size
120+
)
121+
msg += "the number of nodes ({})".format(self._nodes)
122+
raise ValueError(msg)
123+
self._rankcomm = comm_node_rank
90124

91125
# Consider a corner case of the previous calculation. Imagine that
92126
# the number of processes is not evenly divisible by the number of
@@ -291,10 +325,18 @@ def close(self):
291325
self._win.Free()
292326
self._win = None
293327
# Free other communicators if needed
294-
if hasattr(self, "_rankcomm") and (self._rankcomm is not None):
328+
if (
329+
hasattr(self, "_rankcomm")
330+
and (self._rankcomm is not None)
331+
and self._free_comm_node_rank
332+
):
295333
self._rankcomm.Free()
296334
self._rankcomm = None
297-
if hasattr(self, "_nodecomm") and (self._nodecomm is not None):
335+
if (
336+
hasattr(self, "_nodecomm")
337+
and (self._nodecomm is not None)
338+
and self._free_comm_node
339+
):
298340
self._nodecomm.Free()
299341
self._nodecomm = None
300342
return

pshmem/test.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
try:
2828
import mpi4py.MPI as MPI
2929
except ImportError:
30-
raise ImportError("Cannot import mpi4py, will only test serial functionality.")
30+
print("Cannot import mpi4py, will only test serial functionality.")
3131

3232

3333
class ShmemTest(unittest.TestCase):
@@ -44,7 +44,7 @@ def setUp(self):
4444
def tearDown(self):
4545
pass
4646

47-
def read_write(self, comm):
47+
def read_write(self, comm, comm_node=None, comm_node_rank=None):
4848
"""Run a sequence of various access tests."""
4949
rank = 0
5050
procs = 1
@@ -76,7 +76,13 @@ def read_write(self, comm):
7676
# object has no dangling reference counts after leaving the context,
7777
# and will ensure that the shared memory is freed properly.
7878

79-
with MPIShared(local.shape, local.dtype, comm) as shm:
79+
with MPIShared(
80+
local.shape,
81+
local.dtype,
82+
comm,
83+
comm_node=comm_node,
84+
comm_node_rank=comm_node_rank,
85+
) as shm:
8086
for p in range(procs):
8187
# Every process takes turns writing to the buffer.
8288
setdata = None
@@ -94,7 +100,7 @@ def read_write(self, comm):
94100
try:
95101
# All processes call set(), but only data on rank p matters.
96102
shm.set(setdata, setoffset, fromrank=p)
97-
except:
103+
except (RuntimeError, ValueError):
98104
print(
99105
"proc {} threw exception during set()".format(rank),
100106
flush=True,
@@ -117,7 +123,7 @@ def read_write(self, comm):
117123
setoffset[1] : setoffset[1] + setdata.shape[1],
118124
setoffset[2] : setoffset[2] + setdata.shape[2],
119125
] = setdata
120-
except:
126+
except (RuntimeError, ValueError):
121127
print(
122128
"proc {} threw exception during __setitem__".format(
123129
rank
@@ -221,6 +227,36 @@ def test_comm_self(self):
221227
# Every process does the operations on COMM_SELF
222228
self.read_write(MPI.COMM_SELF)
223229

230+
def test_comm_reuse(self):
231+
if self.comm is not None:
232+
if self.comm.rank == 0:
233+
print("Testing MPIShared with re-used node comm...", flush=True)
234+
nodecomm = self.comm.Split_type(MPI.COMM_TYPE_SHARED, 0)
235+
noderank = nodecomm.rank
236+
nodeprocs = nodecomm.size
237+
nodes = self.comm.size // nodeprocs
238+
mynode = self.comm.rank // nodeprocs
239+
rankcomm = self.comm.Split(noderank, mynode)
240+
241+
self.read_write(self.comm, comm_node=nodecomm, comm_node_rank=rankcomm)
242+
243+
if nodes > 1 and nodeprocs > 2:
244+
# We have at least one node, test passing in an incorrect
245+
# communicator for the node comm.
246+
evenoddcomm = self.comm.Split(self.comm.rank % 2, self.comm.rank // 2)
247+
try:
248+
test_shared = MPIShared(
249+
(10, 5),
250+
np.float64,
251+
self.comm,
252+
comm_node=evenoddcomm,
253+
comm_node_rank=evenoddcomm,
254+
)
255+
print("Failed to catch construction with bad node comm")
256+
self.assertTrue(False)
257+
except ValueError:
258+
print("Successfully caught construction with bad node comm")
259+
224260
def test_shape(self):
225261
good_dims = [
226262
(2, 5, 10),
@@ -245,7 +281,7 @@ def test_shape(self):
245281
if self.rank == 0:
246282
print("successful creation with shape {}".format(dims), flush=True)
247283
del shm
248-
except Exception:
284+
except (RuntimeError, ValueError):
249285
if self.rank == 0:
250286
print(
251287
"unsuccessful creation with shape {}".format(dims), flush=True
@@ -256,7 +292,7 @@ def test_shape(self):
256292
if self.rank == 0:
257293
print("unsuccessful rejection of shape {}".format(dims), flush=True)
258294
del shm
259-
except Exception:
295+
except (RuntimeError, ValueError):
260296
if self.rank == 0:
261297
print("successful rejection of shape {}".format(dims), flush=True)
262298

0 commit comments

Comments
 (0)