Skip to content

Commit a78a5e1

Browse files
committed
Almost purge RW config from algo code
1 parent e19a9aa commit a78a5e1

File tree

4 files changed

+120
-138
lines changed

4 files changed

+120
-138
lines changed

algo/src/main/java/org/neo4j/gds/traversal/RandomWalk.java

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,18 @@ public final class RandomWalk extends Algorithm<Stream<long[]>> {
5454

5555
public static RandomWalk create(
5656
Graph graph,
57-
RandomWalkBaseConfig config,
57+
int concurrency,
5858
WalkParameters walkParameters,
59+
List<Long> sourceNodes,
60+
int walkBufferSize,
61+
Optional<Long> randomSeed,
5962
ProgressTracker progressTracker,
6063
ExecutorService executorService
6164
) {
6265
if (graph.hasRelationshipProperty()) {
6366
EmbeddingUtils.validateRelationshipWeightPropertyValue(
6467
graph,
65-
config.concurrency(),
68+
concurrency,
6669
weight -> weight >= 0,
6770
"RandomWalk only supports non-negative weights.",
6871
executorService
@@ -71,15 +74,12 @@ public static RandomWalk create(
7174

7275
return new RandomWalk(
7376
graph,
74-
config.concurrency(),
77+
concurrency,
7578
executorService,
76-
config.walkBufferSize(),
77-
walkParameters.walksPerNode,
78-
walkParameters.walkLength,
79-
walkParameters.returnFactor,
80-
walkParameters.inOutFactor,
81-
config.sourceNodes(),
82-
config.randomSeed(),
79+
walkParameters,
80+
sourceNodes,
81+
walkBufferSize,
82+
randomSeed,
8383
progressTracker
8484
);
8585
}
@@ -88,12 +88,9 @@ private RandomWalk(
8888
Graph graph,
8989
int concurrency,
9090
ExecutorService executorService,
91-
int walkBufferSize,
92-
int walksPerNode,
93-
int walkLength,
94-
double returnFactor,
95-
double inOutFactor,
91+
WalkParameters walkParameters,
9692
List<Long> sourceNodes,
93+
int walkBufferSize,
9794
Optional<Long> maybeRandomSeed,
9895
ProgressTracker progressTracker
9996
) {
@@ -115,10 +112,7 @@ private RandomWalk(
115112
nextNodeSupplier,
116113
cumulativeWeightSupplier,
117114
walks,
118-
walksPerNode,
119-
walkLength,
120-
returnFactor,
121-
inOutFactor,
115+
walkParameters,
122116
randomSeed,
123117
progressTracker,
124118
externalTerminationFlag

algo/src/main/java/org/neo4j/gds/traversal/RandomWalkAlgorithmFactory.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,16 @@ public RandomWalk build(
4545
RandomWalkBaseConfig configuration,
4646
ProgressTracker progressTracker
4747
) {
48-
return RandomWalk.create(graph, configuration, configuration.walkParameters(), progressTracker, DefaultPool.INSTANCE);
48+
return RandomWalk.create(
49+
graph,
50+
configuration.concurrency(),
51+
configuration.walkParameters(),
52+
configuration.sourceNodes(),
53+
configuration.walkBufferSize(),
54+
configuration.randomSeed(),
55+
progressTracker,
56+
DefaultPool.INSTANCE
57+
);
4958
}
5059

5160
@Override

algo/src/main/java/org/neo4j/gds/traversal/RandomWalkTaskSupplier.java

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ class RandomWalkTaskSupplier implements Supplier<RandomWalkTask> {
3232
private final NextNodeSupplier nextNodeSupplier;
3333
private final RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier;
3434
private final BlockingQueue<long[]> walks;
35-
private final int walksPerNode;
36-
private final int walkLength;
37-
private final double returnFactor;
38-
private final double inOutFactor;
35+
private final WalkParameters walkParameters;
3936
private final long randomSeed;
4037
private final ProgressTracker progressTracker;
4138
private final TerminationFlag terminationFlag;
@@ -45,10 +42,7 @@ class RandomWalkTaskSupplier implements Supplier<RandomWalkTask> {
4542
NextNodeSupplier nextNodeSupplier,
4643
RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier,
4744
BlockingQueue<long[]> walks,
48-
int walksPerNode,
49-
int walkLength,
50-
double returnFactor,
51-
double inOutFactor,
45+
WalkParameters walkParameters,
5246
long randomSeed,
5347
ProgressTracker progressTracker,
5448
TerminationFlag terminationFlag
@@ -57,10 +51,7 @@ class RandomWalkTaskSupplier implements Supplier<RandomWalkTask> {
5751
this.nextNodeSupplier = nextNodeSupplier;
5852
this.cumulativeWeightSupplier = cumulativeWeightSupplier;
5953
this.walks = walks;
60-
this.walksPerNode = walksPerNode;
61-
this.walkLength = walkLength;
62-
this.returnFactor = returnFactor;
63-
this.inOutFactor = inOutFactor;
54+
this.walkParameters = walkParameters;
6455
this.randomSeed = randomSeed;
6556
this.progressTracker = progressTracker;
6657
this.terminationFlag = terminationFlag;
@@ -73,10 +64,10 @@ public RandomWalkTask get() {
7364
nextNodeSupplier,
7465
cumulativeWeightSupplier,
7566
walks,
76-
walksPerNode,
77-
walkLength,
78-
returnFactor,
79-
inOutFactor,
67+
walkParameters.walksPerNode,
68+
walkParameters.walkLength,
69+
walkParameters.returnFactor,
70+
walkParameters.inOutFactor,
8071
randomSeed,
8172
progressTracker,
8273
terminationFlag

0 commit comments

Comments
 (0)