Skip to content

Commit c9a985d

Browse files
Use symmetry for more pruning
Co-authored-by: Veselin Nikolov <[email protected]>
1 parent 821a606 commit c9a985d

File tree

3 files changed

+24
-24
lines changed

3 files changed

+24
-24
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/BoruvkaMST.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ public final class BoruvkaMST extends Algorithm<GeometricMSTResult> {
4545
private long edgeCount = 0;
4646
private double totalEdgeSum = 0d;
4747

48-
4948
private BoruvkaMST(
5049
Distances distances,
5150
KdTree kdTree,
@@ -135,7 +134,7 @@ private void performIteration() {
135134
concurrency,
136135
terminationFlag,
137136
(q) -> {
138-
var qComp = unionFind.setIdOf(q);
137+
var qComp = unionFind.setIdOf(q);
139138
if (filterNodesOnCoreValue(q,qComp)) {
140139
traversalStep(q, kdTree.root(), qComp, 0);
141140
}
@@ -161,16 +160,16 @@ boolean prune(KdNode kdNode, long componentId, double lowerBoundOnDistance){
161160
return currentComponentBest < lowerBoundOnDistance;
162161
}
163162

164-
boolean tryUpdate(long qComp, long q,long r, double distance){
165-
return closestDistanceTracker.tryToAssign(qComp,q,r,distance);
163+
boolean tryUpdate(long qComp, long rComp, long q,long r, double distance){
164+
return closestDistanceTracker.consider(qComp,rComp,q,r,distance);
166165
}
167166

168167
double baseCase(long q,long r, long qComp){
169168
var rComp = unionFind.setIdOf(r);
170169
if (rComp != qComp && filterNodesOnCoreValue(r, qComp)) {
171170
var rqDistance = distances.computeDistanceUnsquared(q,r);
172171
var adaptedDistance = Math.max(Math.max(coreValues.get(r), coreValues.get(q)), rqDistance);
173-
if (tryUpdate(qComp,q,r,adaptedDistance)){
172+
if (tryUpdate(qComp,rComp,q,r,adaptedDistance)){
174173
return adaptedDistance;
175174
}
176175
}

algo/src/main/java/org/neo4j/gds/hdbscan/ClosestDistanceInformationTracker.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ void resetComponent(long u) {
9696
componentOutsideBestNode.set(u, -1);
9797
}
9898

99-
void consider(long comp1, long comp2, long p1, long p2, double distance) {
100-
tryToAssign(comp1, p1, p2, distance);
99+
boolean consider(long comp1, long comp2, long p1, long p2, double distance) {
100+
var assigned = tryToAssign(comp1, p1, p2, distance);
101101
tryToAssign(comp2, p2, p1, distance);
102+
return assigned;
102103
}
103104

104105
synchronized boolean tryToAssign(long comp, long pInside, long pOutside, double distance) {

algo/src/test/java/org/neo4j/gds/hdbscan/BoruvkaAlgorithmFunctionsTest.java

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ class BoruvkaAlgorithmFunctionsTest {
3636
@Test
3737
void singleComponentShouldWorkOnLeaf(){
3838
KdNode kdNode = KdNode.createLeaf(0, 0, 4, null);
39-
var kdTree =new KdTree(
39+
var kdTree = new KdTree(
4040
HugeLongArray.of(0,1,2,3),
4141
null,
4242
kdNode,
4343
1
4444
);
4545

46-
var boruvkaMST = BoruvkaMST.createWithZeroCores(
46+
var boruvkaMST = BoruvkaMST.createWithZeroCores(
4747
null,
4848
kdTree,
4949
4,
@@ -62,14 +62,14 @@ void singleComponentShouldWorkOnLeaf(){
6262

6363
@Test
6464
void singleComponentOrShouldWork(){
65-
var kdTree =new KdTree(
65+
var kdTree = new KdTree(
6666
HugeLongArray.of(0,1,2,3),
6767
null,
6868
null,
6969
1
7070
);
7171

72-
var boruvkaMST = BoruvkaMST.createWithZeroCores(
72+
var boruvkaMST = BoruvkaMST.createWithZeroCores(
7373
null,
7474
kdTree,
7575
4,
@@ -102,14 +102,14 @@ void singleComponentShouldWorkOnSplitNode(){
102102
kdNode.leftChild(left);
103103
kdNode.rightChild(right);
104104

105-
var kdTree =new KdTree(
105+
var kdTree = new KdTree(
106106
HugeLongArray.of(0,1,2,3,4,5,6,7),
107107
null,
108108
kdNode,
109109
1
110110
);
111111

112-
var boruvkaMST = BoruvkaMST.createWithZeroCores(
112+
var boruvkaMST = BoruvkaMST.createWithZeroCores(
113113
null,
114114
kdTree,
115115
8,
@@ -136,7 +136,7 @@ void singleComponentShouldWorkOnSplitNode(){
136136

137137
@Test
138138
void baseCaseShouldWork(){
139-
DoubleArrayNodePropertyValues nodeProps=new DoubleArrayNodePropertyValues() {
139+
DoubleArrayNodePropertyValues nodeProps = new DoubleArrayNodePropertyValues() {
140140
@Override
141141
public double[] doubleArrayValue(long nodeId) {
142142
return new double[]{nodeId};
@@ -154,9 +154,9 @@ public long nodeCount() {
154154
when(coreResult.createCoreArray()).thenReturn(HugeDoubleArray.of(0,0,10,10,0,0,0,0,0,0));
155155
when(coreResult.neighboursOf(anyLong())).thenReturn(new Neighbour[0]);
156156

157-
var distances =new DoubleArrayDistances(nodeProps);
157+
var distances = new DoubleArrayDistances(nodeProps);
158158

159-
var boruvkaMST = BoruvkaMST.create(
159+
var boruvkaMST = BoruvkaMST.create(
160160
distances,
161161
kdTree,
162162
coreResult,
@@ -174,16 +174,16 @@ public long nodeCount() {
174174

175175
@Test
176176
void baseCaseShouldIgnoreSameComponents(){
177-
DoubleArrayNodePropertyValues nodeProps=mock(DoubleArrayNodePropertyValues.class);
177+
DoubleArrayNodePropertyValues nodeProps = mock(DoubleArrayNodePropertyValues.class);
178178
var kdTree = mock(KdTree.class);
179179

180180
var coreResult = mock(CoreResult.class);
181181
when(coreResult.createCoreArray()).thenReturn(HugeDoubleArray.of(0,0,10,10));
182182
when(coreResult.neighboursOf(anyLong())).thenReturn(new Neighbour[0]);
183183

184-
var distances =new DoubleArrayDistances(nodeProps);
184+
var distances = new DoubleArrayDistances(nodeProps);
185185

186-
var boruvkaMST = BoruvkaMST.create(
186+
var boruvkaMST = BoruvkaMST.create(
187187
distances,
188188
kdTree,
189189
coreResult,
@@ -199,13 +199,13 @@ void baseCaseShouldIgnoreSameComponents(){
199199
@Test
200200
void shouldPruneProperly(){
201201

202-
DoubleArrayNodePropertyValues nodeProps=mock(DoubleArrayNodePropertyValues.class);
202+
DoubleArrayNodePropertyValues nodeProps = mock(DoubleArrayNodePropertyValues.class);
203203

204-
var distances =new DoubleArrayDistances(nodeProps);
204+
var distances = new DoubleArrayDistances(nodeProps);
205205
var kdRoot = KdNode.createLeaf(0,0,2,mock(AABB.class));
206-
var kdTree =new KdTree(HugeLongArray.of(0,1,2),distances,kdRoot,1);
206+
var kdTree = new KdTree(HugeLongArray.of(0,1,2),distances,kdRoot,1);
207207

208-
var boruvkaMST = BoruvkaMST.createWithZeroCores(
208+
var boruvkaMST = BoruvkaMST.createWithZeroCores(
209209
distances,
210210
kdTree,
211211
3,
@@ -214,7 +214,7 @@ void shouldPruneProperly(){
214214
);
215215

216216
//prune based on distance
217-
boruvkaMST.tryUpdate(2,2,3,5);
217+
boruvkaMST.tryUpdate(2,1,2,1,5);
218218
assertThat(boruvkaMST.prune(kdRoot,2,100)).isTrue();
219219
assertThat(boruvkaMST.prune(kdRoot,2,4)).isFalse();
220220

0 commit comments

Comments
 (0)