Skip to content

Commit cf99422

Browse files
Filtered Node Similarity
1 parent 35a775d commit cf99422

File tree

19 files changed

+228
-125
lines changed

19 files changed

+228
-125
lines changed

algo-params/similarity-params/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ group = 'org.neo4j.gds'
77
dependencies {
88
compileOnly openGds.jetbrains.annotations
99

10+
implementation project(':algo-params-common')
1011
implementation project(':annotations')
1112
implementation project(':concurrency')
1213
implementation project(':graph-projection-api')

algo-params/similarity-params/src/main/java/org/neo4j/gds/similarity/filteredknn/FilteredKnnParameters.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,18 @@
1919
*/
2020
package org.neo4j.gds.similarity.filteredknn;
2121

22+
import org.neo4j.gds.AlgorithmParameters;
23+
import org.neo4j.gds.core.concurrency.Concurrency;
2224
import org.neo4j.gds.similarity.FilteringParameters;
2325
import org.neo4j.gds.similarity.knn.KnnParameters;
2426

2527
public record FilteredKnnParameters(
2628
KnnParameters knnParameters,
2729
FilteringParameters filteringParameters,
2830
boolean seedTargetNodes
29-
) {}
31+
) implements AlgorithmParameters {
32+
@Override
33+
public Concurrency concurrency() {
34+
return knnParameters.concurrency();
35+
}
36+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.filterednodesim;
21+
22+
import org.neo4j.gds.AlgorithmParameters;
23+
import org.neo4j.gds.core.concurrency.Concurrency;
24+
import org.neo4j.gds.similarity.FilteringParameters;
25+
import org.neo4j.gds.similarity.nodesim.NodeSimilarityParameters;
26+
27+
public record FilteredNodeSimilarityParameters(
28+
NodeSimilarityParameters nodeSimilarityParameters,
29+
FilteringParameters filteringParameters
30+
) implements AlgorithmParameters {
31+
32+
@Override
33+
public Concurrency concurrency() {
34+
return nodeSimilarityParameters.concurrency();
35+
}
36+
}

algo-params/similarity-params/src/main/java/org/neo4j/gds/similarity/knn/KnnParameters.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
*/
2020
package org.neo4j.gds.similarity.knn;
2121

22+
import org.neo4j.gds.AlgorithmParameters;
2223
import org.neo4j.gds.annotation.Parameters;
2324
import org.neo4j.gds.core.concurrency.Concurrency;
2425

@@ -36,7 +37,7 @@ public record KnnParameters(
3637
int minBatchSize,
3738
KnnSampler.SamplerType samplerType,
3839
Optional<Long> randomSeed,
39-
List<KnnNodePropertySpec> nodePropertySpecs) {
40+
List<KnnNodePropertySpec> nodePropertySpecs) implements AlgorithmParameters {
4041

4142
static KnnParameters create(
4243
long nodeCount,

algo-params/similarity-params/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityParameters.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
package org.neo4j.gds.similarity.nodesim;
2121

2222
import org.jetbrains.annotations.Nullable;
23+
import org.neo4j.gds.AlgorithmParameters;
2324
import org.neo4j.gds.annotation.Parameters;
25+
import org.neo4j.gds.core.concurrency.Concurrency;
2426

2527
@Parameters
2628
public record NodeSimilarityParameters(
29+
Concurrency concurrency,
2730
MetricSimilarityComputer similarityComputer,
2831
int degreeCutoff,
2932
int upperDegreeCutoff,
@@ -33,7 +36,7 @@ public record NodeSimilarityParameters(
3336
boolean hasRelationshipWeightProperty,
3437
boolean useComponents,
3538
@Nullable String componentProperty
36-
) {
39+
) implements AlgorithmParameters {
3740
boolean hasTopK() {
3841
return normalizedK != 0;
3942
}

algo/src/main/java/org/neo4j/gds/SimilarityAlgorithmTasks.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,23 @@
2323
import org.neo4j.gds.core.utils.progress.tasks.Task;
2424
import org.neo4j.gds.similarity.filteredknn.FilteredKNNTask;
2525
import org.neo4j.gds.similarity.filteredknn.FilteredKnnParameters;
26+
import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityParameters;
2627
import org.neo4j.gds.similarity.knn.KnnParameters;
2728
import org.neo4j.gds.similarity.knn.KnnTask;
29+
import org.neo4j.gds.similarity.nodesim.FilteredNodeSimilarityTask;
2830

2931
public final class SimilarityAlgorithmTasks {
3032

31-
public Task FilteredKnn(Graph graph, FilteredKnnParameters parameters){
33+
public Task filteredKnn(Graph graph, FilteredKnnParameters parameters){
3234
return FilteredKNNTask.create(graph.nodeCount(), parameters);
3335
}
34-
public Task Knn(Graph graph, KnnParameters parameters){
36+
37+
public Task knn(Graph graph, KnnParameters parameters){
3538
return KnnTask.create(graph.nodeCount(), parameters);
3639
}
40+
41+
public Task filteredNodeSimilarity(Graph graph, FilteredNodeSimilarityParameters filteredNodeSimilarityParameters){
42+
return FilteredNodeSimilarityTask.create(graph,filteredNodeSimilarityParameters.nodeSimilarityParameters());
43+
}
44+
3745
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.nodesim;
21+
22+
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.core.utils.progress.tasks.Task;
24+
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
25+
import org.neo4j.gds.wcc.WccTask;
26+
27+
import static org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel.FilteredNodeSimilarity;
28+
29+
public class FilteredNodeSimilarityTask {
30+
31+
private FilteredNodeSimilarityTask() {}
32+
33+
public static Task create(Graph graph, NodeSimilarityParameters parameters) {
34+
return Tasks.task(
35+
FilteredNodeSimilarity.asString(),
36+
filteredNodeSimilarityProgressTask(graph, parameters.runWCC()),
37+
Tasks.leaf("compare node pairs")
38+
);
39+
}
40+
41+
private static Task filteredNodeSimilarityProgressTask(Graph graph, boolean runWcc) {
42+
if (runWcc) {
43+
return Tasks.task(
44+
"prepare",
45+
WccTask.create(graph),
46+
Tasks.leaf("initialize", graph.relationshipCount())
47+
);
48+
}
49+
return Tasks.leaf("prepare", graph.relationshipCount());
50+
}
51+
}

algo/src/main/java/org/neo4j/gds/similarity/nodesim/NodeSimilarity.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ public class NodeSimilarity extends Algorithm<NodeSimilarityResult> {
7979
public NodeSimilarity(
8080
Graph graph,
8181
NodeSimilarityParameters parameters,
82-
Concurrency concurrency,
8382
ExecutorService executorService,
8483
ProgressTracker progressTracker,
8584
NodeFilter sourceNodeFilter,
@@ -92,7 +91,7 @@ public NodeSimilarity(
9291
this.sortVectors = graph.schema().relationshipSchema().availableTypes().size() > 1;
9392
this.sourceNodeFilter = sourceNodeFilter;
9493
this.targetNodeFilter = targetNodeFilter;
95-
this.concurrency = concurrency;
94+
this.concurrency = parameters.concurrency();
9695
this.parameters = parameters;
9796
this.similarityComputer = parameters.similarityComputer();
9897
this.executorService = executorService;

algo/src/test/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityTest.java

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,22 @@
2222
import org.junit.jupiter.api.Test;
2323
import org.junit.jupiter.params.ParameterizedTest;
2424
import org.junit.jupiter.params.provider.ValueSource;
25-
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
26-
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
25+
import org.neo4j.gds.SimilarityAlgorithmTasks;
26+
import org.neo4j.gds.TestProgressTrackerHelper;
27+
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery;
2728
import org.neo4j.gds.applications.algorithms.similarity.SimilarityAlgorithms;
2829
import org.neo4j.gds.core.concurrency.Concurrency;
29-
import org.neo4j.gds.core.utils.logging.LoggerForProgressTrackingAdapter;
30-
import org.neo4j.gds.core.utils.progress.EmptyTaskRegistryFactory;
30+
import org.neo4j.gds.core.concurrency.DefaultPool;
3131
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
32-
import org.neo4j.gds.core.utils.warnings.EmptyUserLogRegistryFactory;
3332
import org.neo4j.gds.extension.GdlExtension;
3433
import org.neo4j.gds.extension.GdlGraph;
3534
import org.neo4j.gds.extension.Inject;
3635
import org.neo4j.gds.extension.TestGraph;
37-
import org.neo4j.gds.logging.GdsTestLog;
36+
import org.neo4j.gds.similarity.filtering.NodeFilter;
3837
import org.neo4j.gds.similarity.filtering.NodeFilterSpecFactory;
38+
import org.neo4j.gds.similarity.nodesim.NodeSimilarity;
3939
import org.neo4j.gds.termination.TerminationFlag;
40+
import org.neo4j.gds.wcc.WccStub;
4041

4142
import java.util.List;
4243
import java.util.stream.Collectors;
@@ -79,19 +80,19 @@ void should() {
7980

8081
var sourceNodeFilter = Stream.of("a", "b", "c").map(graph::toOriginalNodeId).collect(Collectors.toList());
8182

82-
var config = FilteredNodeSimilarityStreamConfigImpl.builder()
83+
var params = FilteredNodeSimilarityStreamConfigImpl.builder()
8384
.sourceNodeFilter(NodeFilterSpecFactory.create(sourceNodeFilter))
84-
.build();
85+
.build().toFilteredParameters();
8586

8687
// no results for nodes that are not specified in the node filter -- nice
87-
var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, config, ProgressTracker.NULL_TRACKER)
88+
var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, params, ProgressTracker.NULL_TRACKER)
8889
.streamResult()
8990
.filter(res -> !sourceNodeFilter.contains(graph.toOriginalNodeId(res.node1)))
9091
.count();
9192
assertThat(noOfResultsWithSourceNodeOutsideOfFilter).isEqualTo(0L);
9293

9394
// nodes outside of the node filter are not present as target nodes either -- not nice
94-
var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, config, ProgressTracker.NULL_TRACKER)
95+
var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, params, ProgressTracker.NULL_TRACKER)
9596
.streamResult()
9697
.filter(res -> !sourceNodeFilter.contains(graph.toOriginalNodeId(res.node2)))
9798
.count();
@@ -104,20 +105,21 @@ void shouldSurviveIoannisObjections() {
104105

105106
var sourceNodeFilter = List.of(graph.toOriginalNodeId("d"));
106107

107-
var config = FilteredNodeSimilarityStreamConfigImpl.builder()
108+
var params = FilteredNodeSimilarityStreamConfigImpl.builder()
108109
.sourceNodeFilter(NodeFilterSpecFactory.create(sourceNodeFilter))
109110
.concurrency(1)
110-
.build();
111+
.build()
112+
.toFilteredParameters();
111113

112114
// no results for nodes that are not specified in the node filter -- nice
113-
var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, config, ProgressTracker.NULL_TRACKER)
115+
var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, params, ProgressTracker.NULL_TRACKER)
114116
.streamResult()
115117
.filter(res -> !sourceNodeFilter.contains(graph.toOriginalNodeId(res.node1)))
116118
.count();
117119
assertThat(noOfResultsWithSourceNodeOutsideOfFilter).isEqualTo(0L);
118120

119121
// nodes outside of the node filter are not present as target nodes either -- not nice
120-
var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, config, ProgressTracker.NULL_TRACKER)
122+
var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, params, ProgressTracker.NULL_TRACKER)
121123
.streamResult()
122124
.filter(res -> !sourceNodeFilter.contains(graph.toOriginalNodeId(res.node2)))
123125
.count();
@@ -131,23 +133,24 @@ void shouldSurviveIoannisFurtherObjections(boolean enableWcc) {
131133

132134
var sourceNodeFilter = List.of(graph.toOriginalNodeId("d"));
133135

134-
var config = FilteredNodeSimilarityStreamConfigImpl.builder()
136+
var params = FilteredNodeSimilarityStreamConfigImpl.builder()
135137
.sourceNodeFilter(NodeFilterSpecFactory.create(sourceNodeFilter))
136138
.concurrency(1)
137139
.useComponents(enableWcc)
138140
.topK(1)
139141
.topN(10)
140-
.build();
142+
.build()
143+
.toFilteredParameters();
141144

142145
// no results for nodes that are not specified in the node filter -- nice
143-
var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, config, ProgressTracker.NULL_TRACKER)
146+
var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, params, ProgressTracker.NULL_TRACKER)
144147
.streamResult()
145148
.filter(res -> !sourceNodeFilter.contains(graph.toOriginalNodeId(res.node1)))
146149
.count();
147150
assertThat(noOfResultsWithSourceNodeOutsideOfFilter).isEqualTo(0L);
148151

149152
// nodes outside of the node filter are not present as target nodes either -- not nice
150-
var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, config, ProgressTracker.NULL_TRACKER)
153+
var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms.filteredNodeSimilarity(graph, params, ProgressTracker.NULL_TRACKER)
151154
.streamResult()
152155
.filter(res -> !sourceNodeFilter.contains(graph.toOriginalNodeId(res.node2)))
153156
.count();
@@ -157,24 +160,38 @@ void shouldSurviveIoannisFurtherObjections(boolean enableWcc) {
157160
@ParameterizedTest
158161
@ValueSource(ints = {1, 2})
159162
void shouldLogProgressAccurately(int concurrencyValue) {
160-
var log = new GdsTestLog();
161-
var requestScopedDependencies = RequestScopedDependencies.builder()
162-
.taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE)
163-
.terminationFlag(TerminationFlag.RUNNING_TRUE)
164-
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
165-
.build();
166-
var progressTrackerCreator = new ProgressTrackerCreator(new LoggerForProgressTrackingAdapter(log), requestScopedDependencies);
167-
var similarityAlgorithms = new SimilarityAlgorithms(progressTrackerCreator, requestScopedDependencies.terminationFlag());
163+
168164

169165
var sourceNodeFilter = List.of(graph.toOriginalNodeId("c"), graph.toOriginalNodeId("d"));
170166
var concurrency = new Concurrency(concurrencyValue);
171-
var config = FilteredNodeSimilarityStreamConfigImpl.builder()
167+
var params = FilteredNodeSimilarityStreamConfigImpl.builder()
172168
.sourceNodeFilter(NodeFilterSpecFactory.create(sourceNodeFilter))
173169
.concurrency(concurrency.value())
174170
.topK(1)
175171
.topN(10)
176-
.build();
177-
similarityAlgorithms.filteredNodeSimilarity(graph, config);
172+
.build()
173+
.toFilteredParameters();
174+
175+
var progressTrackerWithLog = TestProgressTrackerHelper.create(
176+
new SimilarityAlgorithmTasks().filteredNodeSimilarity(graph,params),
177+
new Concurrency(2)
178+
);
179+
180+
var progressTracker = progressTrackerWithLog.progressTracker();
181+
var log = progressTrackerWithLog.log();
182+
183+
var filteredNodeSimilarity = new NodeSimilarity(
184+
graph,
185+
params.nodeSimilarityParameters(),
186+
DefaultPool.INSTANCE,
187+
progressTracker,
188+
params.filteringParameters().sourceFilter().toNodeFilter(graph),
189+
NodeFilter.ALLOW_EVERYTHING,
190+
TerminationFlag.RUNNING_TRUE,
191+
new WccStub(TerminationFlag.RUNNING_TRUE, new AlgorithmMachinery())
192+
);
193+
194+
filteredNodeSimilarity.compute();
178195

179196
assertThat(log.getMessages(INFO))
180197
.extracting(removingThreadId())

algo/src/test/java/org/neo4j/gds/similarity/knn/KnnTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ void shouldLogProgress() {
555555
);
556556

557557
var progressTrackerWithLog = TestProgressTrackerHelper.create(
558-
new SimilarityAlgorithmTasks().Knn(graph, knnParameters),
558+
new SimilarityAlgorithmTasks().knn(graph, knnParameters),
559559
new Concurrency(1)
560560
);
561561

0 commit comments

Comments
 (0)