Skip to content

Commit 0ad0372

Browse files
Split Evaluate on test task into two tasks
Co-Authored-By: Adam Schill Collberg<[email protected]> Co-Authored-By: Mats Rydberg <[email protected]>
1 parent b5f2102 commit 0ad0372

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrain.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ public static List<Task> progressTasks(NodePropertyPredictionSplitConfig splitCo
159159
numberOfModelSelectionTrials
160160
),
161161
ClassifierTrainer.progressTask("Train best model", 5 * trainSetSize),
162+
Tasks.leaf("Evaluate on train data", trainSetSize),
162163
Tasks.leaf("Evaluate on test data", testSetSize),
163164
ClassifierTrainer.progressTask("Retrain best model", 5 * nodeCount)
164165
);
@@ -401,16 +402,18 @@ private void evaluateBestModel(
401402
);
402403
progressTracker.endSubTask("Train best model");
403404

404-
progressTracker.beginSubTask("Evaluate on test data");
405-
progressTracker.setSteps(outerSplit.testSet().size() + outerSplit.trainSet().size());
405+
progressTracker.beginSubTask("Evaluate on train data");
406+
progressTracker.setSteps(outerSplit.trainSet().size());
406407
registerMetricScores(outerSplit.trainSet(), bestClassifier, trainingStatistics::addOuterTrainScore, progressTracker);
407408
var outerTrainMetrics = trainingStatistics.winningModelOuterTrainMetrics();
408409
progressTracker.logMessage(formatWithLocale("Final model metrics on full train set: %s", outerTrainMetrics));
410+
progressTracker.endSubTask("Evaluate on train data");
409411

412+
progressTracker.beginSubTask("Evaluate on test data");
413+
progressTracker.setSteps(outerSplit.testSet().size());
410414
registerMetricScores(outerSplit.testSet(), bestClassifier, trainingStatistics::addTestScore, progressTracker);
411415
var testMetrics = trainingStatistics.winningModelTestMetrics();
412416
progressTracker.logMessage(formatWithLocale("Final model metrics on test set: %s", testMetrics));
413-
414417
progressTracker.endSubTask("Evaluate on test data");
415418
}
416419

pipeline/src/test/resources/expectedLogs/node-classification-log

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ MY DUMMY TASK :: Train best model :: Epoch 5 with loss 0.6128
3333
MY DUMMY TASK :: Train best model :: converged after 5 out of 100 epochs. Initial loss: 0.6931, Last loss: 0.6128.
3434
MY DUMMY TASK :: Train best model 100%
3535
MY DUMMY TASK :: Train best model :: Finished
36+
MY DUMMY TASK :: Evaluate on train data :: Start
37+
MY DUMMY TASK :: Evaluate on train data 100%
38+
MY DUMMY TASK :: Evaluate on train data :: Final model metrics on full train set: {F1_class_1=0.8235}
39+
MY DUMMY TASK :: Evaluate on train data :: Finished
3640
MY DUMMY TASK :: Evaluate on test data :: Start
37-
MY DUMMY TASK :: Evaluate on test data 50%
38-
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on full train set: {F1_class_1=0.8235}
39-
MY DUMMY TASK :: Evaluate on test data 75%
40-
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on test set: {F1_class_1=0.7499}
4141
MY DUMMY TASK :: Evaluate on test data 100%
42+
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on test set: {F1_class_1=0.7499}
4243
MY DUMMY TASK :: Evaluate on test data :: Finished
4344
MY DUMMY TASK :: Retrain best model :: Start
4445
MY DUMMY TASK :: Retrain best model :: Initial loss 0.6931

pipeline/src/test/resources/expectedLogs/node-classification-with-range-log

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ MY DUMMY TASK :: Train best model :: Epoch 5 with loss 0.6128
3333
MY DUMMY TASK :: Train best model :: converged after 5 out of 100 epochs. Initial loss: 0.6931, Last loss: 0.6128.
3434
MY DUMMY TASK :: Train best model 100%
3535
MY DUMMY TASK :: Train best model :: Finished
36+
MY DUMMY TASK :: Evaluate on train data :: Start
37+
MY DUMMY TASK :: Evaluate on train data 100%
38+
MY DUMMY TASK :: Evaluate on train data :: Final model metrics on full train set: {F1_class_1=0.8235}
39+
MY DUMMY TASK :: Evaluate on train data :: Finished
3640
MY DUMMY TASK :: Evaluate on test data :: Start
37-
MY DUMMY TASK :: Evaluate on test data 50%
38-
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on full train set: {F1_class_1=0.8235}
39-
MY DUMMY TASK :: Evaluate on test data 75%
40-
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on test set: {F1_class_1=0.7499}
4141
MY DUMMY TASK :: Evaluate on test data 100%
42+
MY DUMMY TASK :: Evaluate on test data :: Final model metrics on test set: {F1_class_1=0.7499}
4243
MY DUMMY TASK :: Evaluate on test data :: Finished
4344
MY DUMMY TASK :: Retrain best model :: Start
4445
MY DUMMY TASK :: Retrain best model :: Initial loss 0.6931

0 commit comments

Comments
 (0)