Skip to content

Commit c2639bc

Browse files
Merge pull request #5367 from breakanalysis/initialize-volumes-earlier
Introduce logSteps for apriori unkown volumes
2 parents 0ca3ba3 + 679dcf0 commit c2639bc

File tree

17 files changed

+140
-32
lines changed

17 files changed

+140
-32
lines changed

core/src/main/java/org/neo4j/gds/core/utils/progress/tasks/LeafTask.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@ public class LeafTask extends Task {
3636
@Override
3737
public void finish() {
3838
super.finish();
39-
setVolume(currentProgress.longValue());
39+
40+
// This task should now be considered to have 100% progress.
41+
if (volume == UNKNOWN_VOLUME) {
42+
volume = currentProgress.longValue();
43+
}
44+
currentProgress.add(volume - currentProgress.longValue());
4045
}
4146

4247
@Override

core/src/main/java/org/neo4j/gds/core/utils/progress/tasks/ProgressTracker.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ default void logProgress() {
5151

5252
void logProgress(long value, String messageTemplate);
5353

54+
// prefer setting volume via factory method for leaves
55+
// to make root progress available from the start
56+
@Deprecated
5457
void setVolume(long volume);
5558

5659
void logDebug(String message);
@@ -61,6 +64,10 @@ default void logProgress() {
6164

6265
void release();
6366

67+
void setSteps(long steps);
68+
69+
void logSteps(long steps);
70+
6471
class EmptyProgressTracker implements ProgressTracker {
6572

6673
@Override
@@ -108,6 +115,16 @@ public void logProgress(long value, String messageTemplate) {
108115
public void setVolume(long volume) {
109116
}
110117

118+
@Override
119+
public void setSteps(long steps) {
120+
121+
}
122+
123+
@Override
124+
public void logSteps(long steps) {
125+
126+
}
127+
111128
@Override
112129
public void logDebug(String message) {
113130

core/src/main/java/org/neo4j/gds/core/utils/progress/tasks/TaskProgressTracker.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@
3737

3838
public class TaskProgressTracker implements ProgressTracker {
3939

40+
private static final long UNKNOWN_STEPS = -1;
41+
4042
private final Task baseTask;
4143
private final TaskRegistry taskRegistry;
4244
private final UserLogRegistry userLogRegistry;
4345
private final TaskProgressLogger taskProgressLogger;
4446
private final Stack<Task> nestedTasks;
4547
protected Optional<Task> currentTask;
48+
private long currentTotalSteps;
49+
private double progressLeftOvers;
4650

4751
public TaskProgressTracker(Task baseTask, Log log, int concurrency, TaskRegistryFactory taskRegistryFactory) {
4852
this(baseTask, log, concurrency, new JobId(), taskRegistryFactory, EmptyUserLogRegistryFactory.INSTANCE);
@@ -56,6 +60,8 @@ public TaskProgressTracker(
5660
this.taskRegistry = taskRegistryFactory.newInstance(jobId);
5761
this.taskProgressLogger = new TaskProgressLogger(log, baseTask, concurrency);
5862
this.currentTask = Optional.empty();
63+
this.currentTotalSteps = UNKNOWN_STEPS;
64+
this.progressLeftOvers = 0;
5965
this.nestedTasks = new Stack<>();
6066
this.userLogRegistry = userLogRegistryFactory.newInstance();
6167
}
@@ -76,6 +82,8 @@ public void beginSubTask() {
7682
nextTask.start();
7783
taskProgressLogger.logBeginSubTask(nextTask, parentTask());
7884
currentTask = Optional.of(nextTask);
85+
currentTotalSteps = UNKNOWN_STEPS;
86+
progressLeftOvers = 0;
7987
}
8088

8189
@Override
@@ -90,6 +98,26 @@ public void beginSubTask(long taskVolume) {
9098
setVolume(taskVolume);
9199
}
92100

101+
@Override
102+
public void setSteps(long steps) {
103+
if (steps <= 0) {
104+
throw new IllegalStateException(formatWithLocale(
105+
"Total steps for task must be at least 1 but was %d",
106+
steps
107+
));
108+
}
109+
currentTotalSteps = steps;
110+
}
111+
112+
@Override
113+
public void logSteps(long steps) {
114+
long volume = requireCurrentTask().getProgress().volume();
115+
double progress = steps * volume / (double) currentTotalSteps + progressLeftOvers;
116+
long longProgress = (long) progress;
117+
progressLeftOvers = progress - longProgress;
118+
logProgress(longProgress);
119+
}
120+
93121
@Override
94122
public void beginSubTask(String expectedTaskDescription, long taskVolume) {
95123
beginSubTask();

core/src/test/java/org/neo4j/gds/core/utils/progress/tasks/TaskProgressTrackerTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,24 @@ void shouldRegisterBaseTaskOnBaseTaskStart() {
208208
assertThat(taskStore.query("")).containsValue(task);
209209
}
210210

211+
@Test
212+
void stepsShouldGiveProgress() {
213+
var leafTask = Tasks.leaf("leaf", 100);
214+
var progressTracker = progressTracker(leafTask);
215+
216+
progressTracker.beginSubTask();
217+
progressTracker.setSteps(13);
218+
progressTracker.logProgress(3);
219+
progressTracker.logSteps(1);
220+
double expectedDoubleProgressFromFirstStep = 100.0 * 1.0 / 13.0;
221+
long progressAfterFirstStep = leafTask.getProgress().progress();
222+
assertThat(progressAfterFirstStep).isEqualTo((long) expectedDoubleProgressFromFirstStep + 3);
223+
224+
progressTracker.logProgress(1);
225+
progressTracker.logSteps(4);
226+
assertThat(leafTask.getProgress().progress()).isEqualTo(3 + 1 + (long) (100.0 * 5.0 / 13));
227+
}
228+
211229
private TaskProgressTracker progressTracker(Task task, Log log) {
212230
return new TaskProgressTracker(task, log, 1, EmptyTaskRegistryFactory.INSTANCE);
213231
}

core/src/test/java/org/neo4j/gds/core/utils/progress/tasks/TaskTest.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,24 @@ void shouldSetVolumeLate() {
216216
}
217217

218218
@Test
219-
void shouldSetVolumeWhenFinishingTask() {
219+
void shouldSetProgressWhenFinishingTask() {
220220
var task = Tasks.iterativeOpen("root", () -> List.of(Tasks.leaf("leaf")));
221221
task.start();
222222
var leaf1 = task.nextSubtask();
223223
leaf1.start();
224224
leaf1.logProgress(22L);
225225
leaf1.finish();
226226

227+
assertThat(leaf1.getProgress().progress()).isEqualTo(22L);
227228
assertThat(leaf1.getProgress().volume()).isEqualTo(22L);
228229
assertThat(task.getProgress().volume()).isEqualTo(Task.UNKNOWN_VOLUME);
229230

230231
var leaf2 = task.nextSubtask();
231232
leaf2.start();
232-
leaf2.logProgress(20L);
233+
leaf2.setVolume(20L);
233234
leaf2.finish();
234235

236+
assertThat(leaf2.getProgress().progress()).isEqualTo(20L);
235237
assertThat(leaf2.getProgress().volume()).isEqualTo(20L);
236238
assertThat(task.getProgress().volume()).isEqualTo(Task.UNKNOWN_VOLUME);
237239

ml/ml-algo/src/main/java/org/neo4j/gds/ml/metrics/SignedProbabilities.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public static SignedProbabilities computeFromLabeledData(
119119
TerminationFlag terminationFlag,
120120
ProgressTracker progressTracker
121121
) {
122-
progressTracker.setVolume(features.size());
122+
progressTracker.setSteps(features.size());
123123

124124
var signedProbabilities = SignedProbabilities.create(evaluationQueue.totalSize());
125125

@@ -134,7 +134,7 @@ public static SignedProbabilities computeFromLabeledData(
134134

135135
signedProbabilities.add(probabilityOfPositiveEdge, isEdge);
136136
}
137-
progressTracker.logProgress(batch.size());
137+
progressTracker.logSteps(batch.size());
138138
},
139139
terminationFlag
140140
);

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/ClassifierTrainer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ public interface ClassifierTrainer {
2828

2929
Classifier train(Features features, HugeLongArray labels, ReadOnlyHugeLongArray trainSet);
3030

31+
static Task progressTask(String taskName, long volume) {
32+
return Tasks.leaf(taskName, volume);
33+
}
34+
3135
static Task progressTask(String taskName) {
3236
return Tasks.leaf(taskName);
3337
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/BatchLinkFeatureExtractor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public void run() {
6161
}));
6262
});
6363

64-
progressTracker.logProgress(partition.totalDegree());
64+
progressTracker.logSteps(partition.totalDegree());
6565
}
6666
}
6767

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkFeaturesAndLabelsExtractor.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ static FeaturesAndLabels extractFeaturesAndLabels(
7979
ProgressTracker progressTracker,
8080
TerminationFlag terminationFlag
8181
) {
82-
progressTracker.setVolume(graph.relationshipCount() * 2);
82+
progressTracker.setSteps(graph.relationshipCount() * 2);
8383
var features = LinkFeatureExtractor.extractFeatures(
8484
graph,
8585
featureSteps,
@@ -125,7 +125,7 @@ private static HugeLongArray extractLabels(
125125
}
126126
return true;
127127
}));
128-
progressTracker.logProgress(partition.totalDegree());
128+
progressTracker.logSteps(partition.totalDegree());
129129
}
130130
);
131131
relationshipOffset.add(partition.totalDegree());

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,25 @@ public LinkPredictionTrain(
9595
this.classIdMap = makeClassIdMap();
9696
}
9797

98-
public static List<Task> progressTasks(int validationFolds, int numberOfModelSelectionTrials) {
98+
public static List<Task> progressTasks(
99+
long relationshipCount,
100+
LinkPredictionSplitConfig splitConfig,
101+
int numberOfModelSelectionTrials
102+
) {
103+
var sizes = splitConfig.expectedSetSizes(relationshipCount);
99104
return List.of(
100-
Tasks.leaf("Extract train features"),
105+
Tasks.leaf("Extract train features", sizes.trainSize() * 3),
101106
Tasks.iterativeFixed(
102107
"Select best model",
103-
() -> List.of(Tasks.leaf("Trial", validationFolds)),
108+
() -> List.of(Tasks.leaf("Trial", splitConfig.validationFolds() * sizes.trainSize() * 5)),
104109
numberOfModelSelectionTrials
105110
),
106-
ClassifierTrainer.progressTask("Train best model"),
107-
Tasks.leaf("Compute train metrics"),
111+
ClassifierTrainer.progressTask("Train best model", sizes.trainSize() * 5),
112+
Tasks.leaf("Compute train metrics", sizes.trainSize()),
108113
Tasks.task(
109114
"Evaluate on test data",
110-
Tasks.leaf("Extract test features"),
111-
Tasks.leaf("Compute test metrics")
115+
Tasks.leaf("Extract test features", sizes.testSize() * 3),
116+
Tasks.leaf("Compute test metrics", sizes.testSize())
112117
)
113118
);
114119
}
@@ -208,6 +213,7 @@ private void modelSelect(
208213
int trial = 0;
209214
while (hyperParameterOptimizer.hasNext()) {
210215
progressTracker.beginSubTask();
216+
progressTracker.setSteps(pipeline.splitConfig().validationFolds());
211217
var modelParams = hyperParameterOptimizer.next();
212218
progressTracker.logMessage(formatWithLocale("Method: %s, Parameters: %s", modelParams.method(), modelParams.toMap()));
213219
var trainStatsBuilder = new ModelStatsBuilder(pipeline.splitConfig().validationFolds());
@@ -242,7 +248,7 @@ private void modelSelect(
242248
ProgressTracker.NULL_TRACKER
243249
);
244250

245-
progressTracker.logProgress();
251+
progressTracker.logSteps(1);
246252
}
247253

248254
// insert the candidates' metrics into trainStats and validationStats

0 commit comments

Comments
 (0)