Skip to content

Commit 6941b24

Browse files
committed
Embedding initializer is a train parameter
1 parent d54e329 commit 6941b24

File tree

4 files changed

+9
-11
lines changed

4 files changed

+9
-11
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ public class Node2Vec extends Algorithm<Node2VecModel.Result> {
4747
private final List<Long> sourceNodes;
4848
private final Optional<Long> maybeRandomSeed;
4949
private final TrainParameters trainParameters;
50-
private final EmbeddingInitializer embeddingInitializer;
5150

5251

5352
public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength, int embeddingDimension) {
@@ -70,8 +69,7 @@ static Node2Vec create(Graph graph, Node2VecBaseConfig config, ProgressTracker p
7069
config.sourceNodes(),
7170
config.randomSeed(),
7271
progressTracker,
73-
config.trainParameters(),
74-
config.embeddingInitializer()
72+
config.trainParameters()
7573
);
7674
}
7775

@@ -82,8 +80,7 @@ public Node2Vec(
8280
List<Long> sourceNodes,
8381
Optional<Long> maybeRandomSeed,
8482
ProgressTracker progressTracker,
85-
TrainParameters trainParameters,
86-
EmbeddingInitializer embeddingInitializer
83+
TrainParameters trainParameters
8784
) {
8885
super(progressTracker);
8986
this.graph = graph;
@@ -92,7 +89,6 @@ public Node2Vec(
9289
this.sourceNodes = sourceNodes;
9390
this.maybeRandomSeed = maybeRandomSeed;
9491
this.trainParameters = trainParameters;
95-
this.embeddingInitializer = embeddingInitializer;
9692
}
9793

9894
@Override
@@ -151,7 +147,6 @@ public Node2VecModel.Result compute() {
151147
graph::toOriginalNodeId,
152148
graph.nodeCount(),
153149
trainParameters,
154-
embeddingInitializer,
155150
concurrency,
156151
maybeRandomSeed,
157152
walks,

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecBaseConfig.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ default TrainParameters trainParameters() {
114114
iterations(),
115115
windowSize(),
116116
negativeSamplingRate(),
117-
embeddingDimension()
117+
embeddingDimension(),
118+
embeddingInitializer()
118119
);
119120
}
120121
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ public static MemoryEstimation memoryEstimation(int embeddingDimension) {
8080
LongUnaryOperator toOriginalId,
8181
long nodeCount,
8282
TrainParameters trainParameters,
83-
EmbeddingInitializer embeddingInitializer,
8483
int concurrency,
8584
Optional<Long> maybeRandomSeed,
8685
CompressedRandomWalks walks,
@@ -96,7 +95,7 @@ public static MemoryEstimation memoryEstimation(int embeddingDimension) {
9695
trainParameters.windowSize,
9796
trainParameters.negativeSamplingRate,
9897
trainParameters.embeddingDimension,
99-
embeddingInitializer,
98+
trainParameters.embeddingInitializer,
10099
concurrency,
101100
maybeRandomSeed,
102101
walks,

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/TrainParameters.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,23 @@ public class TrainParameters {
2626
final int windowSize;
2727
final int negativeSamplingRate;
2828
final int embeddingDimension;
29+
final EmbeddingInitializer embeddingInitializer;
2930

3031
TrainParameters(
3132
double initialLearningRate,
3233
double minLearningRate,
3334
int iterations,
3435
int windowSize,
3536
int negativeSamplingRate,
36-
int embeddingDimension
37+
int embeddingDimension,
38+
EmbeddingInitializer embeddingInitializer
3739
) {
3840
this.initialLearningRate = initialLearningRate;
3941
this.minLearningRate = minLearningRate;
4042
this.iterations = iterations;
4143
this.windowSize = windowSize;
4244
this.negativeSamplingRate = negativeSamplingRate;
4345
this.embeddingDimension = embeddingDimension;
46+
this.embeddingInitializer = embeddingInitializer;
4447
}
4548
}

0 commit comments

Comments
 (0)