Skip to content

Commit 8e45487

Browse files
committed
Use WalkPArameters param object in Node2Vec
1 parent a78a5e1 commit 8e45487

File tree

2 files changed

+16
-33
lines changed

2 files changed

+16
-33
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.neo4j.gds.mem.MemoryUsage;
3232
import org.neo4j.gds.ml.core.EmbeddingUtils;
3333
import org.neo4j.gds.traversal.RandomWalkCompanion;
34+
import org.neo4j.gds.traversal.WalkParameters;
3435

3536
import java.util.ArrayList;
3637
import java.util.List;
@@ -45,10 +46,7 @@ public class Node2Vec extends Algorithm<Node2VecModel.Result> {
4546
private final double positiveSamplingFactor;
4647
private final double negativeSamplingExponent;
4748
private final int concurrency;
48-
private final int walksPerNode;
49-
private final int walkLength;
50-
private final double inOutFactor;
51-
private final double returnFactor;
49+
private final WalkParameters walkParameters;
5250
private final List<Long> sourceNodes;
5351
private final Optional<Long> maybeRandomSeed;
5452
private final double initialLearningRate;
@@ -78,10 +76,7 @@ static Node2Vec create(Graph graph, Node2VecBaseConfig config, ProgressTracker p
7876
config.concurrency(),
7977
config.positiveSamplingFactor(),
8078
config.negativeSamplingExponent(),
81-
config.walksPerNode(),
82-
config.walkLength(),
83-
config.inOutFactor(),
84-
config.returnFactor(),
79+
config.walkParameters(),
8580
config.sourceNodes(),
8681
config.randomSeed(),
8782
progressTracker,
@@ -100,10 +95,7 @@ public Node2Vec(
10095
int concurrency,
10196
double positiveSamplingFactor,
10297
double negativeSamplingExponent,
103-
int walksPerNode,
104-
int walkLength,
105-
double inOutFactor,
106-
double returnFactor,
98+
WalkParameters walkParameters,
10799
List<Long> sourceNodes,
108100
Optional<Long> maybeRandomSeed,
109101
ProgressTracker progressTracker,
@@ -122,10 +114,7 @@ public Node2Vec(
122114
this.positiveSamplingFactor = positiveSamplingFactor;
123115
this.negativeSamplingExponent = negativeSamplingExponent;
124116
this.concurrency = concurrency;
125-
this.walksPerNode = walksPerNode;
126-
this.walkLength = walkLength;
127-
this.inOutFactor = inOutFactor;
128-
this.returnFactor = returnFactor;
117+
this.walkParameters = walkParameters;
129118
this.sourceNodes = sourceNodes;
130119
this.maybeRandomSeed = maybeRandomSeed;
131120
this.initialLearningRate = initialLearningRate;
@@ -157,7 +146,7 @@ public Node2VecModel.Result compute() {
157146
negativeSamplingExponent,
158147
concurrency
159148
);
160-
var walks = new CompressedRandomWalks(graph.nodeCount() * walksPerNode);
149+
var walks = new CompressedRandomWalks(graph.nodeCount() * walkParameters.walksPerNode);
161150

162151
progressTracker.beginSubTask("RandomWalk");
163152

@@ -168,10 +157,7 @@ public Node2VecModel.Result compute() {
168157
maybeRandomSeed,
169158
concurrency,
170159
sourceNodes,
171-
walksPerNode,
172-
walkLength,
173-
inOutFactor,
174-
returnFactor,
160+
walkParameters,
175161
DefaultPool.INSTANCE,
176162
progressTracker,
177163
terminationFlag
@@ -222,10 +208,7 @@ private List<Node2VecRandomWalkTask> tasks(
222208
Optional<Long> maybeRandomSeed,
223209
int concurrency,
224210
List<Long> sourceNodes,
225-
int walksPerNode,
226-
int walkLength,
227-
double inOutFactor,
228-
double returnFactor,
211+
WalkParameters walkParameters,
229212
ExecutorService executorService,
230213
ProgressTracker progressTracker,
231214
TerminationFlag terminationFlag
@@ -245,17 +228,17 @@ private List<Node2VecRandomWalkTask> tasks(
245228
tasks.add(new Node2VecRandomWalkTask(
246229
graph.concurrentCopy(),
247230
nextNodeSupplier,
248-
walksPerNode,
231+
walkParameters.walksPerNode,
249232
cumulativeWeightsSupplier,
250233
progressTracker,
251234
terminationFlag,
252235
index,
253236
compressedRandomWalks,
254237
randomWalkPropabilitiesBuilder,
255238
randomSeed,
256-
walkLength,
257-
returnFactor,
258-
inOutFactor
239+
walkParameters.walkLength,
240+
walkParameters.returnFactor,
241+
walkParameters.inOutFactor
259242
));
260243
}
261244
return tasks;

algo/src/main/java/org/neo4j/gds/traversal/WalkParameters.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
* Parameter object holding Random Walk parameters.
2424
*/
2525
public class WalkParameters {
26-
final int walksPerNode;
27-
final int walkLength;
28-
final double returnFactor;
29-
final double inOutFactor;
26+
public final int walksPerNode;
27+
public final int walkLength;
28+
public final double returnFactor;
29+
public final double inOutFactor;
3030

3131
public WalkParameters(int walksPerNode, int walkLength, double returnFactor, double inOutFactor) {
3232
this.walksPerNode = walksPerNode;

0 commit comments

Comments
 (0)