Skip to content

Commit 1344961

Browse files
Make root progress available and sane for NC train
Co-Authored-By: Jacob Sznajdman <[email protected]>
1 parent c2639bc commit 1344961

File tree

8 files changed

+37
-28
lines changed

8 files changed

+37
-28
lines changed

ml/ml-algo/src/main/java/org/neo4j/gds/ml/nodeClassification/NodeClassificationPredictConsumer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public void accept(Batch batch) {
8888
currentRow++;
8989
}
9090

91-
progressTracker.logProgress(batch.size());
91+
progressTracker.logSteps(batch.size());
9292
}
9393

9494
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrain.java

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,21 @@ public static MemoryEstimation estimate(
146146
return builder.build();
147147
}
148148

149-
public static List<Task> progressTasks(int validationFolds, int numberOfModelSelectionTrials) {
149+
public static List<Task> progressTasks(NodePropertyPredictionSplitConfig splitConfig, int numberOfModelSelectionTrials, long nodeCount) {
150+
long trainSetSize = splitConfig.trainSetSize(nodeCount);
151+
long testSetSize = splitConfig.testSetSize(nodeCount);
152+
int validationFolds = splitConfig.validationFolds();
153+
150154
return List.of(
151-
Tasks.leaf("Shuffle and split"),
155+
Tasks.leaf("Shuffle and split", validationFolds * trainSetSize + testSetSize),
152156
Tasks.iterativeFixed(
153157
"Select best model",
154-
() -> List.of(Tasks.leaf("Trial", validationFolds)),
158+
() -> List.of(Tasks.leaf("Trial", 5 * validationFolds * trainSetSize)),
155159
numberOfModelSelectionTrials
156160
),
157-
ClassifierTrainer.progressTask("Train best model"),
158-
Tasks.leaf("Evaluate on test data"),
159-
ClassifierTrainer.progressTask("Retrain best model")
161+
ClassifierTrainer.progressTask("Train best model", 5 * trainSetSize),
162+
Tasks.leaf("Evaluate on test data", testSetSize),
163+
ClassifierTrainer.progressTask("Retrain best model", 5 * nodeCount)
160164
);
161165
}
162166

@@ -309,8 +313,10 @@ private void selectBestModel(List<TrainingExamplesSplit> nodeSplits, TrainingSta
309313
int trial = 0;
310314
while (hyperParameterOptimizer.hasNext()) {
311315
progressTracker.beginSubTask("Trial");
316+
progressTracker.setSteps(nodeSplits.size());
312317
var modelParams = hyperParameterOptimizer.next();
313318
progressTracker.logMessage(formatWithLocale("Method: %s, Parameters: %s", modelParams.method(), modelParams.toMap()));
319+
314320
var validationStatsBuilder = new ModelStatsBuilder(nodeSplits.size());
315321
var trainStatsBuilder = new ModelStatsBuilder(nodeSplits.size());
316322
var metricsHandler = ModelSpecificMetricsHandler.of(metrics, validationStatsBuilder);
@@ -323,7 +329,8 @@ private void selectBestModel(List<TrainingExamplesSplit> nodeSplits, TrainingSta
323329

324330
registerMetricScores(validationSet, classifier, validationStatsBuilder::update, ProgressTracker.NULL_TRACKER);
325331
registerMetricScores(trainSet, classifier, trainStatsBuilder::update, ProgressTracker.NULL_TRACKER);
326-
progressTracker.logProgress();
332+
333+
progressTracker.logSteps(1);
327334
}
328335

329336
var candidateStats = ModelCandidateStats.of(
@@ -394,11 +401,8 @@ private void evaluateBestModel(
394401
);
395402
progressTracker.endSubTask("Train best model");
396403

397-
progressTracker.beginSubTask(
398-
"Evaluate on test data",
399-
outerSplit.testSet().size() + outerSplit.trainSet().size()
400-
);
401-
404+
progressTracker.beginSubTask("Evaluate on test data");
405+
progressTracker.setSteps(outerSplit.testSet().size() + outerSplit.trainSet().size());
402406
registerMetricScores(outerSplit.trainSet(), bestClassifier, trainingStatistics::addOuterTrainScore, progressTracker);
403407
var outerTrainMetrics = trainingStatistics.winningModelOuterTrainMetrics();
404408
progressTracker.logMessage(formatWithLocale("Final model metrics on full train set: %s", outerTrainMetrics));

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineAlgorithmFactory.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ public String taskName() {
8989
public Task progressTask(GraphStore graphStore, NodeClassificationPipelineTrainConfig config) {
9090
return NodeClassificationTrainPipelineExecutor.progressTask(
9191
taskName(),
92-
PipelineCatalog .getTyped(config.username(), config.pipeline(), NodeClassificationTrainingPipeline.class)
92+
PipelineCatalog .getTyped(config.username(), config.pipeline(), NodeClassificationTrainingPipeline.class),
93+
graphStore.nodeCount()
9394
);
9495
}
9596
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutor.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,19 @@ public NodeClassificationTrainPipelineExecutor(
5959
super(pipeline, config, executionContext, graphStore, graphName, progressTracker);
6060
}
6161

62-
public static Task progressTask(String taskName, NodeClassificationTrainingPipeline pipeline) {
62+
public static Task progressTask(String taskName, NodeClassificationTrainingPipeline pipeline, long nodeCount) {
6363
return Tasks.task(
6464
taskName,
6565
new ArrayList<>() {{
6666
add(Tasks.iterativeFixed(
6767
"Execute node property steps",
68-
() -> List.of(Tasks.leaf("Step")),
68+
() -> List.of(Tasks.leaf("Step", 10L * nodeCount)),
6969
pipeline.nodePropertySteps().size()
7070
));
7171
addAll(NodeClassificationTrain.progressTasks(
72-
pipeline.splitConfig().validationFolds(),
73-
pipeline.numberOfModelSelectionTrials()
72+
pipeline.splitConfig(),
73+
pipeline.numberOfModelSelectionTrials(),
74+
nodeCount
7475
));
7576

7677
}}

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainPipelineExecutorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ void shouldLogProgress() {
250250
TestProcedureRunner.applyOnProcedure(db, TestProc.class, caller -> {
251251
var log = Neo4jProxy.testLog();
252252
var progressTracker = new TestProgressTracker(
253-
NodeClassificationTrainPipelineExecutor.progressTask("Node Classification Train Pipeline", pipeline),
253+
NodeClassificationTrainPipelineExecutor.progressTask("Node Classification Train Pipeline", pipeline, graphStore.nodeCount()),
254254
log,
255255
1,
256256
EmptyTaskRegistryFactory.INSTANCE

pipeline/src/test/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrainTest.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,9 @@ void shouldLogProgress() {
406406
var config = createConfig("bananasModel", metrics, 42L);
407407

408408
var progressTask = progressTask(
409-
pipeline.splitConfig().validationFolds(),
410-
pipeline.numberOfModelSelectionTrials()
409+
pipeline.splitConfig(),
410+
pipeline.numberOfModelSelectionTrials(),
411+
graph.nodeCount()
411412
);
412413
var testLog = Neo4jProxy.testLog();
413414
var progressTracker = new TestProgressTracker(progressTask, testLog, 1, EmptyTaskRegistryFactory.INSTANCE);
@@ -440,7 +441,7 @@ void shouldLogProgressWithRange() {
440441
var metrics = ClassificationMetricSpecification.parse("F1(class=1)");
441442
var config = createConfig("bananasModel", metrics, 42L);
442443

443-
var progressTask = progressTask(pipeline.splitConfig().validationFolds(), MAX_TRIALS);
444+
var progressTask = progressTask(pipeline.splitConfig(), MAX_TRIALS, graph.nodeCount());
444445
var testLog = Neo4jProxy.testLog();
445446
var progressTracker = new TestProgressTracker(progressTask, testLog, 1, EmptyTaskRegistryFactory.INSTANCE);
446447

@@ -494,10 +495,10 @@ void seededNodeClassification(int concurrency) {
494495
));
495496
}
496497

497-
private static Task progressTask(int validationFolds, int trials) {
498+
private static Task progressTask(NodePropertyPredictionSplitConfig splitConfig, int trials, long nodeCount) {
498499
return Tasks.task(
499500
"MY DUMMY TASK",
500-
NodeClassificationTrain.progressTasks(validationFolds, trials)
501+
NodeClassificationTrain.progressTasks(splitConfig, trials, nodeCount)
501502
);
502503
}
503504

pipeline/src/test/resources/expectedLogs/node-classification-log

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ MY DUMMY TASK :: Train best model :: converged after 5 out of 100 epochs. Initia
3434
MY DUMMY TASK :: Train best model 100%
3535
MY DUMMY TASK :: Train best model :: Finished
3636
MY DUMMY TASK :: Evaluate on test data :: Start
37-
MY DUMMY TASK :: Evaluate on test data 66%
37+
MY DUMMY TASK :: Evaluate on test data 50%
3838
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on full train set: {F1_class_1=0.8235}
39-
MY DUMMY TASK :: Evaluate on test data 100%
39+
MY DUMMY TASK :: Evaluate on test data 75%
4040
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on test set: {F1_class_1=0.7499}
41+
MY DUMMY TASK :: Evaluate on test data 100%
4142
MY DUMMY TASK :: Evaluate on test data :: Finished
4243
MY DUMMY TASK :: Retrain best model :: Start
4344
MY DUMMY TASK :: Retrain best model :: Initial loss 0.6931

pipeline/src/test/resources/expectedLogs/node-classification-with-range-log

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ MY DUMMY TASK :: Train best model :: converged after 5 out of 100 epochs. Initia
3434
MY DUMMY TASK :: Train best model 100%
3535
MY DUMMY TASK :: Train best model :: Finished
3636
MY DUMMY TASK :: Evaluate on test data :: Start
37-
MY DUMMY TASK :: Evaluate on test data 66%
37+
MY DUMMY TASK :: Evaluate on test data 50%
3838
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on full train set: {F1_class_1=0.8235}
39-
MY DUMMY TASK :: Evaluate on test data 100%
39+
MY DUMMY TASK :: Evaluate on test data 75%
4040
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on test set: {F1_class_1=0.7499}
41+
MY DUMMY TASK :: Evaluate on test data 100%
4142
MY DUMMY TASK :: Evaluate on test data :: Finished
4243
MY DUMMY TASK :: Retrain best model :: Start
4344
MY DUMMY TASK :: Retrain best model :: Initial loss 0.6931

0 commit comments

Comments
 (0)