1919import static org .junit .platform .engine .support .hierarchical .ExclusiveResource .GLOBAL_READ_WRITE ;
2020import static org .junit .platform .engine .support .hierarchical .Node .ExecutionMode .SAME_THREAD ;
2121
22- import java .util .ArrayDeque ;
2322import java .util .ArrayList ;
2423import java .util .Collection ;
25- import java .util .Deque ;
2624import java .util .List ;
2725import java .util .PriorityQueue ;
2826import java .util .Queue ;
@@ -99,8 +97,7 @@ public void close() {
9997 }
10098
10199 var entry = enqueue (testTask );
102- workerThread .addForkedChild (entry );
103- return new BlockingAwareFuture <@ Nullable Void >(entry .future (), WorkerThread .BlockHandler .INSTANCE );
100+ return new BlockingAwareFuture <@ Nullable Void >(entry .future (), new WorkerThread .BlockHandler (entry ));
104101 }
105102
106103 @ Override
@@ -175,8 +172,6 @@ public Thread newThread(Runnable runnable) {
175172
176173 private class WorkerThread extends Thread {
177174
178- private final Deque <State > state = new ArrayDeque <>();
179-
180175 @ Nullable
181176 WorkerLease workerLease ;
182177
@@ -250,19 +245,15 @@ void invokeAll(List<? extends TestTask> testTasks) {
250245
251246 List <TestTask > isolatedTasks = new ArrayList <>(testTasks .size ());
252247 List <TestTask > sameThreadTasks = new ArrayList <>(testTasks .size ());
253- forkConcurrentChildren (testTasks , isolatedTasks ::add , sameThreadTasks );
248+ var concurrentTasks = forkConcurrentChildren (testTasks , isolatedTasks ::add , sameThreadTasks );
254249 executeAll (sameThreadTasks );
255- var remainingForkedChildren = stealWork ();
250+ var remainingForkedChildren = stealWork (concurrentTasks );
256251 waitFor (remainingForkedChildren );
257252 executeAll (isolatedTasks );
258253 }
259254
260- void addForkedChild (WorkQueue .Entry entry ) {
261- getForkedChildren ().add (entry );
262- }
263-
264- private void forkConcurrentChildren (List <? extends TestTask > children , Consumer <TestTask > isolatedTaskCollector ,
265- List <TestTask > sameThreadTasks ) {
255+ private Queue <WorkQueue .Entry > forkConcurrentChildren (List <? extends TestTask > children ,
256+ Consumer <TestTask > isolatedTaskCollector , List <TestTask > sameThreadTasks ) {
266257
267258 Queue <WorkQueue .Entry > queueEntries = new PriorityQueue <>(children .size (), reverseOrder ());
268259 for (TestTask child : children ) {
@@ -276,48 +267,44 @@ else if (child.getExecutionMode() == SAME_THREAD) {
276267 queueEntries .add (WorkQueue .Entry .create (child ));
277268 }
278269 }
270+
279271 if (!queueEntries .isEmpty ()) {
280272 if (sameThreadTasks .isEmpty ()) {
281273 // hold back one task for this thread
282274 sameThreadTasks .add (queueEntries .poll ().task );
283275 }
284276 forkAll (queueEntries );
285- getForkedChildren ().addAll (queueEntries );
286277 }
278+
279+ return queueEntries ;
287280 }
288281
289- private List <WorkQueue .Entry > stealWork () {
290- var forkedChildren = getForkedChildren ();
291- List <WorkQueue .Entry > concurrentlyExecutingChildren = new ArrayList <>(forkedChildren .size ());
282+ private List <WorkQueue .Entry > stealWork (Queue <WorkQueue .Entry > concurrentTasks ) {
283+ List <WorkQueue .Entry > concurrentlyExecutingChildren = new ArrayList <>(concurrentTasks .size ());
292284 WorkQueue .Entry entry ;
293- while ((entry = forkedChildren .poll ()) != null ) {
294- if (entry .future .isDone ()) {
285+ while ((entry = concurrentTasks .poll ()) != null ) {
286+ var executed = tryToStealWork (entry );
287+ if (!executed ) {
295288 concurrentlyExecutingChildren .add (entry );
296289 }
297- else {
298- var claimed = workQueue .remove (entry );
299- if (claimed ) {
300- var executed = tryExecuteStolenEntry (entry );
301- if (!executed ) {
302- workQueue .reAdd (entry );
303- concurrentlyExecutingChildren .add (entry );
304- }
305- }
306- else {
307- concurrentlyExecutingChildren .add (entry );
308- }
309- }
310290 }
311291 return concurrentlyExecutingChildren ;
312292 }
313293
314- private Queue <WorkQueue .Entry > getForkedChildren () {
315- return currentState ().forkedChildren ;
316- }
317-
318- private boolean tryExecuteStolenEntry (WorkQueue .Entry entry ) {
319- LOGGER .trace (() -> "stole work: " + entry );
320- return tryExecute (entry );
294+ private boolean tryToStealWork (WorkQueue .Entry entry ) {
295+ if (entry .future .isDone ()) {
296+ return false ;
297+ }
298+ var claimed = workQueue .remove (entry );
299+ if (claimed ) {
300+ LOGGER .trace (() -> "stole work: " + entry );
301+ var executed = tryExecute (entry );
302+ if (!executed ) {
303+ workQueue .reAdd (entry );
304+ }
305+ return executed ;
306+ }
307+ return false ;
321308 }
322309
323310 private void waitFor (List <WorkQueue .Entry > children ) {
@@ -332,8 +319,7 @@ private void waitFor(List<WorkQueue.Entry> children) {
332319 }
333320 else {
334321 runBlocking (() -> {
335- LOGGER .trace (() -> "blocking for forked children of %s: %s" .formatted (
336- currentState ().executingTask , children ));
322+ LOGGER .trace (() -> "blocking for forked children : %s" .formatted (children ));
337323 return future .join ();
338324 });
339325 }
@@ -351,8 +337,7 @@ private void executeAll(List<? extends TestTask> children) {
351337 if (children .isEmpty ()) {
352338 return ;
353339 }
354- LOGGER .trace (
355- () -> "running %d children of %s directly" .formatted (children .size (), currentState ().executingTask ));
340+ LOGGER .trace (() -> "running %d children directly" .formatted (children .size ()));
356341 if (children .size () == 1 ) {
357342 executeTask (children .get (0 ));
358343 return ;
@@ -429,20 +414,14 @@ private boolean tryExecuteTask(TestTask testTask) {
429414
430415 private void doExecute (TestTask testTask ) {
431416 LOGGER .trace (() -> "executing: " + testTask );
432- this .state .push (new State (testTask ));
433417 try {
434418 testTask .execute ();
435419 }
436420 finally {
437- this .state .pop ();
438421 LOGGER .trace (() -> "finished executing: " + testTask );
439422 }
440423 }
441424
442- private State currentState () {
443- return state .element ();
444- }
445-
446425 private static CompletableFuture <?> toCombinedFuture (List <WorkQueue .Entry > entries ) {
447426 if (entries .size () == 1 ) {
448427 return entries .get (0 ).future ();
@@ -451,28 +430,20 @@ private static CompletableFuture<?> toCombinedFuture(List<WorkQueue.Entry> entri
451430 return CompletableFuture .allOf (futures );
452431 }
453432
454- private record State (TestTask executingTask , Queue <WorkQueue .Entry > forkedChildren ) {
455- State (TestTask executingTask ) {
456- this (executingTask , new PriorityQueue <>(reverseOrder ()));
457- }
458- }
459-
460433 private interface BlockingAction <T > {
461434 T run () throws InterruptedException ;
462435 }
463436
464- private static class BlockHandler implements BlockingAwareFuture .BlockHandler {
465-
466- private static final BlockHandler INSTANCE = new BlockHandler ();
437+ private record BlockHandler (WorkQueue .Entry entry ) implements BlockingAwareFuture .BlockHandler {
467438
468439 @ Override
469440 public <T > T handle (Supplier <Boolean > blockingUnnecessary , Callable <T > callable ) throws Exception {
470441 var workerThread = get ();
471- if (workerThread == null || blockingUnnecessary . get ()) {
442+ if (workerThread == null || entry . future . isDone ()) {
472443 return callable .call ();
473444 }
474- workerThread .stealWork ( );
475- if (blockingUnnecessary . get ()) {
445+ workerThread .tryToStealWork ( entry );
446+ if (entry . future . isDone ()) {
476447 return callable .call ();
477448 }
478449 LOGGER .trace (() -> "blocking for child task" );
@@ -486,6 +457,7 @@ public <T> T handle(Supplier<Boolean> blockingUnnecessary, Callable<T> callable)
486457 });
487458 }
488459 }
460+
489461 }
490462
491463 private static class WorkQueue {
0 commit comments