Skip to content

Commit bfc0090

Browse files
Node Similarity
1 parent cf99422 commit bfc0090

File tree

7 files changed

+139
-94
lines changed

7 files changed

+139
-94
lines changed

algo/src/main/java/org/neo4j/gds/SimilarityAlgorithmTasks.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.neo4j.gds.similarity.knn.KnnParameters;
2828
import org.neo4j.gds.similarity.knn.KnnTask;
2929
import org.neo4j.gds.similarity.nodesim.FilteredNodeSimilarityTask;
30+
import org.neo4j.gds.similarity.nodesim.NodeSimilarityParameters;
31+
import org.neo4j.gds.similarity.nodesim.NodeSimilarityTask;
3032

3133
public final class SimilarityAlgorithmTasks {
3234

@@ -38,8 +40,12 @@ public Task knn(Graph graph, KnnParameters parameters){
3840
return KnnTask.create(graph.nodeCount(), parameters);
3941
}
4042

41-
public Task filteredNodeSimilarity(Graph graph, FilteredNodeSimilarityParameters filteredNodeSimilarityParameters){
42-
return FilteredNodeSimilarityTask.create(graph,filteredNodeSimilarityParameters.nodeSimilarityParameters());
43+
public Task filteredNodeSimilarity(Graph graph, FilteredNodeSimilarityParameters parameters){
44+
return FilteredNodeSimilarityTask.create(graph,parameters.nodeSimilarityParameters());
45+
}
46+
47+
public Task nodeSimilarity(Graph graph, NodeSimilarityParameters parameters){
48+
return NodeSimilarityTask.create(graph,parameters);
4349
}
4450

4551
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.similarity.nodesim;
21+
22+
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel;
24+
import org.neo4j.gds.core.utils.progress.tasks.Task;
25+
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
26+
import org.neo4j.gds.wcc.WccTask;
27+
28+
public final class NodeSimilarityTask {
29+
30+
private NodeSimilarityTask() {}
31+
32+
public static Task create(Graph graph, NodeSimilarityParameters parameters) {
33+
return Tasks.task(
34+
AlgorithmLabel.NodeSimilarity.asString(),
35+
parameters.runWCC()
36+
? Tasks.task(
37+
"prepare",
38+
WccTask.create(graph),
39+
Tasks.leaf("initialize", graph.relationshipCount())
40+
)
41+
: Tasks.leaf("prepare", graph.relationshipCount()),
42+
Tasks.leaf("compare node pairs")
43+
);
44+
}
45+
46+
}

algo/src/test/java/org/neo4j/gds/similarity/filterednodesim/FilteredNodeSimilarityTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class FilteredNodeSimilarityTest {
7676

7777
@Test
7878
void should() {
79-
var similarityAlgorithms = new SimilarityAlgorithms(null, TerminationFlag.RUNNING_TRUE);
79+
var similarityAlgorithms = new SimilarityAlgorithms(TerminationFlag.RUNNING_TRUE);
8080

8181
var sourceNodeFilter = Stream.of("a", "b", "c").map(graph::toOriginalNodeId).collect(Collectors.toList());
8282

@@ -101,7 +101,7 @@ void should() {
101101

102102
@Test
103103
void shouldSurviveIoannisObjections() {
104-
var similarityAlgorithms = new SimilarityAlgorithms(null, TerminationFlag.RUNNING_TRUE);
104+
var similarityAlgorithms = new SimilarityAlgorithms(TerminationFlag.RUNNING_TRUE);
105105

106106
var sourceNodeFilter = List.of(graph.toOriginalNodeId("d"));
107107

@@ -129,7 +129,7 @@ void shouldSurviveIoannisObjections() {
129129
@ParameterizedTest
130130
@ValueSource(booleans = {true, false})
131131
void shouldSurviveIoannisFurtherObjections(boolean enableWcc) {
132-
var similarityAlgorithms = new SimilarityAlgorithms(null, TerminationFlag.RUNNING_TRUE);
132+
var similarityAlgorithms = new SimilarityAlgorithms(TerminationFlag.RUNNING_TRUE);
133133

134134
var sourceNodeFilter = List.of(graph.toOriginalNodeId("d"));
135135

algo/src/test/java/org/neo4j/gds/similarity/nodesim/NodeSimilarityTest.java

Lines changed: 72 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
import org.junit.jupiter.params.provider.MethodSource;
2626
import org.junit.jupiter.params.provider.ValueSource;
2727
import org.neo4j.gds.Orientation;
28-
import org.neo4j.gds.TestProgressTracker;
28+
import org.neo4j.gds.SimilarityAlgorithmTasks;
29+
import org.neo4j.gds.TestProgressTrackerHelper;
2930
import org.neo4j.gds.TestSupport;
3031
import org.neo4j.gds.api.Graph;
3132
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery;
3233
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
3334
import org.neo4j.gds.applications.algorithms.machinery.RequestScopedDependencies;
3435
import org.neo4j.gds.applications.algorithms.similarity.SimilarityAlgorithms;
36+
import org.neo4j.gds.applications.algorithms.similarity.SimilarityAlgorithmsBusinessFacade;
3537
import org.neo4j.gds.core.concurrency.Concurrency;
3638
import org.neo4j.gds.core.concurrency.DefaultPool;
3739
import org.neo4j.gds.core.utils.logging.LoggerForProgressTrackingAdapter;
@@ -784,16 +786,28 @@ void shouldIgnoreParallelEdges(Orientation orientation, int concurrency) {
784786

785787
@Test
786788
void shouldLogMessages() {
787-
var log = new GdsTestLog();
788-
var requestScopedDependencies = RequestScopedDependencies.builder()
789-
.taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE)
790-
.terminationFlag(TerminationFlag.RUNNING_TRUE)
791-
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
792-
.build();
793-
var progressTrackerCreator = new ProgressTrackerCreator(new LoggerForProgressTrackingAdapter(log), requestScopedDependencies);
794-
var similarityAlgorithms = new SimilarityAlgorithms(progressTrackerCreator, requestScopedDependencies.terminationFlag());
795789

796-
similarityAlgorithms.nodeSimilarity(naturalGraph, NodeSimilarityBaseConfigImpl.builder().build());
790+
var params = NodeSimilarityBaseConfigImpl.builder().build().toParameters();
791+
var progressTrackerWithLog = TestProgressTrackerHelper.create(
792+
new SimilarityAlgorithmTasks().nodeSimilarity(naturalGraph, params),
793+
new Concurrency(2)
794+
);
795+
796+
var progressTracker = progressTrackerWithLog.progressTracker();
797+
var log = progressTrackerWithLog.log();
798+
799+
var nodeSimilarity = new NodeSimilarity(
800+
naturalGraph,
801+
params,
802+
DefaultPool.INSTANCE,
803+
progressTracker,
804+
NodeFilter.ALLOW_EVERYTHING,
805+
NodeFilter.ALLOW_EVERYTHING,
806+
TerminationFlag.RUNNING_TRUE,
807+
new WccStub(TerminationFlag.RUNNING_TRUE,new AlgorithmMachinery())
808+
);
809+
810+
nodeSimilarity.compute();
797811

798812
assertThat(log.getMessages(INFO))
799813
.extracting(removingThreadId())
@@ -814,9 +828,11 @@ void shouldNotLogMessagesWhenLoggingIsDisabled() {
814828
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
815829
.build();
816830
var progressTrackerCreator = new ProgressTrackerCreator(new LoggerForProgressTrackingAdapter(log), requestScopedDependencies);
817-
var similarityAlgorithms = new SimilarityAlgorithms(progressTrackerCreator, requestScopedDependencies.terminationFlag());
831+
var similarityAlgorithms = new SimilarityAlgorithms(requestScopedDependencies.terminationFlag());
818832

819-
similarityAlgorithms.nodeSimilarity(naturalGraph, NodeSimilarityBaseConfigImpl.builder().logProgress(false).build());
833+
var similarityBusiness = new SimilarityAlgorithmsBusinessFacade(similarityAlgorithms,progressTrackerCreator);
834+
835+
similarityBusiness.nodeSimilarity(naturalGraph, NodeSimilarityBaseConfigImpl.builder().logProgress(false).build());
820836

821837
assertThat(log.getMessages(INFO))
822838
.as("When progress logging is disabled we only log `start` and `finished`.")
@@ -834,22 +850,27 @@ void shouldNotLogMessagesWhenLoggingIsDisabled() {
834850
@ParameterizedTest(name = "concurrency = {0}")
835851
@ValueSource(ints = {1, 2})
836852
void shouldLogProgress(int concurrencyValue) {
837-
var log = new GdsTestLog();
838-
var requestScopedDependencies = RequestScopedDependencies.builder()
839-
.taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE)
840-
.terminationFlag(TerminationFlag.RUNNING_TRUE)
841-
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
842-
.build();
843-
var similarityAlgorithms = new SimilarityAlgorithms(null, requestScopedDependencies.terminationFlag());
844-
845-
var configuration = NodeSimilarityStreamConfigImpl.builder().build();
846-
var progressTracker = new TestProgressTracker(
847-
similarityAlgorithms.constructNodeSimilarityTask(naturalGraph, configuration),
848-
new LoggerForProgressTrackingAdapter(log),
849-
new Concurrency(concurrencyValue),
850-
EmptyTaskRegistryFactory.INSTANCE
853+
var params = NodeSimilarityBaseConfigImpl.builder().build().toParameters();
854+
var progressTrackerWithLog = TestProgressTrackerHelper.create(
855+
new SimilarityAlgorithmTasks().nodeSimilarity(naturalGraph, params),
856+
new Concurrency(concurrencyValue)
851857
);
852-
similarityAlgorithms.nodeSimilarity(naturalGraph, configuration, progressTracker).streamResult().count();
858+
859+
var progressTracker = progressTrackerWithLog.progressTracker();
860+
var log = progressTrackerWithLog.log();
861+
862+
var nodeSimilarity = new NodeSimilarity(
863+
naturalGraph,
864+
params,
865+
DefaultPool.INSTANCE,
866+
progressTracker,
867+
NodeFilter.ALLOW_EVERYTHING,
868+
NodeFilter.ALLOW_EVERYTHING,
869+
TerminationFlag.RUNNING_TRUE,
870+
new WccStub(TerminationFlag.RUNNING_TRUE,new AlgorithmMachinery())
871+
);
872+
873+
nodeSimilarity.compute();
853874

854875
var progresses = progressTracker.getProgresses();
855876

@@ -869,24 +890,33 @@ void shouldLogProgress(int concurrencyValue) {
869890

870891
@Test
871892
void shouldLogProgressForWccOptimization() {
872-
var log = new GdsTestLog();
873-
var requestScopedDependencies = RequestScopedDependencies.builder()
874-
.taskRegistryFactory(EmptyTaskRegistryFactory.INSTANCE)
875-
.terminationFlag(TerminationFlag.RUNNING_TRUE)
876-
.userLogRegistryFactory(EmptyUserLogRegistryFactory.INSTANCE)
877-
.build();
878-
var similarityAlgorithms = new SimilarityAlgorithms(null, requestScopedDependencies.terminationFlag());
879893

880-
var configuration = NodeSimilarityStreamConfigImpl.builder()
894+
var params = NodeSimilarityBaseConfigImpl
895+
.builder()
881896
.useComponents(true)
882-
.build();
883-
var progressTracker = new TestProgressTracker(
884-
similarityAlgorithms.constructNodeSimilarityTask(naturalGraph, configuration),
885-
new LoggerForProgressTrackingAdapter(log),
886-
new Concurrency(4),
887-
EmptyTaskRegistryFactory.INSTANCE
897+
.build()
898+
.toParameters();
899+
900+
var progressTrackerWithLog = TestProgressTrackerHelper.create(
901+
new SimilarityAlgorithmTasks().nodeSimilarity(naturalGraph, params),
902+
new Concurrency(4)
903+
);
904+
905+
var progressTracker = progressTrackerWithLog.progressTracker();
906+
var log = progressTrackerWithLog.log();
907+
908+
var nodeSimilarity = new NodeSimilarity(
909+
naturalGraph,
910+
params,
911+
DefaultPool.INSTANCE,
912+
progressTracker,
913+
NodeFilter.ALLOW_EVERYTHING,
914+
NodeFilter.ALLOW_EVERYTHING,
915+
TerminationFlag.RUNNING_TRUE,
916+
new WccStub(TerminationFlag.RUNNING_TRUE,new AlgorithmMachinery())
888917
);
889-
similarityAlgorithms.nodeSimilarity(naturalGraph, configuration, progressTracker).streamResult().count();
918+
919+
nodeSimilarity.compute();
890920

891921
var progresses = progressTracker.getProgresses();
892922

applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithms.java

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,9 @@
2020
package org.neo4j.gds.applications.algorithms.similarity;
2121

2222
import org.neo4j.gds.api.Graph;
23-
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmLabel;
2423
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery;
25-
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
2624
import org.neo4j.gds.core.concurrency.DefaultPool;
2725
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
28-
import org.neo4j.gds.core.utils.progress.tasks.Task;
29-
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
3026
import org.neo4j.gds.similarity.filteredknn.FilteredKnn;
3127
import org.neo4j.gds.similarity.filteredknn.FilteredKnnParameters;
3228
import org.neo4j.gds.similarity.filteredknn.FilteredKnnResult;
@@ -41,22 +37,19 @@
4137
import org.neo4j.gds.similarity.knn.SimilarityFunction;
4238
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;
4339
import org.neo4j.gds.similarity.nodesim.NodeSimilarity;
44-
import org.neo4j.gds.similarity.nodesim.NodeSimilarityBaseConfig;
40+
import org.neo4j.gds.similarity.nodesim.NodeSimilarityParameters;
4541
import org.neo4j.gds.similarity.nodesim.NodeSimilarityResult;
4642
import org.neo4j.gds.termination.TerminationFlag;
4743
import org.neo4j.gds.wcc.WccStub;
48-
import org.neo4j.gds.wcc.WccTask;
4944

5045
import java.util.Optional;
5146

5247
public class SimilarityAlgorithms {
5348
private final AlgorithmMachinery algorithmMachinery = new AlgorithmMachinery();
5449

55-
private final ProgressTrackerCreator progressTrackerCreator;
5650
private final TerminationFlag terminationFlag;
5751

58-
public SimilarityAlgorithms(ProgressTrackerCreator progressTrackerCreator, TerminationFlag terminationFlag) {
59-
this.progressTrackerCreator = progressTrackerCreator;
52+
public SimilarityAlgorithms(TerminationFlag terminationFlag) {
6053
this.terminationFlag = terminationFlag;
6154
}
6255

@@ -136,43 +129,18 @@ KnnResult knn(Graph graph, KnnParameters parameters, ProgressTracker progressTra
136129
);
137130
}
138131

139-
public NodeSimilarityResult nodeSimilarity(Graph graph, NodeSimilarityBaseConfig configuration) {
140-
var task = constructNodeSimilarityTask(graph, configuration);
141132

142-
var progressTracker = progressTrackerCreator.createProgressTracker(
143-
task,
144-
configuration.jobId(),
145-
configuration.concurrency(),
146-
configuration.logProgress()
147-
);
148-
149-
return nodeSimilarity(graph, configuration, progressTracker);
150-
}
151-
152-
public Task constructNodeSimilarityTask(Graph graph, NodeSimilarityBaseConfig configuration) {
153-
return Tasks.task(
154-
AlgorithmLabel.NodeSimilarity.asString(),
155-
configuration.useComponents().computeComponents()
156-
? Tasks.task(
157-
"prepare",
158-
WccTask.create(graph),
159-
Tasks.leaf("initialize", graph.relationshipCount())
160-
)
161-
: Tasks.leaf("prepare", graph.relationshipCount()),
162-
Tasks.leaf("compare node pairs")
163-
);
164-
}
165133

166134
public NodeSimilarityResult nodeSimilarity(
167135
Graph graph,
168-
NodeSimilarityBaseConfig configuration,
136+
NodeSimilarityParameters parameters,
169137
ProgressTracker progressTracker
170138
) {
171139
var wccStub = new WccStub(terminationFlag, algorithmMachinery);
172140

173141
var algorithm = new NodeSimilarity(
174142
graph,
175-
configuration.toParameters(),
143+
parameters,
176144
DefaultPool.INSTANCE,
177145
progressTracker,
178146
NodeFilter.ALLOW_EVERYTHING,
@@ -185,7 +153,7 @@ public NodeSimilarityResult nodeSimilarity(
185153
algorithm,
186154
progressTracker,
187155
true,
188-
configuration.concurrency()
156+
parameters.concurrency()
189157
);
190158
}
191159

applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityAlgorithmsBusinessFacade.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.neo4j.gds.SimilarityAlgorithmTasks;
2323
import org.neo4j.gds.api.Graph;
2424
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
25-
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2625
import org.neo4j.gds.similarity.filteredknn.FilteredKnnBaseConfig;
2726
import org.neo4j.gds.similarity.filteredknn.FilteredKnnResult;
2827
import org.neo4j.gds.similarity.filterednodesim.FilteredNodeSimilarityBaseConfig;
@@ -69,15 +68,11 @@ KnnResult knn(Graph graph, KnnBaseConfig configuration) {
6968
}
7069

7170
public NodeSimilarityResult nodeSimilarity(Graph graph, NodeSimilarityBaseConfig configuration) {
72-
return similarityAlgorithms.nodeSimilarity(graph,configuration);
71+
var parameters = configuration.toParameters();
72+
var task = tasks.nodeSimilarity(graph, parameters);
73+
var progressTracker = progressTrackerCreator.createProgressTracker(task,configuration);
74+
return similarityAlgorithms.nodeSimilarity(graph, parameters, progressTracker);
7375
}
7476

75-
public NodeSimilarityResult nodeSimilarity(
76-
Graph graph,
77-
NodeSimilarityBaseConfig configuration,
78-
ProgressTracker progressTracker
79-
) {
80-
return similarityAlgorithms.nodeSimilarity(graph,configuration,progressTracker);
81-
}
8277

8378
}

applications/algorithms/similarity/src/main/java/org/neo4j/gds/applications/algorithms/similarity/SimilarityApplications.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public static SimilarityApplications create(
5959
WriteContext writeContext
6060
) {
6161
var estimationModeFacade = new SimilarityAlgorithmsEstimationModeBusinessFacade(algorithmEstimationTemplate);
62-
var similarityAlgorithms = new SimilarityAlgorithms(progressTrackerCreator, requestScopedDependencies.terminationFlag());
62+
var similarityAlgorithms = new SimilarityAlgorithms(requestScopedDependencies.terminationFlag());
6363

6464
var businessFacade = new SimilarityAlgorithmsBusinessFacade(
6565
similarityAlgorithms,

0 commit comments

Comments
 (0)