Skip to content

Commit 8f63b8d

Browse files
authored
Remove all restrictions to MPI rank zero (#353)
1 parent 00b0d6c commit 8f63b8d

File tree

1 file changed

+41
-61
lines changed

1 file changed

+41
-61
lines changed

pylammpsmpi/mpi/lmpmpi.py

Lines changed: 41 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,14 @@ def convert_data(val, type, length, width):
7979
val = _gather_data_from_all_processors(
8080
data=job.numpy.extract_compute(*filtered_args)
8181
)
82-
if MPI.COMM_WORLD.rank == 0:
83-
length = job.get_natoms()
84-
return convert_data(val=val, type=type, length=length, width=width)
82+
length = job.get_natoms()
83+
return convert_data(val=val, type=type, length=length, width=width)
8584
else: # Todo
8685
raise ValueError("Local style is currently not supported")
8786

8887

8988
def get_version(job, funct_args):
90-
if MPI.COMM_WORLD.rank == 0:
91-
return job.version()
89+
return job.version()
9290

9391

9492
def get_file(job, funct_args):
@@ -107,53 +105,48 @@ def commands_string(job, funct_args):
107105

108106

109107
def extract_setting(job, funct_args):
110-
if MPI.COMM_WORLD.rank == 0:
111-
return job.extract_setting(*funct_args)
108+
return job.extract_setting(*funct_args)
112109

113110

114111
def extract_global(job, funct_args):
115-
if MPI.COMM_WORLD.rank == 0:
116-
return job.extract_global(*funct_args)
112+
return job.extract_global(*funct_args)
117113

118114

119115
def extract_box(job, funct_args):
120-
if MPI.COMM_WORLD.rank == 0:
121-
return job.extract_box(*funct_args)
116+
return job.extract_box(*funct_args)
122117

123118

124119
def extract_atom(job, funct_args):
125-
if MPI.COMM_WORLD.rank == 0:
126-
# extract atoms return an internal data type
127-
# this has to be reformatted
128-
name = str(funct_args[0])
129-
if name not in atom_properties:
130-
return []
120+
# extract atoms return an internal data type
121+
# this has to be reformatted
122+
name = str(funct_args[0])
123+
if name not in atom_properties:
124+
return []
131125

132-
# this block prevents error when trying to access values
133-
# that do not exist
134-
try:
135-
val = job.extract_atom(name, atom_properties[name]["type"])
136-
except ValueError:
137-
return []
138-
# this is per atom quantity - so get
139-
# number of atoms - first dimension
140-
natoms = job.get_natoms()
141-
# second dim is from dict
142-
dim = atom_properties[name]["dim"]
143-
data = []
144-
if dim > 1:
145-
for i in range(int(natoms)):
146-
dummy = [val[i][x] for x in range(dim)]
147-
data.append(dummy)
148-
else:
149-
data = [val[x] for x in range(int(natoms))]
126+
# this block prevents error when trying to access values
127+
# that do not exist
128+
try:
129+
val = job.extract_atom(name, atom_properties[name]["type"])
130+
except ValueError:
131+
return []
132+
# this is per atom quantity - so get
133+
# number of atoms - first dimension
134+
natoms = job.get_natoms()
135+
# second dim is from dict
136+
dim = atom_properties[name]["dim"]
137+
data = []
138+
if dim > 1:
139+
for i in range(int(natoms)):
140+
dummy = [val[i][x] for x in range(dim)]
141+
data.append(dummy)
142+
else:
143+
data = [val[x] for x in range(int(natoms))]
150144

151-
return np.array(data)
145+
return np.array(data)
152146

153147

154148
def extract_fix(job, funct_args):
155-
if MPI.COMM_WORLD.rank == 0:
156-
return job.extract_fix(*funct_args)
149+
return job.extract_fix(*funct_args)
157150

158151

159152
def extract_variable(job, funct_args):
@@ -163,9 +156,8 @@ def extract_variable(job, funct_args):
163156
data = _gather_data_from_all_processors(
164157
data=job.numpy.extract_variable(*funct_args)
165158
)
166-
if MPI.COMM_WORLD.rank == 0:
167-
return np.array(data)
168-
elif MPI.COMM_WORLD.rank == 0:
159+
return np.array(data)
160+
else:
169161
# if type is 1 - reformat file
170162
try:
171163
data = job.extract_variable(*funct_args)
@@ -175,8 +167,7 @@ def extract_variable(job, funct_args):
175167

176168

177169
def get_natoms(job, funct_args):
178-
if MPI.COMM_WORLD.rank == 0:
179-
return job.get_natoms()
170+
return job.get_natoms()
180171

181172

182173
def set_variable(job, funct_args):
@@ -303,33 +294,27 @@ def set_fix_external_callback(job, funct_args):
303294

304295

305296
def get_neighlist(job, funct_args):
306-
if MPI.COMM_WORLD.rank == 0:
307-
return job.get_neighlist(*funct_args)
297+
return job.get_neighlist(*funct_args)
308298

309299

310300
def find_pair_neighlist(job, funct_args):
311-
if MPI.COMM_WORLD.rank == 0:
312-
return job.find_pair_neighlist(*funct_args)
301+
return job.find_pair_neighlist(*funct_args)
313302

314303

315304
def find_fix_neighlist(job, funct_args):
316-
if MPI.COMM_WORLD.rank == 0:
317-
return job.find_fix_neighlist(*funct_args)
305+
return job.find_fix_neighlist(*funct_args)
318306

319307

320308
def find_compute_neighlist(job, funct_args):
321-
if MPI.COMM_WORLD.rank == 0:
322-
return job.find_compute_neighlist(*funct_args)
309+
return job.find_compute_neighlist(*funct_args)
323310

324311

325312
def get_neighlist_size(job, funct_args):
326-
if MPI.COMM_WORLD.rank == 0:
327-
return job.get_neighlist_size(*funct_args)
313+
return job.get_neighlist_size(*funct_args)
328314

329315

330316
def get_neighlist_element_neighbors(job, funct_args):
331-
if MPI.COMM_WORLD.rank == 0:
332-
return job.get_neighlist_element_neighbors(*funct_args)
317+
return job.get_neighlist_element_neighbors(*funct_args)
333318

334319

335320
def get_thermo(job, funct_args):
@@ -443,12 +428,7 @@ def select_cmd(argument):
443428

444429
def _gather_data_from_all_processors(data):
445430
data_gather = MPI.COMM_WORLD.gather(data, root=0)
446-
if MPI.COMM_WORLD.rank == 0:
447-
data = []
448-
for vl in data_gather:
449-
for v in vl:
450-
data.append(v)
451-
return data
431+
return [v for vl in data_gather for v in vl]
452432

453433

454434
def _run_lammps_mpi(argument_lst):

0 commit comments

Comments
 (0)