Skip to content

Commit 9cb8785

Browse files
committed
Node2Vec also gets walkBufferSize
1 parent 6b22038 commit 9cb8785

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ public class Node2Vec extends Algorithm<Node2VecModel.Result> {
4747
private final List<Long> sourceNodes;
4848
private final Optional<Long> maybeRandomSeed;
4949
private final TrainParameters trainParameters;
50+
private final int walkBufferSize;
5051

5152

5253
public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength, int embeddingDimension) {
@@ -84,6 +85,7 @@ static Node2Vec create(
8485
concurrency,
8586
List.of(),
8687
maybeRandomSeed,
88+
1000,
8789
walkParameters,
8890
trainParameters,
8991
progressTracker
@@ -95,6 +97,7 @@ public Node2Vec(
9597
int concurrency,
9698
List<Long> sourceNodes,
9799
Optional<Long> maybeRandomSeed,
100+
int walkBufferSize,
98101
WalkParameters walkParameters,
99102
TrainParameters trainParameters,
100103
ProgressTracker progressTracker
@@ -103,6 +106,7 @@ public Node2Vec(
103106
this.graph = graph;
104107
this.concurrency = concurrency;
105108
this.walkParameters = walkParameters;
109+
this.walkBufferSize = walkBufferSize;
106110
this.sourceNodes = sourceNodes;
107111
this.maybeRandomSeed = maybeRandomSeed;
108112
this.trainParameters = trainParameters;
@@ -140,6 +144,7 @@ public Node2VecModel.Result compute() {
140144
concurrency,
141145
sourceNodes,
142146
walkParameters,
147+
walkBufferSize,
143148
DefaultPool.INSTANCE,
144149
progressTracker,
145150
terminationFlag
@@ -185,6 +190,7 @@ private List<Node2VecRandomWalkTask> walkTasks(
185190
int concurrency,
186191
List<Long> sourceNodes,
187192
WalkParameters walkParameters,
193+
int walkBufferSize,
188194
ExecutorService executorService,
189195
ProgressTracker progressTracker,
190196
TerminationFlag terminationFlag
@@ -211,6 +217,7 @@ private List<Node2VecRandomWalkTask> walkTasks(
211217
index,
212218
compressedRandomWalks,
213219
randomWalkPropabilitiesBuilder,
220+
walkBufferSize,
214221
randomSeed,
215222
walkParameters.walkLength,
216223
walkParameters.returnFactor,

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public Node2Vec build(
5252
configuration.concurrency(),
5353
configuration.sourceNodes(),
5454
configuration.randomSeed(),
55+
configuration.walkBufferSize(),
5556
configuration.walkParameters(),
5657
configuration.trainParameters(),
5758
progressTracker

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecRandomWalkTask.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ final class Node2VecRandomWalkTask implements Runnable {
3838
private final CompressedRandomWalks compressedRandomWalks;
3939
private final RandomWalkProbabilities.Builder randomWalkProbabilitiesBuilder;
4040
private final RandomWalkSampler sampler;
41+
private final int walkBufferSize;
4142
private int walks;
4243
private int maxWalkLength;
4344
private long maxIndex;
@@ -52,6 +53,7 @@ final class Node2VecRandomWalkTask implements Runnable {
5253
AtomicLong walkIndex,
5354
CompressedRandomWalks compressedRandomWalks,
5455
RandomWalkProbabilities.Builder randomWalkProbabilitiesBuilder,
56+
int walkBufferSize,
5557
long randomSeed,
5658
int walkLength,
5759
double returnFactor,
@@ -65,6 +67,7 @@ final class Node2VecRandomWalkTask implements Runnable {
6567
this.walkIndex = walkIndex;
6668
this.compressedRandomWalks = compressedRandomWalks;
6769
this.randomWalkProbabilitiesBuilder = randomWalkProbabilitiesBuilder;
70+
this.walkBufferSize = walkBufferSize;
6871

6972
this.sampler = RandomWalkSampler.create(
7073
graph,
@@ -85,7 +88,7 @@ private boolean consumePath(long[] path) {
8588
randomWalkProbabilitiesBuilder.registerWalk(path);
8689
compressedRandomWalks.add(index, path);
8790
maxWalkLength = Math.max(path.length, maxWalkLength);
88-
if (walks++ == 1000) { //this is just to get the same
91+
if (walks++ == walkBufferSize) {
8992
walks = 0;
9093
return this.terminationFlag.running();
9194
}

0 commit comments

Comments
 (0)