2727 try :
2828 import mpi4py .MPI as MPI
2929 except ImportError :
30- raise ImportError ("Cannot import mpi4py, will only test serial functionality." )
30+ print ("Cannot import mpi4py, will only test serial functionality." )
3131
3232
3333class ShmemTest (unittest .TestCase ):
@@ -44,7 +44,7 @@ def setUp(self):
4444 def tearDown (self ):
4545 pass
4646
47- def read_write (self , comm ):
47+ def read_write (self , comm , comm_node = None , comm_node_rank = None ):
4848 """Run a sequence of various access tests."""
4949 rank = 0
5050 procs = 1
@@ -76,7 +76,13 @@ def read_write(self, comm):
7676 # object has no dangling reference counts after leaving the context,
7777 # and will ensure that the shared memory is freed properly.
7878
79- with MPIShared (local .shape , local .dtype , comm ) as shm :
79+ with MPIShared (
80+ local .shape ,
81+ local .dtype ,
82+ comm ,
83+ comm_node = comm_node ,
84+ comm_node_rank = comm_node_rank ,
85+ ) as shm :
8086 for p in range (procs ):
8187 # Every process takes turns writing to the buffer.
8288 setdata = None
@@ -94,7 +100,7 @@ def read_write(self, comm):
94100 try :
95101 # All processes call set(), but only data on rank p matters.
96102 shm .set (setdata , setoffset , fromrank = p )
97- except :
103+ except ( RuntimeError , ValueError ) :
98104 print (
99105 "proc {} threw exception during set()" .format (rank ),
100106 flush = True ,
@@ -117,7 +123,7 @@ def read_write(self, comm):
117123 setoffset [1 ] : setoffset [1 ] + setdata .shape [1 ],
118124 setoffset [2 ] : setoffset [2 ] + setdata .shape [2 ],
119125 ] = setdata
120- except :
126+ except ( RuntimeError , ValueError ) :
121127 print (
122128 "proc {} threw exception during __setitem__" .format (
123129 rank
@@ -221,6 +227,36 @@ def test_comm_self(self):
221227 # Every process does the operations on COMM_SELF
222228 self .read_write (MPI .COMM_SELF )
223229
230+ def test_comm_reuse (self ):
231+ if self .comm is not None :
232+ if self .comm .rank == 0 :
233+ print ("Testing MPIShared with re-used node comm..." , flush = True )
234+ nodecomm = self .comm .Split_type (MPI .COMM_TYPE_SHARED , 0 )
235+ noderank = nodecomm .rank
236+ nodeprocs = nodecomm .size
237+ nodes = self .comm .size // nodeprocs
238+ mynode = self .comm .rank // nodeprocs
239+ rankcomm = self .comm .Split (noderank , mynode )
240+
241+ self .read_write (self .comm , comm_node = nodecomm , comm_node_rank = rankcomm )
242+
243+ if nodes > 1 and nodeprocs > 2 :
244+ # We have at least one node, test passing in an incorrect
245+ # communicator for the node comm.
246+ evenoddcomm = self .comm .Split (self .comm .rank % 2 , self .comm .rank // 2 )
247+ try :
248+ test_shared = MPIShared (
249+ (10 , 5 ),
250+ np .float64 ,
251+ self .comm ,
252+ comm_node = evenoddcomm ,
253+ comm_node_rank = evenoddcomm ,
254+ )
255+ print ("Failed to catch construction with bad node comm" )
256+ self .assertTrue (False )
257+ except ValueError :
258+ print ("Successfully caught construction with bad node comm" )
259+
224260 def test_shape (self ):
225261 good_dims = [
226262 (2 , 5 , 10 ),
@@ -245,7 +281,7 @@ def test_shape(self):
245281 if self .rank == 0 :
246282 print ("successful creation with shape {}" .format (dims ), flush = True )
247283 del shm
248- except Exception :
284+ except ( RuntimeError , ValueError ) :
249285 if self .rank == 0 :
250286 print (
251287 "unsuccessful creation with shape {}" .format (dims ), flush = True
@@ -256,7 +292,7 @@ def test_shape(self):
256292 if self .rank == 0 :
257293 print ("unsuccessful rejection of shape {}" .format (dims ), flush = True )
258294 del shm
259- except Exception :
295+ except ( RuntimeError , ValueError ) :
260296 if self .rank == 0 :
261297 print ("successful rejection of shape {}" .format (dims ), flush = True )
262298
0 commit comments