@@ -168,17 +168,33 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
168168 auto connected_clusters = GetConnectedClusters (graph_viewer_, ng_clusters);
169169
170170 int no_of_clusters = 0 ;
171+ std::vector<NodeIndex> prev_cluster;
172+ bool try_next_cluster = false ;
171173
172174 for (auto this_cluster : connected_clusters) {
175+ bool omit_subgraph = false ;
176+ if (try_next_cluster) {
177+ // no need to check previous cluster
178+ for (auto idx : prev_cluster) {
179+ if ((std::find (this_cluster.begin (), this_cluster.end (), idx)) == this_cluster.end ()) {
180+ this_cluster.emplace_back (idx);
181+ }
182+ }
183+ try_next_cluster = false ;
184+ }
185+
173186 // If subgraph has less then three, graph is considered trivial unless its an epctx cluster
174- if (this_cluster.size () < 3 ) {
187+ if (!try_next_cluster && this_cluster.size () < 3 ) {
175188 bool is_epctx_node = false ;
176189 for (auto node_idx : this_cluster) {
177190 if (graph_viewer_.GetNode (node_idx)->OpType () == " EPContext" )
178191 is_epctx_node = true ;
179192 }
180- if (!is_epctx_node)
181- continue ;
193+ if (!is_epctx_node) {
194+ omit_subgraph = true ;
195+ prev_cluster = this_cluster;
196+ try_next_cluster = true ;
197+ }
182198 }
183199
184200 std::vector<std::string> cluster_graph_inputs, cluster_inputs, cluster_outputs;
@@ -190,7 +206,7 @@ std::vector<std::unique_ptr<ComputeCapability>> GetCapability::Execute() {
190206 cluster_inputs,
191207 cluster_outputs);
192208
193- bool omit_subgraph = false ;
209+
194210 // Omitting zero dim subgraphs
195211 for (auto index : this_cluster) {
196212 const Node* node = graph_viewer_.GetNode (index);
0 commit comments