Skip to content

Commit b7099cc

Browse files
committed
Pass ExecutorService to path-finding algorithms
1 parent 97298d0 commit b7099cc

File tree

21 files changed

+209
-154
lines changed

21 files changed

+209
-154
lines changed

algo/src/main/java/org/neo4j/gds/dag/longestPath/DagLongestPath.java

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import java.util.Optional;
4444
import java.util.concurrent.ConcurrentHashMap;
4545
import java.util.concurrent.CountedCompleter;
46-
import java.util.concurrent.ForkJoinPool;
4746
import java.util.concurrent.ForkJoinTask;
4847
import java.util.concurrent.atomic.AtomicLong;
4948
import java.util.function.LongFunction;
@@ -115,35 +114,37 @@ private void initializeInDegrees() {
115114
private void traverse() {
116115
this.progressTracker.beginSubTask("Traversal");
117116

118-
ForkJoinPool forkJoinPool = ExecutorServiceUtil.createForkJoinPool(concurrency);
119-
var tasks = ConcurrentHashMap.<ForkJoinTask<Void>>newKeySet();
117+
try(var forkJoinPool = ExecutorServiceUtil.createForkJoinPool(concurrency)) {
118+
var tasks = ConcurrentHashMap.<ForkJoinTask<Void>>newKeySet();
120119

121-
LongFunction<CountedCompleter<Void>> taskProducer =
122-
(nodeId) -> new LongestPathTask(
123-
null,
124-
nodeId,
125-
graph.concurrentCopy(),
126-
inDegrees,
127-
parentsAndDistances
120+
LongFunction<CountedCompleter<Void>> taskProducer =
121+
(nodeId) -> new LongestPathTask(
122+
null,
123+
nodeId,
124+
graph.concurrentCopy(),
125+
inDegrees,
126+
parentsAndDistances
127+
);
128+
129+
ParallelUtil.parallelForEachNode(
130+
nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, nodeId -> {
131+
if (inDegrees.get(nodeId) == 0L) {
132+
tasks.add(taskProducer.apply(nodeId));
133+
parentsAndDistances.set(nodeId, nodeId, 0);
134+
}
135+
// Might not reach 100% if there are cycles in the graph
136+
progressTracker.logProgress();
137+
}
128138
);
129139

130-
ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, nodeId -> {
131-
if (inDegrees.get(nodeId) == 0L) {
132-
tasks.add(taskProducer.apply(nodeId));
133-
parentsAndDistances.set(nodeId, nodeId, 0);
140+
for (ForkJoinTask<Void> task : tasks) {
141+
forkJoinPool.submit(task);
134142
}
135-
// Might not reach 100% if there are cycles in the graph
136-
progressTracker.logProgress();
137-
});
138143

139-
for (ForkJoinTask<Void> task : tasks) {
140-
forkJoinPool.submit(task);
144+
// calling join makes sure the pool waits for all the tasks to complete before shutting down
145+
tasks.forEach(ForkJoinTask::join);
146+
this.progressTracker.endSubTask("Traversal");
141147
}
142-
143-
// calling join makes sure the pool waits for all the tasks to complete before shutting down
144-
tasks.forEach(ForkJoinTask::join);
145-
forkJoinPool.shutdown();
146-
this.progressTracker.endSubTask("Traversal");
147148
}
148149

149150
private static final class LongestPathTask extends CountedCompleter<Void> {

algo/src/main/java/org/neo4j/gds/dag/topologicalsort/TopologicalSort.java

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@
2727
import org.neo4j.gds.core.concurrency.Concurrency;
2828
import org.neo4j.gds.core.concurrency.ExecutorServiceUtil;
2929
import org.neo4j.gds.core.concurrency.ParallelUtil;
30-
import org.neo4j.gds.termination.TerminationFlag;
3130
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
3231
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
3332
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
33+
import org.neo4j.gds.termination.TerminationFlag;
3434

3535
import java.util.Optional;
3636
import java.util.concurrent.ConcurrentHashMap;
3737
import java.util.concurrent.CountedCompleter;
38-
import java.util.concurrent.ForkJoinPool;
3938
import java.util.concurrent.ForkJoinTask;
4039
import java.util.function.LongFunction;
4140

@@ -117,43 +116,45 @@ private void initializeInDegrees() {
117116
private void traverse() {
118117
this.progressTracker.beginSubTask("Traversal");
119118

120-
ForkJoinPool forkJoinPool = ExecutorServiceUtil.createForkJoinPool(concurrency);
121-
var tasks = ConcurrentHashMap.<ForkJoinTask<Void>>newKeySet();
119+
try(var forkJoinPool = ExecutorServiceUtil.createForkJoinPool(concurrency)) {
120+
var tasks = ConcurrentHashMap.<ForkJoinTask<Void>>newKeySet();
121+
122+
var taskProducer =
123+
longestPathDistances
124+
.<LongFunction<CountedCompleter<Void>>>map(pathDistances -> (nodeId) -> new LongestPathTask(
125+
null,
126+
nodeId,
127+
graph.concurrentCopy(),
128+
result,
129+
inDegrees,
130+
pathDistances
131+
))
132+
.orElseGet(() -> (nodeId) -> new TraversalTask(
133+
null,
134+
nodeId,
135+
graph.concurrentCopy(),
136+
result,
137+
inDegrees
138+
));
122139

123-
LongFunction<CountedCompleter<Void>> taskProducer = longestPathDistances.isPresent()
124-
? (nodeId) -> new LongestPathTask(
125-
null,
126-
nodeId,
127-
graph.concurrentCopy(),
128-
result,
129-
inDegrees,
130-
longestPathDistances.get()
131-
)
132-
: (nodeId) -> new TraversalTask(
133-
null,
134-
nodeId,
135-
graph.concurrentCopy(),
136-
result,
137-
inDegrees
138-
);
140+
ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, nodeId -> {
141+
if (inDegrees.get(nodeId) == 0L) {
142+
result.addNode(nodeId);
143+
tasks.add(taskProducer.apply(nodeId));
144+
}
145+
// Might not reach 100% if there are cycles in the graph
146+
progressTracker.logProgress();
147+
});
139148

140-
ParallelUtil.parallelForEachNode(nodeCount, concurrency, TerminationFlag.RUNNING_TRUE, nodeId -> {
141-
if (inDegrees.get(nodeId) == 0L) {
142-
result.addNode(nodeId);
143-
tasks.add(taskProducer.apply(nodeId));
149+
for (ForkJoinTask<Void> task : tasks) {
150+
forkJoinPool.submit(task);
144151
}
145-
// Might not reach 100% if there are cycles in the graph
146-
progressTracker.logProgress();
147-
});
148152

149-
for (ForkJoinTask<Void> task : tasks) {
150-
forkJoinPool.submit(task);
153+
// calling join makes sure the pool waits for all the tasks to complete before shutting down
154+
tasks.forEach(ForkJoinTask::join);
155+
this.progressTracker.endSubTask("Traversal");
151156
}
152157

153-
// calling join makes sure the pool waits for all the tasks to complete before shutting down
154-
tasks.forEach(ForkJoinTask::join);
155-
forkJoinPool.shutdown();
156-
this.progressTracker.endSubTask("Traversal");
157158
}
158159

159160
private static final class TraversalTask extends CountedCompleter<Void> {

algo/src/main/java/org/neo4j/gds/paths/bellmanford/BellmanFord.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import java.util.ArrayList;
4141
import java.util.Optional;
42+
import java.util.concurrent.ExecutorService;
4243
import java.util.concurrent.atomic.AtomicLong;
4344
import java.util.stream.LongStream;
4445
import java.util.stream.Stream;
@@ -51,35 +52,40 @@ public class BellmanFord extends Algorithm<BellmanFordResult> {
5152
private final boolean trackNegativeCycles;
5253
private final boolean trackPaths;
5354
private final Concurrency concurrency;
55+
private final ExecutorService executorService;
5456

5557
public BellmanFord(
5658
Graph graph,
5759
ProgressTracker progressTracker,
5860
long sourceNode,
5961
boolean trackNegativeCycles,
6062
boolean trackPaths,
61-
Concurrency concurrency
63+
Concurrency concurrency,
64+
ExecutorService executorService
6265
) {
6366
super(progressTracker);
6467
this.graph = graph;
6568
this.sourceNode = sourceNode;
6669
this.trackNegativeCycles = trackNegativeCycles;
6770
this.trackPaths = trackPaths;
6871
this.concurrency = concurrency;
72+
this.executorService = executorService;
6973
}
7074

7175
public BellmanFord(
7276
Graph graph,
7377
ProgressTracker progressTracker,
74-
BellmanFordParameters parameters
78+
BellmanFordParameters parameters,
79+
ExecutorService executorService
7580
) {
7681
this(
7782
graph,
7883
progressTracker,
7984
parameters.sourceNode(),
8085
parameters.trackNegativeCycles(),
8186
parameters.trackPaths(),
82-
parameters.concurrency()
87+
parameters.concurrency(),
88+
executorService
8389
);
8490
}
8591

@@ -123,13 +129,15 @@ public BellmanFordResult compute() {
123129
RunWithConcurrency.builder()
124130
.tasks(tasks)
125131
.concurrency(concurrency)
132+
.executor(executorService)
126133
.run();
127134
progressTracker.endSubTask();
128135
progressTracker.beginSubTask();
129136
frontierSize.set(0); // fill global queue again
130137
RunWithConcurrency.builder()
131138
.tasks(tasks)
132139
.concurrency(concurrency)
140+
.executor(executorService)
133141
.run();
134142
progressTracker.endSubTask();
135143
}

algo/src/main/java/org/neo4j/gds/paths/delta/DeltaStepping.java

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
import org.neo4j.gds.core.utils.partition.PartitionUtils;
3636
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3737
import org.neo4j.gds.paths.PathResult;
38-
import org.neo4j.gds.paths.delta.config.AllShortestPathsDeltaBaseConfig;
3938
import org.neo4j.gds.paths.dijkstra.PathFindingResult;
4039

4140
import java.util.Arrays;
@@ -65,24 +64,6 @@ public final class DeltaStepping extends Algorithm<PathFindingResult> {
6564

6665
private final ExecutorService executorService;
6766

68-
@Deprecated(forRemoval = true)
69-
public static DeltaStepping of(
70-
Graph graph,
71-
AllShortestPathsDeltaBaseConfig config,
72-
ExecutorService executorService,
73-
ProgressTracker progressTracker
74-
) {
75-
return new DeltaStepping(
76-
graph,
77-
graph.toMappedNodeId(config.sourceNode()),
78-
config.delta(),
79-
config.concurrency(),
80-
true,
81-
executorService,
82-
progressTracker
83-
);
84-
}
85-
8667
public static DeltaStepping of(
8768
Graph graph,
8869
DeltaSteppingParameters parameters,

algo/src/main/java/org/neo4j/gds/paths/traverse/BFS.java

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,20 @@
2222

2323
import org.neo4j.gds.Algorithm;
2424
import org.neo4j.gds.api.Graph;
25+
import org.neo4j.gds.collections.ha.HugeDoubleArray;
26+
import org.neo4j.gds.collections.ha.HugeLongArray;
2527
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
2628
import org.neo4j.gds.core.concurrency.Concurrency;
27-
import org.neo4j.gds.core.concurrency.DefaultPool;
2829
import org.neo4j.gds.core.concurrency.ParallelUtil;
2930
import org.neo4j.gds.core.utils.paged.HugeAtomicBitSet;
30-
import org.neo4j.gds.collections.ha.HugeDoubleArray;
31-
import org.neo4j.gds.collections.ha.HugeLongArray;
3231
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
3332
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3433
import org.neo4j.gds.termination.TerminationFlag;
3534

3635
import java.util.ArrayList;
3736
import java.util.Collection;
3837
import java.util.List;
38+
import java.util.concurrent.ExecutorService;
3939
import java.util.concurrent.atomic.AtomicLong;
4040

4141
/**
@@ -85,27 +85,30 @@ public final class BFS extends Algorithm<HugeLongArray> {
8585
// each node id in the `traversedNodes`.
8686
private final HugeAtomicBitSet visited;
8787

88+
private final ExecutorService executorService;
8889
private final Concurrency concurrency;
8990

9091
public static BFS create(
9192
Graph graph,
9293
long startNodeId,
9394
ExitPredicate exitPredicate,
9495
Aggregator aggregatorFunction,
96+
long maximumDepth,
97+
ExecutorService executorService,
9598
Concurrency concurrency,
9699
ProgressTracker progressTracker,
97-
long maximumDepth,
98100
TerminationFlag terminationFlag
99101
) {
100102
return create(
101103
graph,
102104
startNodeId,
103105
exitPredicate,
104106
aggregatorFunction,
105-
concurrency,
106-
progressTracker,
107107
DEFAULT_DELTA,
108108
maximumDepth,
109+
executorService,
110+
concurrency,
111+
progressTracker,
109112
terminationFlag
110113
);
111114
}
@@ -115,10 +118,11 @@ static BFS create(
115118
long startNodeId,
116119
ExitPredicate exitPredicate,
117120
Aggregator aggregatorFunction,
118-
Concurrency concurrency,
119-
ProgressTracker progressTracker,
120121
int delta,
121122
long maximumDepth,
123+
ExecutorService executorService,
124+
Concurrency concurrency,
125+
ProgressTracker progressTracker,
122126
TerminationFlag terminationFlag
123127
) {
124128

@@ -136,10 +140,11 @@ static BFS create(
136140
visited,
137141
exitPredicate,
138142
aggregatorFunction,
139-
concurrency,
140-
progressTracker,
141143
delta,
142144
maximumDepth,
145+
executorService,
146+
concurrency,
147+
progressTracker,
143148
terminationFlag
144149
);
145150
}
@@ -152,17 +157,19 @@ private BFS(
152157
HugeAtomicBitSet visited,
153158
ExitPredicate exitPredicate,
154159
Aggregator aggregatorFunction,
155-
Concurrency concurrency,
156-
ProgressTracker progressTracker,
157160
int delta,
158161
long maximumDepth,
162+
ExecutorService executorService,
163+
Concurrency concurrency,
164+
ProgressTracker progressTracker,
159165
TerminationFlag terminationFlag
160166
) {
161167
super(progressTracker);
162168
this.graph = graph;
163169
this.sourceNodeId = sourceNodeId;
164170
this.exitPredicate = exitPredicate;
165171
this.aggregatorFunction = aggregatorFunction;
172+
this.executorService = executorService;
166173
this.concurrency = concurrency;
167174
this.delta = delta;
168175
this.maximumDepth = maximumDepth;
@@ -207,7 +214,7 @@ public HugeLongArray compute() {
207214
if (currentDepth == maximumDepth) {
208215
break;
209216
}
210-
ParallelUtil.run(bfsTaskList, DefaultPool.INSTANCE);
217+
ParallelUtil.run(bfsTaskList, executorService);
211218

212219
if (targetFoundIndex.get() != Long.MAX_VALUE) {
213220
break;

algo/src/main/java/org/neo4j/gds/spanningtree/Prim.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
* weights in the MST.
4242
*/
4343
public class Prim extends Algorithm<SpanningTree> {
44-
private final int EMPTY=-1;
44+
private static final int EMPTY = -1;
4545
private final Graph graph;
4646
private final DoubleUnaryOperator minMax;
4747
private final long startNodeId;

0 commit comments

Comments
 (0)