@@ -100,9 +100,20 @@ void embeddingsShouldHaveTheConfiguredDimension(String msg, Iterable<String> nod
100
100
.graph ();
101
101
102
102
int embeddingDimension = 128 ;
103
+ var trainParameters = new TrainParameters (
104
+ 0.025 ,
105
+ 0.0001 ,
106
+ 1 ,
107
+ 10 ,
108
+ 5 ,
109
+ embeddingDimension ,
110
+ EmbeddingInitializer .NORMALIZED
111
+ );
103
112
HugeObjectArray <FloatVector > node2Vec = Node2Vec .create (
104
113
graph ,
105
- ImmutableNode2VecStreamConfig .builder ().embeddingDimension (embeddingDimension ).build (),
114
+ 4 ,
115
+ new WalkParameters (10 , 80 , 1.0 , 1.0 , 0.001 , 0.75 ),
116
+ trainParameters ,
106
117
ProgressTracker .NULL_TRACKER
107
118
).compute ().embeddings ();
108
119
@@ -132,11 +143,25 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
132
143
.embeddingDimension (embeddingDimension )
133
144
.build ();
134
145
var progressTask = new Node2VecAlgorithmFactory <>().progressTask (graph , config );
146
+
147
+ var walkParameters = new WalkParameters (10 , 80 , 1.0 , 1.0 , 0.001 , 0.75 );
148
+ var trainParameters = new TrainParameters (
149
+ 0.025 ,
150
+ 0.0001 ,
151
+ 1 ,
152
+ 10 ,
153
+ 5 ,
154
+ embeddingDimension ,
155
+ EmbeddingInitializer .NORMALIZED
156
+ );
135
157
var log = Neo4jProxy .testLog ();
136
158
var progressTracker = new TestProgressTracker (progressTask , log , 4 , EmptyTaskRegistryFactory .INSTANCE );
137
159
Node2Vec .create (
138
160
graph ,
139
- config ,
161
+ 4 ,
162
+ Optional .empty (),
163
+ walkParameters ,
164
+ trainParameters ,
140
165
progressTracker
141
166
).compute ();
142
167
@@ -170,10 +195,12 @@ void shouldLogProgress(boolean relationshipWeights, int expectedProgresses) {
170
195
@ Test
171
196
void shouldEstimateMemory () {
172
197
var nodeCount = 1000 ;
173
- var config = ImmutableNode2VecStreamConfig .builder ().build ();
174
- var memoryEstimation = Node2Vec .memoryEstimation (config .walksPerNode (), config .walkLength (), config .embeddingDimension ());
198
+ var walksPerNode = 10 ;
199
+ var walkLength = 80 ;
200
+ var embeddingDimension = 128 ;
201
+ var memoryEstimation = Node2Vec .memoryEstimation (walksPerNode , walkLength , embeddingDimension );
175
202
176
- var numberOfRandomWalks = nodeCount * config . walksPerNode () * config . walkLength () ;
203
+ var numberOfRandomWalks = nodeCount * walksPerNode * walkLength ;
177
204
var randomWalkMemoryUsageLowerBound = numberOfRandomWalks * Long .BYTES ;
178
205
179
206
var estimate = memoryEstimation .estimate (GraphDimensions .of (nodeCount ), 1 );
@@ -193,12 +220,16 @@ void shouldEstimateMemory() {
193
220
void failOnNegativeWeights () {
194
221
var graph = GdlFactory .of ("CREATE (a)-[:REL {weight: -1}]->(b)" ).build ().getUnion ();
195
222
196
- var config = ImmutableNode2VecStreamConfig
197
- .builder ()
198
- .relationshipWeightProperty ("weight" )
199
- .build ();
223
+ var walkParameters = new WalkParameters (10 , 80 , 1.0 , 1.0 , 0.001 , 0.75 );
224
+ var trainParameters = new TrainParameters (0.025 , 0.0001 , 1 , 1 , 1 , 128 , EmbeddingInitializer .NORMALIZED );
200
225
201
- var node2Vec = Node2Vec .create (graph , config , ProgressTracker .NULL_TRACKER );
226
+ var node2Vec = Node2Vec .create (
227
+ graph ,
228
+ 4 ,
229
+ walkParameters ,
230
+ trainParameters ,
231
+ ProgressTracker .NULL_TRACKER
232
+ );
202
233
203
234
assertThatThrownBy (node2Vec ::compute )
204
235
.isInstanceOf (RuntimeException .class )
@@ -214,30 +245,26 @@ void randomSeed(SoftAssertions softly) {
214
245
Graph graph = new StoreLoaderBuilder ().databaseService (db ).build ().graph ();
215
246
216
247
int embeddingDimension = 2 ;
217
-
218
- var config = ImmutableNode2VecStreamConfig
219
- .builder ()
220
- .embeddingDimension (embeddingDimension )
221
- .iterations (1 )
222
- .negativeSamplingRate (1 )
223
- .windowSize (1 )
224
- .walksPerNode (1 )
225
- .walkLength (20 )
226
- .walkBufferSize (50 )
227
- .randomSeed (1337L )
228
- .build ();
248
+ var walkParameters = new WalkParameters (1 , 20 , 1.0 , 1.0 , 0.001 , 0.75 );
249
+ var trainParameters = new TrainParameters (0.025 , 0.0001 , 1 , 1 , 1 , embeddingDimension , EmbeddingInitializer .NORMALIZED );
229
250
230
251
var embeddings = Node2Vec .create (
231
252
graph ,
232
- config ,
253
+ 4 ,
254
+ Optional .of (1337L ),
255
+ walkParameters ,
256
+ trainParameters ,
233
257
ProgressTracker .NULL_TRACKER
234
- ).compute ().embeddings ();
258
+ ).compute ().embeddings ();
235
259
236
260
var otherEmbeddings = Node2Vec .create (
237
261
graph ,
238
- config ,
262
+ 4 ,
263
+ Optional .of (1337L ),
264
+ walkParameters ,
265
+ trainParameters ,
239
266
ProgressTracker .NULL_TRACKER
240
- ).compute ().embeddings ();
267
+ ).compute ().embeddings ();
241
268
242
269
for (long node = 0 ; node < graph .nodeCount (); node ++) {
243
270
softly .assertThat (otherEmbeddings .get (node )).isEqualTo (embeddings .get (node ));
@@ -318,25 +345,26 @@ void shouldBeFairlyConsistentUnderOriginalIds(EmbeddingInitializer embeddingInit
318
345
var firstGraph = GraphFactory .create (firstIdMap , firstRelationships );
319
346
var secondGraph = GraphFactory .create (secondIdMap , secondRelationships );
320
347
321
- var config = ImmutableNode2VecStreamConfig
322
- .builder ()
323
- .embeddingInitializer (embeddingInitializer )
324
- .embeddingDimension (embeddingDimension )
325
- .randomSeed (1337L )
326
- .concurrency (1 )
327
- .build ();
348
+ var walkParameters = new WalkParameters (10 , 80 , 1.0 , 1.0 , 0.01 , 0.75 );
349
+ var trainParameters = new TrainParameters (0.025 , 0.0001 , 1 , 10 , 5 , embeddingDimension , embeddingInitializer );
328
350
329
351
var firstEmbeddings = Node2Vec .create (
330
352
firstGraph ,
331
- config ,
353
+ 4 ,
354
+ Optional .of (1337L ),
355
+ walkParameters ,
356
+ trainParameters ,
332
357
ProgressTracker .NULL_TRACKER
333
- ).compute ().embeddings ();
358
+ ).compute ().embeddings ();
334
359
335
360
var secondEmbeddings = Node2Vec .create (
336
361
secondGraph ,
337
- config ,
362
+ 4 ,
363
+ Optional .of (1337L ),
364
+ walkParameters ,
365
+ trainParameters ,
338
366
ProgressTracker .NULL_TRACKER
339
- ).compute ().embeddings ();
367
+ ).compute ().embeddings ();
340
368
341
369
double cosineSum = 0 ;
342
370
for (long originalNodeId = 0 ; originalNodeId < nodeCount ; originalNodeId ++) {
0 commit comments