31
31
import org .neo4j .gds .mem .MemoryUsage ;
32
32
import org .neo4j .gds .ml .core .EmbeddingUtils ;
33
33
import org .neo4j .gds .traversal .RandomWalkCompanion ;
34
+ import org .neo4j .gds .traversal .WalkParameters ;
34
35
35
36
import java .util .ArrayList ;
36
37
import java .util .List ;
@@ -45,10 +46,7 @@ public class Node2Vec extends Algorithm<Node2VecModel.Result> {
45
46
private final double positiveSamplingFactor ;
46
47
private final double negativeSamplingExponent ;
47
48
private final int concurrency ;
48
- private final int walksPerNode ;
49
- private final int walkLength ;
50
- private final double inOutFactor ;
51
- private final double returnFactor ;
49
+ private final WalkParameters walkParameters ;
52
50
private final List <Long > sourceNodes ;
53
51
private final Optional <Long > maybeRandomSeed ;
54
52
private final double initialLearningRate ;
@@ -78,10 +76,7 @@ static Node2Vec create(Graph graph, Node2VecBaseConfig config, ProgressTracker p
78
76
config .concurrency (),
79
77
config .positiveSamplingFactor (),
80
78
config .negativeSamplingExponent (),
81
- config .walksPerNode (),
82
- config .walkLength (),
83
- config .inOutFactor (),
84
- config .returnFactor (),
79
+ config .walkParameters (),
85
80
config .sourceNodes (),
86
81
config .randomSeed (),
87
82
progressTracker ,
@@ -100,10 +95,7 @@ public Node2Vec(
100
95
int concurrency ,
101
96
double positiveSamplingFactor ,
102
97
double negativeSamplingExponent ,
103
- int walksPerNode ,
104
- int walkLength ,
105
- double inOutFactor ,
106
- double returnFactor ,
98
+ WalkParameters walkParameters ,
107
99
List <Long > sourceNodes ,
108
100
Optional <Long > maybeRandomSeed ,
109
101
ProgressTracker progressTracker ,
@@ -122,10 +114,7 @@ public Node2Vec(
122
114
this .positiveSamplingFactor = positiveSamplingFactor ;
123
115
this .negativeSamplingExponent = negativeSamplingExponent ;
124
116
this .concurrency = concurrency ;
125
- this .walksPerNode = walksPerNode ;
126
- this .walkLength = walkLength ;
127
- this .inOutFactor = inOutFactor ;
128
- this .returnFactor = returnFactor ;
117
+ this .walkParameters = walkParameters ;
129
118
this .sourceNodes = sourceNodes ;
130
119
this .maybeRandomSeed = maybeRandomSeed ;
131
120
this .initialLearningRate = initialLearningRate ;
@@ -157,7 +146,7 @@ public Node2VecModel.Result compute() {
157
146
negativeSamplingExponent ,
158
147
concurrency
159
148
);
160
- var walks = new CompressedRandomWalks (graph .nodeCount () * walksPerNode );
149
+ var walks = new CompressedRandomWalks (graph .nodeCount () * walkParameters . walksPerNode );
161
150
162
151
progressTracker .beginSubTask ("RandomWalk" );
163
152
@@ -168,10 +157,7 @@ public Node2VecModel.Result compute() {
168
157
maybeRandomSeed ,
169
158
concurrency ,
170
159
sourceNodes ,
171
- walksPerNode ,
172
- walkLength ,
173
- inOutFactor ,
174
- returnFactor ,
160
+ walkParameters ,
175
161
DefaultPool .INSTANCE ,
176
162
progressTracker ,
177
163
terminationFlag
@@ -222,10 +208,7 @@ private List<Node2VecRandomWalkTask> tasks(
222
208
Optional <Long > maybeRandomSeed ,
223
209
int concurrency ,
224
210
List <Long > sourceNodes ,
225
- int walksPerNode ,
226
- int walkLength ,
227
- double inOutFactor ,
228
- double returnFactor ,
211
+ WalkParameters walkParameters ,
229
212
ExecutorService executorService ,
230
213
ProgressTracker progressTracker ,
231
214
TerminationFlag terminationFlag
@@ -245,17 +228,17 @@ private List<Node2VecRandomWalkTask> tasks(
245
228
tasks .add (new Node2VecRandomWalkTask (
246
229
graph .concurrentCopy (),
247
230
nextNodeSupplier ,
248
- walksPerNode ,
231
+ walkParameters . walksPerNode ,
249
232
cumulativeWeightsSupplier ,
250
233
progressTracker ,
251
234
terminationFlag ,
252
235
index ,
253
236
compressedRandomWalks ,
254
237
randomWalkPropabilitiesBuilder ,
255
238
randomSeed ,
256
- walkLength ,
257
- returnFactor ,
258
- inOutFactor
239
+ walkParameters . walkLength ,
240
+ walkParameters . returnFactor ,
241
+ walkParameters . inOutFactor
259
242
));
260
243
}
261
244
return tasks ;
0 commit comments