Skip to content

Commit ef1b195

Browse files
Merge pull request #10439 from brs96/213-splittable-random-fix
Node2Vec replace ThreadId with taskId
2 parents 65f1eb6 + 0e08b3e commit ef1b195

File tree

4 files changed

+56
-37
lines changed

4 files changed

+56
-37
lines changed

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/NegativeSampleProducer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class NegativeSampleProducer {
2727

2828
private final HugeLongArray contextNodeDistribution;
2929
private final long cumulativeProbability;
30-
private SplittableRandom splittableRandom;
30+
private final SplittableRandom splittableRandom;
3131

3232
public NegativeSampleProducer(
3333
HugeLongArray contextNodeDistribution,

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/Node2VecModel.java

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
import org.neo4j.gds.ml.core.tensor.FloatVector;
3131

3232
import java.util.ArrayList;
33+
import java.util.List;
3334
import java.util.Optional;
3435
import java.util.Random;
3536
import java.util.SplittableRandom;
37+
import java.util.concurrent.atomic.AtomicInteger;
3638
import java.util.function.LongUnaryOperator;
3739

3840
import static org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations.addInPlace;
@@ -41,7 +43,6 @@
4143

4244
public class Node2VecModel {
4345

44-
4546
private final HugeObjectArray<FloatVector> centerEmbeddings;
4647
private final HugeObjectArray<FloatVector> contextEmbeddings;
4748
private final double initialLearningRate;
@@ -136,31 +137,7 @@ Node2VecResult train() {
136137
initialLearningRate - iteration * learningRateAlpha
137138
);
138139

139-
var tasks = PartitionUtils.degreePartitionWithBatchSize(
140-
walks.size(),
141-
walks::walkLength,
142-
BitUtil.ceilDiv(randomWalkProbabilities.sampleCount(), concurrency.value()),
143-
partition -> {
144-
var positiveSampleProducer = new PositiveSampleProducer(
145-
walks.iterator(partition.startNode(), partition.nodeCount()),
146-
randomWalkProbabilities.positiveSamplingProbabilities(),
147-
windowSize,
148-
Optional.of(randomSeed)
149-
);
150-
151-
return new TrainingTask(
152-
centerEmbeddings,
153-
contextEmbeddings,
154-
positiveSampleProducer,
155-
randomWalkProbabilities.negativeSamplingDistribution(),
156-
learningRate,
157-
negativeSamplingRate,
158-
embeddingDimension,
159-
progressTracker,
160-
randomSeed
161-
);
162-
}
163-
);
140+
var tasks = createTrainingTasks(learningRate);
164141

165142
RunWithConcurrency.builder()
166143
.concurrency(concurrency)
@@ -232,12 +209,13 @@ private TrainingTask(
232209
int negativeSamplingRate,
233210
int embeddingDimensions,
234211
ProgressTracker progressTracker,
235-
long randomSeed
212+
long randomSeed,
213+
int taskId
236214
) {
237215
this.centerEmbeddings = centerEmbeddings;
238216
this.contextEmbeddings = contextEmbeddings;
239217
this.positiveSampleProducer = positiveSampleProducer;
240-
this.negativeSampleProducer = new NegativeSampleProducer(negativeSamples, randomSeed + Thread.currentThread().getId());
218+
this.negativeSampleProducer = new NegativeSampleProducer(negativeSamples, randomSeed + taskId);
241219
this.learningRate = learningRate;
242220
this.negativeSamplingRate = negativeSamplingRate;
243221

@@ -310,4 +288,38 @@ void addAll(FloatConsumer other) {
310288
}
311289
}
312290

291+
List<TrainingTask> createTrainingTasks(float learningRate){
292+
AtomicInteger taskIndex = new AtomicInteger(0);
293+
return PartitionUtils.degreePartitionWithBatchSize(
294+
walks.size(),
295+
walks::walkLength,
296+
BitUtil.ceilDiv(randomWalkProbabilities.sampleCount(), concurrency.value()),
297+
partition -> {
298+
299+
var taskId = taskIndex.getAndIncrement();
300+
var positiveSampleProducer = new PositiveSampleProducer(
301+
walks.iterator(partition.startNode(), partition.nodeCount()),
302+
randomWalkProbabilities.positiveSamplingProbabilities(),
303+
windowSize,
304+
Optional.of(randomSeed),
305+
taskId
306+
);
307+
308+
return new TrainingTask(
309+
centerEmbeddings,
310+
contextEmbeddings,
311+
positiveSampleProducer,
312+
randomWalkProbabilities.negativeSamplingDistribution(),
313+
learningRate,
314+
negativeSamplingRate,
315+
embeddingDimension,
316+
progressTracker,
317+
randomSeed,
318+
taskId
319+
);
320+
}
321+
);
322+
323+
}
324+
313325
}

algo/src/main/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducer.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ public class PositiveSampleProducer {
4848
Iterator<long[]> walks,
4949
HugeDoubleArray samplingProbabilities,
5050
int windowSize,
51-
Optional<Long> maybeRandomSeed
51+
Optional<Long> maybeRandomSeed,
52+
int taskId
5253
) {
5354
this.walks = walks;
5455
this.samplingProbabilities = samplingProbabilities;
@@ -60,7 +61,7 @@ public class PositiveSampleProducer {
6061
this.centerWordIndex = -1;
6162
this.contextWordIndex = 1;
6263
probabilitySupplier = maybeRandomSeed
63-
.map(seed -> new SplittableRandom(Thread.currentThread().getId() + seed))
64+
.map(seed -> new SplittableRandom(taskId + seed))
6465
.orElseGet(() -> new SplittableRandom(ThreadLocalRandom.current().nextLong()));
6566

6667
}

algo/src/test/java/org/neo4j/gds/embeddings/node2vec/PositiveSampleProducerTest.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ void doesNotCauseStackOverflow() {
6161
walks.iterator(0, nbrOfWalks),
6262
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
6363
10,
64-
Optional.empty()
64+
Optional.empty(),
65+
0
6566
);
6667

6768
var counter = 0L;
@@ -89,7 +90,8 @@ void doesNotCauseStackOverflowDueToBadLuck() {
8990
walks.iterator(0, nbrOfWalks),
9091
probabilities,
9192
10,
92-
Optional.empty()
93+
Optional.empty(),
94+
0
9395
);
9496
// does not overflow the stack = passes test
9597

@@ -113,7 +115,8 @@ void doesNotAttemptToFetchOutsideBatch() {
113115
walks.iterator(0, nbrOfWalks / 2),
114116
HugeDoubleArray.of(LongStream.range(0, nbrOfWalks).mapToDouble((l) -> 1.0).toArray()),
115117
10,
116-
Optional.empty()
118+
Optional.empty(),
119+
0
117120
);
118121

119122
var counter = 0L;
@@ -138,7 +141,8 @@ void shouldProducePairsWith(
138141
walks.iterator(0, walks.size()),
139142
centerNodeProbabilities,
140143
windowSize,
141-
Optional.empty()
144+
Optional.empty(),
145+
0
142146
);
143147
while (producer.next(buffer)) {
144148
actualPairs.add(Pair.of(buffer[0], buffer[1]));
@@ -161,7 +165,8 @@ void shouldProducePairsWithBounds() {
161165
walks.iterator(0, 2),
162166
centerNodeProbabilities,
163167
3,
164-
Optional.empty()
168+
Optional.empty(),
169+
0
165170
);
166171
while (producer.next(buffer)) {
167172
actualPairs.add(Pair.of(buffer[0], buffer[1]));
@@ -207,7 +212,8 @@ void shouldRemoveDownsampledWordFromWalk() {
207212
walks.iterator(0, walks.size()),
208213
centerNodeProbabilities,
209214
3,
210-
Optional.empty()
215+
Optional.empty(),
216+
0
211217
);
212218

213219
while (producer.next(buffer)) {

0 commit comments

Comments
 (0)