Skip to content

Commit b9edfe5

Browse files
committed
Refine parameter objects for Node2Vec
1 parent 8e45487 commit b9edfe5

File tree

8 files changed

+150
-55
lines changed

8 files changed

+150
-55
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.neo4j.gds.mem.MemoryUsage;
3232
import org.neo4j.gds.ml.core.EmbeddingUtils;
3333
import org.neo4j.gds.traversal.RandomWalkCompanion;
34-
import org.neo4j.gds.traversal.WalkParameters;
3534

3635
import java.util.ArrayList;
3736
import java.util.List;
@@ -43,18 +42,11 @@
4342
public class Node2Vec extends Algorithm<Node2VecModel.Result> {
4443

4544
private final Graph graph;
46-
private final double positiveSamplingFactor;
47-
private final double negativeSamplingExponent;
4845
private final int concurrency;
4946
private final WalkParameters walkParameters;
5047
private final List<Long> sourceNodes;
5148
private final Optional<Long> maybeRandomSeed;
52-
private final double initialLearningRate;
53-
private final double minLearningRate;
54-
private final int iterations;
55-
private final int embeddingDimension;
56-
private final int windowSize;
57-
private final int negativeSamplingRate;
49+
private final TrainParameters trainParameters;
5850
private final Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer;
5951

6052

@@ -74,55 +66,32 @@ static Node2Vec create(Graph graph, Node2VecBaseConfig config, ProgressTracker p
7466
return new Node2Vec(
7567
graph,
7668
config.concurrency(),
77-
config.positiveSamplingFactor(),
78-
config.negativeSamplingExponent(),
7969
config.walkParameters(),
8070
config.sourceNodes(),
8171
config.randomSeed(),
8272
progressTracker,
83-
config.initialLearningRate(),
84-
config.minLearningRate(),
85-
config.iterations(),
86-
config.embeddingDimension(),
87-
config.windowSize(),
88-
config.negativeSamplingRate(),
73+
config.trainParameters(),
8974
config.embeddingInitializer()
9075
);
9176
}
9277

9378
public Node2Vec(
9479
Graph graph,
9580
int concurrency,
96-
double positiveSamplingFactor,
97-
double negativeSamplingExponent,
9881
WalkParameters walkParameters,
9982
List<Long> sourceNodes,
10083
Optional<Long> maybeRandomSeed,
10184
ProgressTracker progressTracker,
102-
// train params
103-
double initialLearningRate,
104-
double minLearningRate,
105-
int iterations,
106-
int embeddingDimension,
107-
int windowSize,
108-
int negativeSamplingRate,
85+
TrainParameters trainParameters,
10986
Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer
110-
11187
) {
11288
super(progressTracker);
11389
this.graph = graph;
114-
this.positiveSamplingFactor = positiveSamplingFactor;
115-
this.negativeSamplingExponent = negativeSamplingExponent;
11690
this.concurrency = concurrency;
11791
this.walkParameters = walkParameters;
11892
this.sourceNodes = sourceNodes;
11993
this.maybeRandomSeed = maybeRandomSeed;
120-
this.initialLearningRate = initialLearningRate;
121-
this.minLearningRate = minLearningRate;
122-
this.iterations = iterations;
123-
this.embeddingDimension = embeddingDimension;
124-
this.windowSize = windowSize;
125-
this.negativeSamplingRate = negativeSamplingRate;
94+
this.trainParameters = trainParameters;
12695
this.embeddingInitializer = embeddingInitializer;
12796
}
12897

@@ -142,15 +111,15 @@ public Node2VecModel.Result compute() {
142111

143112
var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
144113
graph.nodeCount(),
145-
positiveSamplingFactor,
146-
negativeSamplingExponent,
147-
concurrency
114+
concurrency,
115+
walkParameters.positiveSamplingFactor,
116+
walkParameters.negativeSamplingExponent
148117
);
149118
var walks = new CompressedRandomWalks(graph.nodeCount() * walkParameters.walksPerNode);
150119

151120
progressTracker.beginSubTask("RandomWalk");
152121

153-
var tasks = tasks(
122+
var tasks = walkTasks(
154123
walks,
155124
probabilitiesBuilder,
156125
graph,
@@ -181,12 +150,7 @@ public Node2VecModel.Result compute() {
181150
var node2VecModel = new Node2VecModel(
182151
graph::toOriginalNodeId,
183152
graph.nodeCount(),
184-
initialLearningRate,
185-
minLearningRate,
186-
iterations,
187-
embeddingDimension,
188-
windowSize,
189-
negativeSamplingRate,
153+
trainParameters,
190154
embeddingInitializer,
191155
concurrency,
192156
maybeRandomSeed,
@@ -201,7 +165,7 @@ public Node2VecModel.Result compute() {
201165
return result;
202166
}
203167

204-
private List<Node2VecRandomWalkTask> tasks(
168+
private List<Node2VecRandomWalkTask> walkTasks(
205169
CompressedRandomWalks compressedRandomWalks,
206170
RandomWalkProbabilities.Builder randomWalkPropabilitiesBuilder,
207171
Graph graph,

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecBaseConfig.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,30 @@ default int iterations() {
132132
default List<Long> sourceNodes() {
133133
return List.of();
134134
}
135+
136+
@Configuration.Ignore
137+
@Value.Derived
138+
default WalkParameters walkParameters() {
139+
return new WalkParameters(
140+
walksPerNode(),
141+
walkLength(),
142+
returnFactor(),
143+
inOutFactor(),
144+
positiveSamplingFactor(),
145+
negativeSamplingExponent()
146+
);
147+
}
148+
149+
@Configuration.Ignore
150+
@Value.Derived
151+
default TrainParameters trainParameters() {
152+
return new TrainParameters(
153+
initialLearningRate(),
154+
minLearningRate(),
155+
iterations(),
156+
windowSize(),
157+
negativeSamplingRate(),
158+
embeddingDimension()
159+
);
160+
}
135161
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,44 @@ public static MemoryEstimation memoryEstimation(int embeddingDimension) {
7676
.build();
7777
}
7878

79+
Node2VecModel(
80+
LongUnaryOperator toOriginalId,
81+
long nodeCount,
82+
TrainParameters trainParameters,
83+
Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer,
84+
int concurrency,
85+
Optional<Long> maybeRandomSeed,
86+
CompressedRandomWalks walks,
87+
RandomWalkProbabilities randomWalkProbabilities,
88+
ProgressTracker progressTracker
89+
) {
90+
this(
91+
toOriginalId,
92+
nodeCount,
93+
trainParameters.initialLearningRate,
94+
trainParameters.minLearningRate,
95+
trainParameters.iterations,
96+
trainParameters.windowSize,
97+
trainParameters.negativeSamplingRate,
98+
trainParameters.embeddingDimension,
99+
embeddingInitializer,
100+
concurrency,
101+
maybeRandomSeed,
102+
walks,
103+
randomWalkProbabilities,
104+
progressTracker
105+
);
106+
}
107+
79108
Node2VecModel(
80109
LongUnaryOperator toOriginalId,
81110
long nodeCount,
82111
double initialLearningRate,
83112
double minLearningRate,
84113
int iterations,
85-
int embeddingDimension,
86114
int windowSize,
87115
int negativeSamplingRate,
116+
int embeddingDimension,
88117
Node2VecBaseConfig.EmbeddingInitializer embeddingInitializer,
89118
int concurrency,
90119
Optional<Long> maybeRandomSeed,

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/RandomWalkProbabilities.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ class Builder {
6262

6363
Builder(
6464
long nodeCount,
65+
int concurrency,
6566
double positiveSamplingFactor,
66-
double negativeSamplingExponent,
67-
int concurrency
67+
double negativeSamplingExponent
6868
) {
6969
this.nodeCount = nodeCount;
7070
this.concurrency = concurrency;
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.node2vec;
21+
22+
public class TrainParameters {
23+
final double initialLearningRate;
24+
final double minLearningRate;
25+
final int iterations;
26+
final int windowSize;
27+
final int negativeSamplingRate;
28+
final int embeddingDimension;
29+
30+
TrainParameters(
31+
double initialLearningRate,
32+
double minLearningRate,
33+
int iterations,
34+
int windowSize,
35+
int negativeSamplingRate,
36+
int embeddingDimension
37+
) {
38+
this.initialLearningRate = initialLearningRate;
39+
this.minLearningRate = minLearningRate;
40+
this.iterations = iterations;
41+
this.windowSize = windowSize;
42+
this.negativeSamplingRate = negativeSamplingRate;
43+
this.embeddingDimension = embeddingDimension;
44+
}
45+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.embeddings.node2vec;
21+
22+
public class WalkParameters extends org.neo4j.gds.traversal.WalkParameters {
23+
final double negativeSamplingExponent;
24+
final double positiveSamplingFactor;
25+
26+
WalkParameters(int walksPerNode, int walkLength, double returnFactor, double inOutFactor, double positiveSamplingFactor, double negativeSamplingExponent) {
27+
super(walksPerNode, walkLength, returnFactor, inOutFactor);
28+
this.negativeSamplingExponent = negativeSamplingExponent;
29+
this.positiveSamplingFactor = positiveSamplingFactor;
30+
}
31+
}

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/NegativeSampleProducerTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ class NegativeSampleProducerTest {
3434
void shouldProduceSamplesAccordingToNodeDistribution() {
3535
var builder = new RandomWalkProbabilities.Builder(
3636
2,
37+
4,
3738
0.001,
38-
0.75,
39-
4
39+
0.75
4040
);
4141

4242
builder

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ void testModel() {
4646

4747
var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
4848
numberOfClusters * clusterSize,
49+
4,
4950
0.001,
50-
0.75,
51-
4
51+
0.75
5252
);
5353

5454
CompressedRandomWalks walks = generateRandomWalks(
@@ -158,9 +158,9 @@ void randomSeed(int iterations) {
158158

159159
var probabilitiesBuilder = new RandomWalkProbabilities.Builder(
160160
numberOfClusters * clusterSize,
161+
4,
161162
0.001,
162-
0.75,
163-
4
163+
0.75
164164
);
165165

166166
CompressedRandomWalks walks = generateRandomWalks(probabilitiesBuilder, numberOfClusters, clusterSize, numberOfWalks, walkLength, random);

0 commit comments

Comments
 (0)