Skip to content

Commit 6b22038

Browse files
committed
Almost purge Node2Vec config from algo code
1 parent 6941b24 commit 6b22038

File tree

6 files changed

+116
-74
lines changed

6 files changed

+116
-74
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2Vec.java

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,26 +61,43 @@ public static MemoryEstimation memoryEstimation(int walksPerNode, int walkLength
6161
.build();
6262
}
6363

64-
static Node2Vec create(Graph graph, Node2VecBaseConfig config, ProgressTracker progressTracker) {
64+
static Node2Vec create(
65+
Graph graph,
66+
int concurrency,
67+
WalkParameters walkParameters,
68+
TrainParameters trainParameters,
69+
ProgressTracker progressTracker
70+
) {
71+
return create(graph, concurrency, Optional.empty(), walkParameters, trainParameters, progressTracker);
72+
}
73+
74+
static Node2Vec create(
75+
Graph graph,
76+
int concurrency,
77+
Optional<Long> maybeRandomSeed,
78+
WalkParameters walkParameters,
79+
TrainParameters trainParameters,
80+
ProgressTracker progressTracker
81+
) {
6582
return new Node2Vec(
6683
graph,
67-
config.concurrency(),
68-
config.walkParameters(),
69-
config.sourceNodes(),
70-
config.randomSeed(),
71-
progressTracker,
72-
config.trainParameters()
84+
concurrency,
85+
List.of(),
86+
maybeRandomSeed,
87+
walkParameters,
88+
trainParameters,
89+
progressTracker
7390
);
7491
}
7592

7693
public Node2Vec(
7794
Graph graph,
7895
int concurrency,
79-
WalkParameters walkParameters,
8096
List<Long> sourceNodes,
8197
Optional<Long> maybeRandomSeed,
82-
ProgressTracker progressTracker,
83-
TrainParameters trainParameters
98+
WalkParameters walkParameters,
99+
TrainParameters trainParameters,
100+
ProgressTracker progressTracker
84101
) {
85102
super(progressTracker);
86103
this.graph = graph;

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecAlgorithmFactory.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,15 @@ public Node2Vec build(
4747
ProgressTracker progressTracker
4848
) {
4949
validateConfig(configuration, graph);
50-
return Node2Vec.create(graph, configuration, progressTracker);
50+
return new Node2Vec(
51+
graph,
52+
configuration.concurrency(),
53+
configuration.sourceNodes(),
54+
configuration.randomSeed(),
55+
configuration.walkParameters(),
56+
configuration.trainParameters(),
57+
progressTracker
58+
);
5159
}
5260

5361
@Override

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/TrainParameters.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public class TrainParameters {
2828
final int embeddingDimension;
2929
final EmbeddingInitializer embeddingInitializer;
3030

31-
TrainParameters(
31+
public TrainParameters(
3232
double initialLearningRate,
3333
double minLearningRate,
3434
int iterations,

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/WalkParameters.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ public class WalkParameters extends org.neo4j.gds.traversal.WalkParameters {
2323
final double negativeSamplingExponent;
2424
final double positiveSamplingFactor;
2525

26-
WalkParameters(int walksPerNode, int walkLength, double returnFactor, double inOutFactor, double positiveSamplingFactor, double negativeSamplingExponent) {
26+
public WalkParameters(
27+
int walksPerNode,
28+
int walkLength,
29+
double returnFactor,
30+
double inOutFactor,
31+
double positiveSamplingFactor,
32+
double negativeSamplingExponent
33+
) {
2734
super(walksPerNode, walkLength, returnFactor, inOutFactor);
2835
this.negativeSamplingExponent = negativeSamplingExponent;
2936
this.positiveSamplingFactor = positiveSamplingFactor;

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecModelTest.java

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,16 @@ void testModel() {
6060
random
6161
);
6262

63-
Node2VecStreamConfig defaults = ImmutableNode2VecStreamConfig.builder().build();
63+
var trainParameters = new TrainParameters(0.05, 0.0001, 5, 10, 1, 10, EmbeddingInitializer.NORMALIZED);
6464

6565
int nodeCount = numberOfClusters * clusterSize;
6666

6767
var node2VecModel = new Node2VecModel(
6868
nodeId -> nodeId,
6969
nodeCount,
70-
0.05,
71-
defaults.minLearningRate(),
72-
5,
73-
10,
74-
defaults.windowSize(),
75-
1,
76-
defaults.embeddingInitializer(),
70+
trainParameters,
7771
4,
78-
defaults.randomSeed(),
72+
Optional.empty(),
7973
walks,
8074
probabilitiesBuilder.build(),
8175
ProgressTracker.NULL_TRACKER
@@ -165,20 +159,14 @@ void randomSeed(int iterations) {
165159

166160
CompressedRandomWalks walks = generateRandomWalks(probabilitiesBuilder, numberOfClusters, clusterSize, numberOfWalks, walkLength, random);
167161

168-
Node2VecStreamConfig defaults = ImmutableNode2VecStreamConfig.builder().build();
162+
var trainParameters = new TrainParameters(0.05, 0.0001, iterations, 10, 1, 2, EmbeddingInitializer.NORMALIZED);
169163

170164
int nodeCount = numberOfClusters * clusterSize;
171165

172166
var node2VecModel = new Node2VecModel(
173167
nodeId -> nodeId,
174168
nodeCount,
175-
0.05,
176-
defaults.minLearningRate(),
177-
iterations,
178-
2,
179-
defaults.windowSize(),
180-
1,
181-
defaults.embeddingInitializer(),
169+
trainParameters,
182170
4,
183171
Optional.of(1337L),
184172
walks,
@@ -189,13 +177,7 @@ void randomSeed(int iterations) {
189177
var otherNode2VecModel = new Node2VecModel(
190178
nodeId -> nodeId,
191179
nodeCount,
192-
0.05,
193-
defaults.minLearningRate(),
194-
iterations,
195-
2,
196-
defaults.windowSize(),
197-
1,
198-
defaults.embeddingInitializer(),
180+
trainParameters,
199181
4,
200182
Optional.of(1337L),
201183
walks,

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/Node2VecTest.java

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,20 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable<String> nod
100100
.graph();
101101

102102
int embeddingDimension = 128;
103+
var trainParameters = new TrainParameters(
104+
0.025,
105+
0.0001,
106+
1,
107+
10,
108+
5,
109+
embeddingDimension,
110+
EmbeddingInitializer.NORMALIZED
111+
);
103112
HugeObjectArray<FloatVector> node2Vec = Node2Vec.create(
104113
graph,
105-
ImmutableNode2VecStreamConfig.builder().embeddingDimension(embeddingDimension).build(),
114+
4,
115+
new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75),
116+
trainParameters,
106117
ProgressTracker.NULL_TRACKER
107118
).compute().embeddings();
108119

@@ -132,11 +143,25 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
132143
.embeddingDimension(embeddingDimension)
133144
.build();
134145
var progressTask = new Node2VecAlgorithmFactory<>().progressTask(graph, config);
146+
147+
var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75);
148+
var trainParameters = new TrainParameters(
149+
0.025,
150+
0.0001,
151+
1,
152+
10,
153+
5,
154+
embeddingDimension,
155+
EmbeddingInitializer.NORMALIZED
156+
);
135157
var log = Neo4jProxy.testLog();
136158
var progressTracker = new TestProgressTracker(progressTask, log, 4, EmptyTaskRegistryFactory.INSTANCE);
137159
Node2Vec.create(
138160
graph,
139-
config,
161+
4,
162+
Optional.empty(),
163+
walkParameters,
164+
trainParameters,
140165
progressTracker
141166
).compute();
142167

@@ -170,10 +195,12 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
170195
@Test
171196
void shouldEstimateMemory() {
172197
var nodeCount = 1000;
173-
var config = ImmutableNode2VecStreamConfig.builder().build();
174-
var memoryEstimation = Node2Vec.memoryEstimation(config.walksPerNode(), config.walkLength(), config.embeddingDimension());
198+
var walksPerNode = 10;
199+
var walkLength = 80;
200+
var embeddingDimension = 128;
201+
var memoryEstimation = Node2Vec.memoryEstimation(walksPerNode, walkLength, embeddingDimension);
175202

176-
var numberOfRandomWalks = nodeCount * config.walksPerNode() * config.walkLength();
203+
var numberOfRandomWalks = nodeCount * walksPerNode * walkLength;
177204
var randomWalkMemoryUsageLowerBound = numberOfRandomWalks * Long.BYTES;
178205

179206
var estimate = memoryEstimation.estimate(GraphDimensions.of(nodeCount), 1);
@@ -193,12 +220,16 @@ void shouldEstimateMemory() {
193220
void failOnNegativeWeights() {
194221
var graph = GdlFactory.of("CREATE (a)-[:REL {weight: -1}]->(b)").build().getUnion();
195222

196-
var config = ImmutableNode2VecStreamConfig
197-
.builder()
198-
.relationshipWeightProperty("weight")
199-
.build();
223+
var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.001, 0.75);
224+
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, 128, EmbeddingInitializer.NORMALIZED);
200225

201-
var node2Vec = Node2Vec.create(graph, config, ProgressTracker.NULL_TRACKER);
226+
var node2Vec = Node2Vec.create(
227+
graph,
228+
4,
229+
walkParameters,
230+
trainParameters,
231+
ProgressTracker.NULL_TRACKER
232+
);
202233

203234
assertThatThrownBy(node2Vec::compute)
204235
.isInstanceOf(RuntimeException.class)
@@ -214,30 +245,26 @@ void randomSeed(SoftAssertions softly) {
214245
Graph graph = new StoreLoaderBuilder().databaseService(db).build().graph();
215246

216247
int embeddingDimension = 2;
217-
218-
var config = ImmutableNode2VecStreamConfig
219-
.builder()
220-
.embeddingDimension(embeddingDimension)
221-
.iterations(1)
222-
.negativeSamplingRate(1)
223-
.windowSize(1)
224-
.walksPerNode(1)
225-
.walkLength(20)
226-
.walkBufferSize(50)
227-
.randomSeed(1337L)
228-
.build();
248+
var walkParameters = new WalkParameters(1, 20, 1.0, 1.0, 0.001, 0.75);
249+
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 1, 1, embeddingDimension, EmbeddingInitializer.NORMALIZED);
229250

230251
var embeddings = Node2Vec.create(
231252
graph,
232-
config,
253+
4,
254+
Optional.of(1337L),
255+
walkParameters,
256+
trainParameters,
233257
ProgressTracker.NULL_TRACKER
234-
).compute().embeddings();
258+
).compute().embeddings();
235259

236260
var otherEmbeddings = Node2Vec.create(
237261
graph,
238-
config,
262+
4,
263+
Optional.of(1337L),
264+
walkParameters,
265+
trainParameters,
239266
ProgressTracker.NULL_TRACKER
240-
).compute().embeddings();
267+
).compute().embeddings();
241268

242269
for (long node = 0; node < graph.nodeCount(); node++) {
243270
softly.assertThat(otherEmbeddings.get(node)).isEqualTo(embeddings.get(node));
@@ -318,25 +345,26 @@ void shouldBeFairlyConsistentUnderOriginalIds(EmbeddingInitializer embeddingInit
318345
var firstGraph = GraphFactory.create(firstIdMap, firstRelationships);
319346
var secondGraph = GraphFactory.create(secondIdMap, secondRelationships);
320347

321-
var config = ImmutableNode2VecStreamConfig
322-
.builder()
323-
.embeddingInitializer(embeddingInitializer)
324-
.embeddingDimension(embeddingDimension)
325-
.randomSeed(1337L)
326-
.concurrency(1)
327-
.build();
348+
var walkParameters = new WalkParameters(10, 80, 1.0, 1.0, 0.01, 0.75);
349+
var trainParameters = new TrainParameters(0.025, 0.0001, 1, 10, 5, embeddingDimension, embeddingInitializer);
328350

329351
var firstEmbeddings = Node2Vec.create(
330352
firstGraph,
331-
config,
353+
4,
354+
Optional.of(1337L),
355+
walkParameters,
356+
trainParameters,
332357
ProgressTracker.NULL_TRACKER
333-
).compute().embeddings();
358+
).compute().embeddings();
334359

335360
var secondEmbeddings = Node2Vec.create(
336361
secondGraph,
337-
config,
362+
4,
363+
Optional.of(1337L),
364+
walkParameters,
365+
trainParameters,
338366
ProgressTracker.NULL_TRACKER
339-
).compute().embeddings();
367+
).compute().embeddings();
340368

341369
double cosineSum = 0;
342370
for (long originalNodeId = 0; originalNodeId < nodeCount; originalNodeId++) {

0 commit comments

Comments
 (0)