Skip to content
Merged
17 changes: 10 additions & 7 deletions nipype/pipeline/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def _merge_graphs(
# nodes of the supergraph.
supernodes = supergraph.nodes()
ids = [n._hierarchy + n._id for n in supernodes]
if len(np.unique(ids)) != len(ids):
if len(set(ids)) != len(ids):
# This should trap the problem of miswiring when multiple iterables are
# used at the same level. The use of the template below for naming
# updates to nodes is the general solution.
Expand Down Expand Up @@ -1100,11 +1100,12 @@ def make_field_func(*pair):
old_edge_dict = jedge_dict[jnode]
# the edge source node replicates
expansions = defaultdict(list)
for node in graph_in.nodes():
for node in graph_in:
for src_id in list(old_edge_dict.keys()):
# Drop the original JoinNodes; only concerned with
# generated Nodes
if hasattr(node, "joinfield") and node.itername == src_id:
itername = node.itername
if hasattr(node, "joinfield") and itername == src_id:
continue
# Patterns:
# - src_id : Non-iterable node
Expand All @@ -1113,10 +1114,12 @@ def make_field_func(*pair):
# - src_id.[a-z]I.[a-z]\d+ :
# Non-IdentityInterface w/ iterables
# - src_idJ\d+ : JoinNode(IdentityInterface)
if re.match(
src_id + r"((\.[a-z](I\.[a-z])?|J)\d+)?$", node.itername
):
expansions[src_id].append(node)
if itername.startswith(src_id):
itername = itername[len(src_id):]
if re.fullmatch(
r"((\.[a-z](I\.[a-z])?|J)\d+)?", itername
):
expansions[src_id].append(node)
for in_id, in_nodes in list(expansions.items()):
logger.debug(
"The join node %s input %s was expanded" " to %d nodes.",
Expand Down
61 changes: 41 additions & 20 deletions nipype/pipeline/engine/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def __init__(self, name, base_dir=None):
super(Workflow, self).__init__(name, base_dir)
self._graph = nx.DiGraph()

self._nodes_cache = set()
self._nested_workflows_cache = set()

# PUBLIC API
def clone(self, name):
"""Clone a workflow
Expand Down Expand Up @@ -141,7 +144,7 @@ def connect(self, *args, **kwargs):
self.disconnect(connection_list)
return

newnodes = []
newnodes = set()
for srcnode, destnode, _ in connection_list:
if self in [srcnode, destnode]:
msg = (
Expand All @@ -151,9 +154,9 @@ def connect(self, *args, **kwargs):

raise IOError(msg)
if (srcnode not in newnodes) and not self._has_node(srcnode):
newnodes.append(srcnode)
newnodes.add(srcnode)
if (destnode not in newnodes) and not self._has_node(destnode):
newnodes.append(destnode)
newnodes.add(destnode)
if newnodes:
self._check_nodes(newnodes)
for node in newnodes:
Expand All @@ -163,15 +166,16 @@ def connect(self, *args, **kwargs):
connected_ports = {}
for srcnode, destnode, connects in connection_list:
if destnode not in connected_ports:
connected_ports[destnode] = []
connected_ports[destnode] = set()
# check to see which ports of destnode are already
# connected.
if not disconnect and (destnode in self._graph.nodes()):
for edge in self._graph.in_edges(destnode):
data = self._graph.get_edge_data(*edge)
for sourceinfo, destname in data["connect"]:
if destname not in connected_ports[destnode]:
connected_ports[destnode] += [destname]
connected_ports[destnode].update(
destname
for _, destname in data["connect"]
)
for source, dest in connects:
# Currently datasource/sink/grabber.io modules
# determine their inputs/outputs depending on
Expand Down Expand Up @@ -226,7 +230,7 @@ def connect(self, *args, **kwargs):
)
if sourcename and not srcnode._check_outputs(sourcename):
not_found.append(["out", srcnode.name, sourcename])
connected_ports[destnode] += [dest]
connected_ports[destnode].add(dest)
infostr = []
for info in not_found:
infostr += [
Expand Down Expand Up @@ -269,6 +273,8 @@ def connect(self, *args, **kwargs):
"(%s, %s): new edge data: %s", srcnode, destnode, str(edge_data)
)

self._update_node_cache()

def disconnect(self, *args):
"""Disconnect nodes
See the docstring for connect for format.
Expand Down Expand Up @@ -314,6 +320,8 @@ def disconnect(self, *args):
else:
self._graph.add_edges_from([(srcnode, dstnode, edge_data)])

self._update_node_cache()

def add_nodes(self, nodes):
""" Add nodes to a workflow

Expand Down Expand Up @@ -346,6 +354,7 @@ def add_nodes(self, nodes):
if node._hierarchy is None:
node._hierarchy = self.name
self._graph.add_nodes_from(newnodes)
self._update_node_cache()

def remove_nodes(self, nodes):
""" Remove nodes from a workflow
Expand All @@ -356,6 +365,7 @@ def remove_nodes(self, nodes):
A list of EngineBase-based objects
"""
self._graph.remove_nodes_from(nodes)
self._update_node_cache()

# Input-Output access
@property
Expand Down Expand Up @@ -903,21 +913,32 @@ def _set_node_input(self, node, param, source, sourceinfo):
node.set_input(param, deepcopy(newval))

def _get_all_nodes(self):
allnodes = []
for node in self._graph.nodes():
if isinstance(node, Workflow):
allnodes.extend(node._get_all_nodes())
else:
allnodes.append(node)
allnodes = [
*self._nodes_cache.difference(self._nested_workflows_cache)
] # all nodes that are not workflows
for node in self._nested_workflows_cache:
allnodes.extend(node._get_all_nodes())
return allnodes

def _update_node_cache(self):
nodes = set(self._graph)

added_nodes = nodes.difference(self._nodes_cache)
removed_nodes = self._nodes_cache.difference(nodes)

self._nodes_cache = nodes
self._nested_workflows_cache.difference_update(removed_nodes)

for node in added_nodes:
if isinstance(node, Workflow):
self._nested_workflows_cache.add(node)

def _has_node(self, wanted_node):
for node in self._graph.nodes():
if wanted_node == node:
if wanted_node in self._nodes_cache:
return True
for node in self._nested_workflows_cache:
if node._has_node(wanted_node):
return True
if isinstance(node, Workflow):
if node._has_node(wanted_node):
return True
return False

def _create_flat_graph(self):
Expand Down Expand Up @@ -949,7 +970,7 @@ def _generate_flatgraph(self):
raise Exception(
("Workflow: %s is not a directed acyclic graph " "(DAG)") % self.name
)
nodes = list(nx.topological_sort(self._graph))
nodes = list(self._graph.nodes)
for node in nodes:
logger.debug("processing node: %s", node)
if isinstance(node, Workflow):
Expand Down