Skip to content

Commit b5f2102

Browse files
adamnschbreakanalysis
authored andcommitted
Make root progress available and sane for NC predict
1 parent 1344961 commit b5f2102

File tree

4 files changed

+7
-8
lines changed

4 files changed

+7
-8
lines changed

ml/ml-algo/src/main/java/org/neo4j/gds/ml/nodeClassification/NodeClassificationPredict.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.jetbrains.annotations.Nullable;
2323
import org.neo4j.gds.Algorithm;
2424
import org.neo4j.gds.annotation.ValueClass;
25-
import org.neo4j.gds.api.Graph;
2625
import org.neo4j.gds.core.utils.TerminationFlag;
2726
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
2827
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
@@ -74,8 +73,8 @@ public NodeClassificationPredict(
7473
);
7574
}
7675

77-
public static Task progressTask(Graph graph) {
78-
return Tasks.leaf("Node classification predict", graph.nodeCount());
76+
public static Task progressTask(long nodeCount) {
77+
return Tasks.leaf("Node classification predict", nodeCount);
7978
}
8079

8180
public static MemoryEstimation memoryEstimation(
@@ -138,6 +137,7 @@ public static MemoryEstimation memoryEstimationWithDerivedBatchSize(
138137
@Override
139138
public NodeClassificationResult compute() {
140139
progressTracker.beginSubTask();
140+
progressTracker.setSteps(features.size());
141141
var predictedProbabilities = initProbabilities();
142142
var predictedClasses = predictor.predict(predictedProbabilities);
143143
progressTracker.endSubTask();

ml/ml-algo/src/test/java/org/neo4j/gds/ml/nodeClassification/NodeClassificationPredictTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ void shouldLogProgress() {
222222

223223
var log = Neo4jProxy.testLog();
224224
var progressTracker = new TaskProgressTracker(
225-
NodeClassificationPredict.progressTask(graph),
225+
NodeClassificationPredict.progressTask(graph.nodeCount()),
226226
log,
227227
1,
228228
new JobId(),

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineAlgorithmFactory.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ public Task progressTask(GraphStore graphStore, CONFIG config) {
5252
modelCatalog,
5353
config.modelName(),
5454
config.username()
55-
).customInfo()
56-
.pipeline();
55+
).customInfo().pipeline();
5756

5857
return NodeClassificationPredictPipelineExecutor.progressTask(taskName(), trainingPipeline, graphStore);
5958
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineExecutor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ public static Task progressTask(String taskName, NodePropertyPredictPipeline pip
7171
taskName,
7272
Tasks.iterativeFixed(
7373
"Execute node property steps",
74-
() -> List.of(Tasks.leaf("Step")),
74+
() -> List.of(Tasks.leaf("Step", 10 * graphStore.getUnion().nodeCount())),
7575
pipeline.nodePropertySteps().size()
7676
),
77-
Tasks.leaf("Node classification predict", graphStore.getUnion().nodeCount())
77+
NodeClassificationPredict.progressTask(graphStore.getUnion().nodeCount())
7878
);
7979
}
8080

0 commit comments

Comments
 (0)