Skip to content

Commit 261ac0e

Browse files
committed
Limit work stealing of dynamic children to current entry
1 parent e0bc61b commit 261ac0e

File tree

1 file changed

+34
-62
lines changed

1 file changed

+34
-62
lines changed

junit-platform-engine/src/main/java/org/junit/platform/engine/support/hierarchical/ConcurrentHierarchicalTestExecutorService.java

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
import static org.junit.platform.engine.support.hierarchical.ExclusiveResource.GLOBAL_READ_WRITE;
2020
import static org.junit.platform.engine.support.hierarchical.Node.ExecutionMode.SAME_THREAD;
2121

22-
import java.util.ArrayDeque;
2322
import java.util.ArrayList;
2423
import java.util.Collection;
25-
import java.util.Deque;
2624
import java.util.List;
2725
import java.util.PriorityQueue;
2826
import java.util.Queue;
@@ -99,8 +97,7 @@ public void close() {
9997
}
10098

10199
var entry = enqueue(testTask);
102-
workerThread.addForkedChild(entry);
103-
return new BlockingAwareFuture<@Nullable Void>(entry.future(), WorkerThread.BlockHandler.INSTANCE);
100+
return new BlockingAwareFuture<@Nullable Void>(entry.future(), new WorkerThread.BlockHandler(entry));
104101
}
105102

106103
@Override
@@ -175,8 +172,6 @@ public Thread newThread(Runnable runnable) {
175172

176173
private class WorkerThread extends Thread {
177174

178-
private final Deque<State> state = new ArrayDeque<>();
179-
180175
@Nullable
181176
WorkerLease workerLease;
182177

@@ -250,19 +245,15 @@ void invokeAll(List<? extends TestTask> testTasks) {
250245

251246
List<TestTask> isolatedTasks = new ArrayList<>(testTasks.size());
252247
List<TestTask> sameThreadTasks = new ArrayList<>(testTasks.size());
253-
forkConcurrentChildren(testTasks, isolatedTasks::add, sameThreadTasks);
248+
var concurrentTasks = forkConcurrentChildren(testTasks, isolatedTasks::add, sameThreadTasks);
254249
executeAll(sameThreadTasks);
255-
var remainingForkedChildren = stealWork();
250+
var remainingForkedChildren = stealWork(concurrentTasks);
256251
waitFor(remainingForkedChildren);
257252
executeAll(isolatedTasks);
258253
}
259254

260-
void addForkedChild(WorkQueue.Entry entry) {
261-
getForkedChildren().add(entry);
262-
}
263-
264-
private void forkConcurrentChildren(List<? extends TestTask> children, Consumer<TestTask> isolatedTaskCollector,
265-
List<TestTask> sameThreadTasks) {
255+
private Queue<WorkQueue.Entry> forkConcurrentChildren(List<? extends TestTask> children,
256+
Consumer<TestTask> isolatedTaskCollector, List<TestTask> sameThreadTasks) {
266257

267258
Queue<WorkQueue.Entry> queueEntries = new PriorityQueue<>(children.size(), reverseOrder());
268259
for (TestTask child : children) {
@@ -276,48 +267,44 @@ else if (child.getExecutionMode() == SAME_THREAD) {
276267
queueEntries.add(WorkQueue.Entry.create(child));
277268
}
278269
}
270+
279271
if (!queueEntries.isEmpty()) {
280272
if (sameThreadTasks.isEmpty()) {
281273
// hold back one task for this thread
282274
sameThreadTasks.add(queueEntries.poll().task);
283275
}
284276
forkAll(queueEntries);
285-
getForkedChildren().addAll(queueEntries);
286277
}
278+
279+
return queueEntries;
287280
}
288281

289-
private List<WorkQueue.Entry> stealWork() {
290-
var forkedChildren = getForkedChildren();
291-
List<WorkQueue.Entry> concurrentlyExecutingChildren = new ArrayList<>(forkedChildren.size());
282+
private List<WorkQueue.Entry> stealWork(Queue<WorkQueue.Entry> concurrentTasks) {
283+
List<WorkQueue.Entry> concurrentlyExecutingChildren = new ArrayList<>(concurrentTasks.size());
292284
WorkQueue.Entry entry;
293-
while ((entry = forkedChildren.poll()) != null) {
294-
if (entry.future.isDone()) {
285+
while ((entry = concurrentTasks.poll()) != null) {
286+
var executed = tryToStealWork(entry);
287+
if (!executed) {
295288
concurrentlyExecutingChildren.add(entry);
296289
}
297-
else {
298-
var claimed = workQueue.remove(entry);
299-
if (claimed) {
300-
var executed = tryExecuteStolenEntry(entry);
301-
if (!executed) {
302-
workQueue.reAdd(entry);
303-
concurrentlyExecutingChildren.add(entry);
304-
}
305-
}
306-
else {
307-
concurrentlyExecutingChildren.add(entry);
308-
}
309-
}
310290
}
311291
return concurrentlyExecutingChildren;
312292
}
313293

314-
private Queue<WorkQueue.Entry> getForkedChildren() {
315-
return currentState().forkedChildren;
316-
}
317-
318-
private boolean tryExecuteStolenEntry(WorkQueue.Entry entry) {
319-
LOGGER.trace(() -> "stole work: " + entry);
320-
return tryExecute(entry);
294+
private boolean tryToStealWork(WorkQueue.Entry entry) {
295+
if (entry.future.isDone()) {
296+
return false;
297+
}
298+
var claimed = workQueue.remove(entry);
299+
if (claimed) {
300+
LOGGER.trace(() -> "stole work: " + entry);
301+
var executed = tryExecute(entry);
302+
if (!executed) {
303+
workQueue.reAdd(entry);
304+
}
305+
return executed;
306+
}
307+
return false;
321308
}
322309

323310
private void waitFor(List<WorkQueue.Entry> children) {
@@ -332,8 +319,7 @@ private void waitFor(List<WorkQueue.Entry> children) {
332319
}
333320
else {
334321
runBlocking(() -> {
335-
LOGGER.trace(() -> "blocking for forked children of %s: %s".formatted(
336-
currentState().executingTask, children));
322+
LOGGER.trace(() -> "blocking for forked children : %s".formatted(children));
337323
return future.join();
338324
});
339325
}
@@ -351,8 +337,7 @@ private void executeAll(List<? extends TestTask> children) {
351337
if (children.isEmpty()) {
352338
return;
353339
}
354-
LOGGER.trace(
355-
() -> "running %d children of %s directly".formatted(children.size(), currentState().executingTask));
340+
LOGGER.trace(() -> "running %d children directly".formatted(children.size()));
356341
if (children.size() == 1) {
357342
executeTask(children.get(0));
358343
return;
@@ -429,20 +414,14 @@ private boolean tryExecuteTask(TestTask testTask) {
429414

430415
private void doExecute(TestTask testTask) {
431416
LOGGER.trace(() -> "executing: " + testTask);
432-
this.state.push(new State(testTask));
433417
try {
434418
testTask.execute();
435419
}
436420
finally {
437-
this.state.pop();
438421
LOGGER.trace(() -> "finished executing: " + testTask);
439422
}
440423
}
441424

442-
private State currentState() {
443-
return state.element();
444-
}
445-
446425
private static CompletableFuture<?> toCombinedFuture(List<WorkQueue.Entry> entries) {
447426
if (entries.size() == 1) {
448427
return entries.get(0).future();
@@ -451,28 +430,20 @@ private static CompletableFuture<?> toCombinedFuture(List<WorkQueue.Entry> entri
451430
return CompletableFuture.allOf(futures);
452431
}
453432

454-
private record State(TestTask executingTask, Queue<WorkQueue.Entry> forkedChildren) {
455-
State(TestTask executingTask) {
456-
this(executingTask, new PriorityQueue<>(reverseOrder()));
457-
}
458-
}
459-
460433
private interface BlockingAction<T> {
461434
T run() throws InterruptedException;
462435
}
463436

464-
private static class BlockHandler implements BlockingAwareFuture.BlockHandler {
465-
466-
private static final BlockHandler INSTANCE = new BlockHandler();
437+
private record BlockHandler(WorkQueue.Entry entry) implements BlockingAwareFuture.BlockHandler {
467438

468439
@Override
469440
public <T> T handle(Supplier<Boolean> blockingUnnecessary, Callable<T> callable) throws Exception {
470441
var workerThread = get();
471-
if (workerThread == null || blockingUnnecessary.get()) {
442+
if (workerThread == null || entry.future.isDone()) {
472443
return callable.call();
473444
}
474-
workerThread.stealWork();
475-
if (blockingUnnecessary.get()) {
445+
workerThread.tryToStealWork(entry);
446+
if (entry.future.isDone()) {
476447
return callable.call();
477448
}
478449
LOGGER.trace(() -> "blocking for child task");
@@ -486,6 +457,7 @@ public <T> T handle(Supplier<Boolean> blockingUnnecessary, Callable<T> callable)
486457
});
487458
}
488459
}
460+
489461
}
490462

491463
private static class WorkQueue {

0 commit comments

Comments
 (0)