Skip to content

Commit 5679b0c

Browse files
authored
Merge pull request #6686 from chris1011011/refactor-selection
refactor manage nodeQueue in SelectionStrategy
2 parents afdcf6f + e7d2ca6 commit 5679b0c

File tree

8 files changed

+270
-205
lines changed

8 files changed

+270
-205
lines changed

algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentrality.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,11 @@
3333
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3434

3535
import java.util.concurrent.ExecutorService;
36-
import java.util.concurrent.atomic.AtomicLong;
3736
import java.util.function.Consumer;
3837

3938
public class BetweennessCentrality extends Algorithm<HugeAtomicDoubleArray> {
4039

4140
private final Graph graph;
42-
private final AtomicLong nodeQueue = new AtomicLong();
4341
private final long nodeCount;
4442
private final double divisor;
4543
private final ForwardTraverser.Factory traverserFactory;
@@ -75,7 +73,6 @@ public BetweennessCentrality(
7573
@Override
7674
public HugeAtomicDoubleArray compute() {
7775
progressTracker.beginSubTask();
78-
nodeQueue.set(0);
7976
ParallelUtil.run(ParallelUtil.tasks(concurrency, BCTask::new), executorService);
8077
progressTracker.endSubTask();
8178
return centrality;
@@ -113,15 +110,11 @@ public void run() {
113110
);
114111

115112
for (;;) {
116-
// take start node from the queue
117-
long startNodeId = nodeQueue.getAndIncrement();
118-
if (startNodeId >= nodeCount || !terminationFlag.running()) {
113+
long startNodeId = selectionStrategy.next();
114+
if (startNodeId == SelectionStrategy.NONE_SELECTED || !terminationFlag.running()) {
119115
return;
120116
}
121-
// check whether the node is part of the subset
122-
if (!selectionStrategy.select(startNodeId)) {
123-
continue;
124-
}
117+
125118
// reset
126119
getProgressTracker().logProgress();
127120

algo/src/main/java/org/neo4j/gds/betweenness/BetweennessCentralityFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ public BetweennessCentrality build(
5555
var samplingSeed = configuration.samplingSeed();
5656

5757
var strategy = samplingSize.isPresent() && samplingSize.get() < graph.nodeCount()
58-
? new SelectionStrategy.RandomDegree(samplingSize.get(), samplingSeed)
59-
: SelectionStrategy.ALL;
58+
? new RandomDegreeSelectionStrategy(samplingSize.get(), samplingSeed)
59+
: new FullSelectionStrategy();
6060

6161
ForwardTraverser.Factory traverserFactory = configuration.hasRelationshipWeightProperty()
6262
? ForwardTraverser.Factory.weighted()
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.betweenness;
21+
22+
import org.neo4j.gds.api.Graph;
23+
24+
import java.util.concurrent.ExecutorService;
25+
import java.util.concurrent.atomic.AtomicLong;
26+
27+
public class FullSelectionStrategy implements SelectionStrategy {
28+
29+
private final AtomicLong nodeQueue = new AtomicLong();
30+
private long graphSize = 0;
31+
32+
@Override
33+
public void init(Graph graph, ExecutorService executorService, int concurrency) {
34+
this.graphSize = graph.nodeCount();
35+
nodeQueue.set(0);
36+
}
37+
38+
@Override
39+
public long next() {
40+
long nextNodeId = nodeQueue.getAndIncrement();
41+
if (nextNodeId >= graphSize) {
42+
return NONE_SELECTED;
43+
}
44+
return nextNodeId;
45+
}
46+
}
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.betweenness;
21+
22+
import com.carrotsearch.hppc.BitSet;
23+
import com.carrotsearch.hppc.BitSetIterator;
24+
import org.neo4j.gds.api.Graph;
25+
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
26+
import org.neo4j.gds.core.utils.partition.Partition;
27+
import org.neo4j.gds.core.utils.partition.PartitionUtils;
28+
29+
import java.util.Collection;
30+
import java.util.Optional;
31+
import java.util.SplittableRandom;
32+
import java.util.concurrent.ExecutorService;
33+
import java.util.concurrent.atomic.AtomicInteger;
34+
import java.util.concurrent.atomic.AtomicLong;
35+
import java.util.stream.Collectors;
36+
37+
import static com.carrotsearch.hppc.BitSetIterator.NO_MORE;
38+
39+
public class RandomDegreeSelectionStrategy implements SelectionStrategy {
40+
41+
private final long samplingSize;
42+
private final Optional<Long> maybeRandomSeed;
43+
private final AtomicLong nodeQueue = new AtomicLong();
44+
45+
private long graphSize;
46+
private BitSet sampleSet;
47+
48+
public RandomDegreeSelectionStrategy(long samplingSize) {
49+
this(samplingSize, Optional.empty());
50+
}
51+
52+
public RandomDegreeSelectionStrategy(long samplingSize, Optional<Long> maybeRandomSeed) {
53+
this.samplingSize = samplingSize;
54+
this.maybeRandomSeed = maybeRandomSeed;
55+
}
56+
57+
@Override
58+
public void init(Graph graph, ExecutorService executorService, int concurrency) {
59+
assert samplingSize <= graph.nodeCount();
60+
this.sampleSet = new BitSet(graph.nodeCount());
61+
this.graphSize = graph.nodeCount();
62+
nodeQueue.set(0);
63+
var partitions = PartitionUtils.numberAlignedPartitioning(concurrency, graph.nodeCount(), Long.SIZE);
64+
var maxDegree = maxDegree(graph, partitions, executorService, concurrency);
65+
selectNodes(graph, partitions, maxDegree, executorService, concurrency);
66+
}
67+
68+
@Override
69+
public long next() {
70+
long nextNodeId;
71+
while ((nextNodeId = nodeQueue.getAndIncrement()) < graphSize) {
72+
if (sampleSet.get(nextNodeId)) {
73+
return nextNodeId;
74+
}
75+
}
76+
return NONE_SELECTED;
77+
}
78+
79+
private static int maxDegree(
80+
Graph graph,
81+
Collection<Partition> partitions,
82+
ExecutorService executorService,
83+
int concurrency
84+
) {
85+
AtomicInteger maxDegree = new AtomicInteger(0);
86+
87+
var tasks = partitions.stream()
88+
.map(partition -> (Runnable) () -> partition.consume(nodeId -> {
89+
int degree = graph.degree(nodeId);
90+
int current = maxDegree.get();
91+
while (degree > current) {
92+
int newCurrent = maxDegree.compareAndExchange(current, degree);
93+
if (newCurrent == current) {
94+
break;
95+
}
96+
current = newCurrent;
97+
}
98+
})).collect(Collectors.toList());
99+
100+
RunWithConcurrency.builder()
101+
.concurrency(concurrency)
102+
.tasks(tasks)
103+
.executor(executorService)
104+
.run();
105+
106+
return maxDegree.get();
107+
}
108+
109+
private void selectNodes(
110+
Graph graph,
111+
Collection<Partition> partitions,
112+
int maxDegree,
113+
ExecutorService executorService,
114+
int concurrency
115+
) {
116+
var random = maybeRandomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
117+
var selectionSize = new AtomicLong(0);
118+
var tasks = partitions.stream()
119+
.map(partition -> (Runnable) () -> {
120+
var threadLocalRandom = random.split();
121+
var fromNode = partition.startNode();
122+
var toNode = partition.startNode() + partition.nodeCount();
123+
124+
for (long nodeId = fromNode; nodeId < toNode; nodeId++) {
125+
var currentSelectionSize = selectionSize.get();
126+
if (currentSelectionSize >= samplingSize) {
127+
break;
128+
}
129+
int nodeDegree = graph.degree(nodeId);
130+
// probability factor is in range [1, maxDegree] (inclusive both ends)
131+
// the probability of a node being selected is probabilityFactor * (1 / maxDegree)
132+
int probabilityFactor = threadLocalRandom.nextInt(maxDegree) + 1;
133+
if (probabilityFactor <= nodeDegree) {
134+
while (true) {
135+
long actualCurrentSelectionSize = selectionSize.compareAndExchange(
136+
currentSelectionSize,
137+
currentSelectionSize + 1
138+
);
139+
if (currentSelectionSize == actualCurrentSelectionSize) {
140+
sampleSet.set(nodeId);
141+
break;
142+
}
143+
if (actualCurrentSelectionSize >= samplingSize) {
144+
break;
145+
}
146+
currentSelectionSize = actualCurrentSelectionSize;
147+
}
148+
}
149+
}
150+
}).collect(Collectors.toList());
151+
152+
RunWithConcurrency.builder()
153+
.concurrency(concurrency)
154+
.tasks(tasks)
155+
.executor(executorService)
156+
.run();
157+
158+
long actualSelectedNodes = selectionSize.get();
159+
160+
if (actualSelectedNodes < samplingSize) {
161+
// Flip bitset to be able to iterate unset bits.
162+
// The upper range is Graph#nodeCount() since
163+
// BitSet#size() returns a multiple of 64.
164+
// We need to make sure to stay within bounds.
165+
sampleSet.flip(0, graph.nodeCount());
166+
// Potentially iterate the bitset multiple times
167+
// until we have exactly numSeedNodes nodes.
168+
BitSetIterator iterator;
169+
while (actualSelectedNodes < samplingSize) {
170+
iterator = sampleSet.iterator();
171+
var unselectedNode = iterator.nextSetBit();
172+
while (unselectedNode != NO_MORE && actualSelectedNodes < samplingSize) {
173+
if (random.nextDouble() >= 0.5) {
174+
sampleSet.flip(unselectedNode);
175+
actualSelectedNodes++;
176+
}
177+
unselectedNode = iterator.nextSetBit();
178+
}
179+
}
180+
sampleSet.flip(0, graph.nodeCount());
181+
}
182+
}
183+
}

0 commit comments

Comments
 (0)