22
22
import org .junit .jupiter .api .Test ;
23
23
import org .junit .jupiter .params .ParameterizedTest ;
24
24
import org .junit .jupiter .params .provider .ValueSource ;
25
- import org .neo4j .gds .applications .algorithms .machinery .ProgressTrackerCreator ;
26
- import org .neo4j .gds .applications .algorithms .machinery .RequestScopedDependencies ;
25
+ import org .neo4j .gds .SimilarityAlgorithmTasks ;
26
+ import org .neo4j .gds .TestProgressTrackerHelper ;
27
+ import org .neo4j .gds .applications .algorithms .machinery .AlgorithmMachinery ;
27
28
import org .neo4j .gds .applications .algorithms .similarity .SimilarityAlgorithms ;
28
29
import org .neo4j .gds .core .concurrency .Concurrency ;
29
- import org .neo4j .gds .core .utils .logging .LoggerForProgressTrackingAdapter ;
30
- import org .neo4j .gds .core .utils .progress .EmptyTaskRegistryFactory ;
30
+ import org .neo4j .gds .core .concurrency .DefaultPool ;
31
31
import org .neo4j .gds .core .utils .progress .tasks .ProgressTracker ;
32
- import org .neo4j .gds .core .utils .warnings .EmptyUserLogRegistryFactory ;
33
32
import org .neo4j .gds .extension .GdlExtension ;
34
33
import org .neo4j .gds .extension .GdlGraph ;
35
34
import org .neo4j .gds .extension .Inject ;
36
35
import org .neo4j .gds .extension .TestGraph ;
37
- import org .neo4j .gds .logging . GdsTestLog ;
36
+ import org .neo4j .gds .similarity . filtering . NodeFilter ;
38
37
import org .neo4j .gds .similarity .filtering .NodeFilterSpecFactory ;
38
+ import org .neo4j .gds .similarity .nodesim .NodeSimilarity ;
39
39
import org .neo4j .gds .termination .TerminationFlag ;
40
+ import org .neo4j .gds .wcc .WccStub ;
40
41
41
42
import java .util .List ;
42
43
import java .util .stream .Collectors ;
@@ -79,19 +80,19 @@ void should() {
79
80
80
81
var sourceNodeFilter = Stream .of ("a" , "b" , "c" ).map (graph ::toOriginalNodeId ).collect (Collectors .toList ());
81
82
82
- var config = FilteredNodeSimilarityStreamConfigImpl .builder ()
83
+ var params = FilteredNodeSimilarityStreamConfigImpl .builder ()
83
84
.sourceNodeFilter (NodeFilterSpecFactory .create (sourceNodeFilter ))
84
- .build ();
85
+ .build (). toFilteredParameters () ;
85
86
86
87
// no results for nodes that are not specified in the node filter -- nice
87
- var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
88
+ var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , params , ProgressTracker .NULL_TRACKER )
88
89
.streamResult ()
89
90
.filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node1 )))
90
91
.count ();
91
92
assertThat (noOfResultsWithSourceNodeOutsideOfFilter ).isEqualTo (0L );
92
93
93
94
// nodes outside of the node filter are not present as target nodes either -- not nice
94
- var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
95
+ var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , params , ProgressTracker .NULL_TRACKER )
95
96
.streamResult ()
96
97
.filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node2 )))
97
98
.count ();
@@ -104,20 +105,21 @@ void shouldSurviveIoannisObjections() {
104
105
105
106
var sourceNodeFilter = List .of (graph .toOriginalNodeId ("d" ));
106
107
107
- var config = FilteredNodeSimilarityStreamConfigImpl .builder ()
108
+ var params = FilteredNodeSimilarityStreamConfigImpl .builder ()
108
109
.sourceNodeFilter (NodeFilterSpecFactory .create (sourceNodeFilter ))
109
110
.concurrency (1 )
110
- .build ();
111
+ .build ()
112
+ .toFilteredParameters ();
111
113
112
114
// no results for nodes that are not specified in the node filter -- nice
113
- var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
115
+ var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , params , ProgressTracker .NULL_TRACKER )
114
116
.streamResult ()
115
117
.filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node1 )))
116
118
.count ();
117
119
assertThat (noOfResultsWithSourceNodeOutsideOfFilter ).isEqualTo (0L );
118
120
119
121
// nodes outside of the node filter are not present as target nodes either -- not nice
120
- var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
122
+ var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , params , ProgressTracker .NULL_TRACKER )
121
123
.streamResult ()
122
124
.filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node2 )))
123
125
.count ();
@@ -131,23 +133,24 @@ void shouldSurviveIoannisFurtherObjections(boolean enableWcc) {
131
133
132
134
var sourceNodeFilter = List .of (graph .toOriginalNodeId ("d" ));
133
135
134
- var config = FilteredNodeSimilarityStreamConfigImpl .builder ()
136
+ var params = FilteredNodeSimilarityStreamConfigImpl .builder ()
135
137
.sourceNodeFilter (NodeFilterSpecFactory .create (sourceNodeFilter ))
136
138
.concurrency (1 )
137
139
.useComponents (enableWcc )
138
140
.topK (1 )
139
141
.topN (10 )
140
- .build ();
142
+ .build ()
143
+ .toFilteredParameters ();
141
144
142
145
// no results for nodes that are not specified in the node filter -- nice
143
- var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
146
+ var noOfResultsWithSourceNodeOutsideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , params , ProgressTracker .NULL_TRACKER )
144
147
.streamResult ()
145
148
.filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node1 )))
146
149
.count ();
147
150
assertThat (noOfResultsWithSourceNodeOutsideOfFilter ).isEqualTo (0L );
148
151
149
152
// nodes outside of the node filter are not present as target nodes either -- not nice
150
- var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , config , ProgressTracker .NULL_TRACKER )
153
+ var noOfResultsWithTargetNodeOutSideOfFilter = similarityAlgorithms .filteredNodeSimilarity (graph , params , ProgressTracker .NULL_TRACKER )
151
154
.streamResult ()
152
155
.filter (res -> !sourceNodeFilter .contains (graph .toOriginalNodeId (res .node2 )))
153
156
.count ();
@@ -157,24 +160,38 @@ void shouldSurviveIoannisFurtherObjections(boolean enableWcc) {
157
160
@ ParameterizedTest
158
161
@ ValueSource (ints = {1 , 2 })
159
162
void shouldLogProgressAccurately (int concurrencyValue ) {
160
- var log = new GdsTestLog ();
161
- var requestScopedDependencies = RequestScopedDependencies .builder ()
162
- .taskRegistryFactory (EmptyTaskRegistryFactory .INSTANCE )
163
- .terminationFlag (TerminationFlag .RUNNING_TRUE )
164
- .userLogRegistryFactory (EmptyUserLogRegistryFactory .INSTANCE )
165
- .build ();
166
- var progressTrackerCreator = new ProgressTrackerCreator (new LoggerForProgressTrackingAdapter (log ), requestScopedDependencies );
167
- var similarityAlgorithms = new SimilarityAlgorithms (progressTrackerCreator , requestScopedDependencies .terminationFlag ());
163
+
168
164
169
165
var sourceNodeFilter = List .of (graph .toOriginalNodeId ("c" ), graph .toOriginalNodeId ("d" ));
170
166
var concurrency = new Concurrency (concurrencyValue );
171
- var config = FilteredNodeSimilarityStreamConfigImpl .builder ()
167
+ var params = FilteredNodeSimilarityStreamConfigImpl .builder ()
172
168
.sourceNodeFilter (NodeFilterSpecFactory .create (sourceNodeFilter ))
173
169
.concurrency (concurrency .value ())
174
170
.topK (1 )
175
171
.topN (10 )
176
- .build ();
177
- similarityAlgorithms .filteredNodeSimilarity (graph , config );
172
+ .build ()
173
+ .toFilteredParameters ();
174
+
175
+ var progressTrackerWithLog = TestProgressTrackerHelper .create (
176
+ new SimilarityAlgorithmTasks ().filteredNodeSimilarity (graph ,params ),
177
+ new Concurrency (2 )
178
+ );
179
+
180
+ var progressTracker = progressTrackerWithLog .progressTracker ();
181
+ var log = progressTrackerWithLog .log ();
182
+
183
+ var filteredNodeSimilarity = new NodeSimilarity (
184
+ graph ,
185
+ params .nodeSimilarityParameters (),
186
+ DefaultPool .INSTANCE ,
187
+ progressTracker ,
188
+ params .filteringParameters ().sourceFilter ().toNodeFilter (graph ),
189
+ NodeFilter .ALLOW_EVERYTHING ,
190
+ TerminationFlag .RUNNING_TRUE ,
191
+ new WccStub (TerminationFlag .RUNNING_TRUE , new AlgorithmMachinery ())
192
+ );
193
+
194
+ filteredNodeSimilarity .compute ();
178
195
179
196
assertThat (log .getMessages (INFO ))
180
197
.extracting (removingThreadId ())
0 commit comments