Skip to content

Commit 92a6c8a

Browse files
authored
Merge pull request #9751 from lassewesth/lpppp4
migrate link prediction stream estimate
2 parents 56d6bf5 + 82041f8 commit 92a6c8a

File tree

7 files changed

+132
-29
lines changed

7 files changed

+132
-29
lines changed

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamProc.java

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@
1919
*/
2020
package org.neo4j.gds.ml.linkmodels.pipeline.predict;
2121

22-
import org.neo4j.gds.BaseProc;
23-
import org.neo4j.gds.core.model.ModelCatalog;
24-
import org.neo4j.gds.executor.ExecutionContext;
25-
import org.neo4j.gds.executor.MemoryEstimationExecutor;
26-
import org.neo4j.gds.executor.ProcedureExecutor;
2722
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
23+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
24+
import org.neo4j.gds.procedures.pipelines.StreamResult;
2825
import org.neo4j.procedure.Context;
2926
import org.neo4j.procedure.Description;
3027
import org.neo4j.procedure.Mode;
@@ -36,24 +33,18 @@
3633

3734
import static org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion.ESTIMATE_PREDICT_DESCRIPTION;
3835
import static org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion.PREDICT_DESCRIPTION;
39-
import static org.neo4j.gds.ml.pipeline.PipelineCompanion.preparePipelineConfig;
40-
41-
public class LinkPredictionPipelineStreamProc extends BaseProc {
4236

37+
public class LinkPredictionPipelineStreamProc {
4338
@Context
44-
public ModelCatalog modelCatalog;
39+
public GraphDataScienceProcedures facade;
4540

4641
@Procedure(name = "gds.beta.pipeline.linkPrediction.predict.stream", mode = Mode.READ)
4742
@Description(PREDICT_DESCRIPTION)
4843
public Stream<StreamResult> stream(
4944
@Name(value = "graphName") String graphName,
5045
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
5146
) {
52-
preparePipelineConfig(graphName, configuration);
53-
return new ProcedureExecutor<>(
54-
new LinkPredictionPipelineStreamSpec(),
55-
executionContext()
56-
).compute(graphName, configuration);
47+
return facade.pipelines().linkPrediction().stream(graphName, configuration);
5748
}
5849

5950
@Procedure(name = "gds.beta.pipeline.linkPrediction.predict.stream.estimate", mode = Mode.READ)
@@ -62,17 +53,6 @@ public Stream<MemoryEstimateResult> estimate(
6253
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
6354
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
6455
) {
65-
preparePipelineConfig(graphNameOrConfiguration, algoConfiguration);
66-
return new MemoryEstimationExecutor<>(
67-
new LinkPredictionPipelineStreamSpec(),
68-
executionContext(),
69-
transactionContext()
70-
).computeEstimate(graphNameOrConfiguration, algoConfiguration);
71-
}
72-
73-
@Override
74-
public ExecutionContext executionContext() {
75-
return super.executionContext().withModelCatalog(modelCatalog);
56+
return facade.pipelines().linkPrediction().streamEstimate(graphNameOrConfiguration, algoConfiguration);
7657
}
77-
7858
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamSpec.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
2929
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor;
3030
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineStreamConfig;
31+
import org.neo4j.gds.procedures.pipelines.StreamResult;
3132

3233
import java.util.Collection;
3334
import java.util.stream.Stream;

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LinkPredictionFacade.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,31 @@ public Stream<MemoryEstimateResult> mutateEstimate(
169169

170170
return Stream.of(result);
171171
}
172+
173+
public Stream<StreamResult> stream(
174+
String graphNameAsString,
175+
Map<String, Object> configuration
176+
) {
177+
PipelineCompanion.preparePipelineConfig(graphNameAsString, configuration);
178+
179+
var graphName = GraphName.parse(graphNameAsString);
180+
181+
return pipelineApplications.linkPredictionStream(graphName, configuration);
182+
}
183+
184+
public Stream<MemoryEstimateResult> streamEstimate(
185+
Object graphNameOrConfiguration,
186+
Map<String, Object> rawConfiguration
187+
) {
188+
PipelineCompanion.preparePipelineConfig(graphNameOrConfiguration, rawConfiguration);
189+
190+
var configuration = pipelineConfigurationParser.parseLinkPredictionPredictPipelineStreamConfig(rawConfiguration);
191+
192+
var result = pipelineApplications.linkPredictionEstimate(
193+
graphNameOrConfiguration,
194+
configuration
195+
);
196+
197+
return Stream.of(result);
198+
}
172199
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.pipelines;
21+
22+
import org.neo4j.gds.api.Graph;
23+
import org.neo4j.gds.api.GraphStore;
24+
import org.neo4j.gds.applications.algorithms.machinery.StreamResultBuilder;
25+
import org.neo4j.gds.logging.Log;
26+
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
27+
28+
import java.util.Optional;
29+
import java.util.stream.Stream;
30+
31+
class LinkPredictionPipelineStreamResultBuilder implements StreamResultBuilder<LinkPredictionResult, StreamResult> {
32+
private final Log log;
33+
private final TrainedLPPipelineModel trainedLPPipelineModel;
34+
private final LinkPredictionPredictPipelineStreamConfig configuration;
35+
36+
public LinkPredictionPipelineStreamResultBuilder(
37+
Log log,
38+
TrainedLPPipelineModel trainedLPPipelineModel,
39+
LinkPredictionPredictPipelineStreamConfig configuration
40+
) {
41+
this.log = log;
42+
this.trainedLPPipelineModel = trainedLPPipelineModel;
43+
this.configuration = configuration;
44+
}
45+
46+
@Override
47+
public Stream<StreamResult> build(
48+
Graph graph,
49+
GraphStore graphStore,
50+
Optional<LinkPredictionResult> result
51+
) {
52+
if (result.isEmpty()) return Stream.empty();
53+
54+
var linkPredictionResult = result.get();
55+
56+
var model = trainedLPPipelineModel.get(
57+
configuration.modelName(),
58+
configuration.username()
59+
);
60+
61+
var lpGraphStoreFilter = LPGraphStoreFilterFactory.generate(
62+
log, model.trainConfig(),
63+
configuration,
64+
graphStore
65+
);
66+
67+
var filteredGraph = graphStore.getGraph(lpGraphStoreFilter.predictNodeLabels());
68+
69+
return linkPredictionResult.stream().map(predictedLink -> new StreamResult(
70+
filteredGraph.toOriginalNodeId(predictedLink.sourceId()),
71+
filteredGraph.toOriginalNodeId(predictedLink.targetId()),
72+
predictedLink.probability()
73+
));
74+
}
75+
}

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineApplications.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,24 @@ MutateResult linkPredictionMutate(GraphName graphName, Map<String, Object> rawCo
389389
);
390390
}
391391

392+
Stream<StreamResult> linkPredictionStream(GraphName graphName, Map<String, Object> rawConfiguration) {
393+
var configuration = pipelineConfigurationParser.parseLinkPredictionPredictPipelineStreamConfig(rawConfiguration);
394+
var label = new StandardLabel("LinkPredictionPipelineStream");
395+
var computation = constructLinkPredictionComputation(configuration, label);
396+
var resultBuilder = new LinkPredictionPipelineStreamResultBuilder(log, trainedLPPipelineModel, configuration);
397+
398+
return algorithmProcessingTemplate.processAlgorithmForStream(
399+
Optional.empty(),
400+
graphName,
401+
configuration,
402+
Optional.empty(),
403+
label,
404+
() -> linkPredictionMemoryEstimation(configuration),
405+
computation,
406+
resultBuilder
407+
);
408+
}
409+
392410
MemoryEstimateResult nodeClassificationPredictEstimate(
393411
Object graphNameOrConfiguration,
394412
NodeClassificationPredictPipelineBaseConfig configuration

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/PipelineConfigurationParser.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ LinkPredictionSplitConfig parseLinkPredictionSplitConfig(Map<String, Object> raw
7272
);
7373
}
7474

75+
LinkPredictionPredictPipelineStreamConfig parseLinkPredictionPredictPipelineStreamConfig(Map<String, Object> configuration) {
76+
return parseConfiguration(LinkPredictionPredictPipelineStreamConfig::of, configuration);
77+
}
78+
7579
TunableTrainerConfig parseLogisticRegressionTrainerConfig(Map<String, Object> configuration) {
7680
return parseTrainerConfiguration(
7781
configuration,
Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.ml.linkmodels.pipeline.predict;
20+
package org.neo4j.gds.procedures.pipelines;
2121

22-
@SuppressWarnings("unused")
2322
public final class StreamResult {
24-
2523
public final long node1;
2624
public final long node2;
2725
public final double probability;

0 commit comments

Comments
 (0)