diff --git a/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java b/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java index 8394fc2ebc27..5973b183d3fb 100644 --- a/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java +++ b/junit-platform-commons/src/main/java/org/junit/platform/commons/util/ReflectionUtils.java @@ -49,6 +49,8 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; import java.util.function.Predicate; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -1119,10 +1121,7 @@ public static List> findNestedClasses(Class clazz, Predicate> candidates = new LinkedHashSet<>(); - visitNestedClasses(clazz, predicate, nestedClass -> { - candidates.add(nestedClass); - return true; - }); + visitAllNestedClasses(clazz, predicate, candidates::add); return List.copyOf(candidates); } @@ -1144,8 +1143,9 @@ public static boolean isNestedClassPresent(Class clazz, Predicate> p Preconditions.notNull(clazz, "Class must not be null"); Preconditions.notNull(predicate, "Predicate must not be null"); - boolean visitorWasNotCalled = visitNestedClasses(clazz, predicate, __ -> false); - return !visitorWasNotCalled; + AtomicBoolean foundNestedClass = new AtomicBoolean(false); + visitAllNestedClasses(clazz, predicate, __ -> foundNestedClass.setPlain(true)); + return foundNestedClass.getPlain(); } /** @@ -1156,10 +1156,15 @@ public static Stream> streamNestedClasses(Class clazz, Predicate clazz, Predicate> predicate, - Visitor> visitor) { + /** + * Visit all nested classes without support for short-circuiting + * in order to ensure all of them are checked for class cycles. + */ + private static void visitAllNestedClasses(Class clazz, Predicate> predicate, + Consumer> consumer) { + if (!isSearchable(clazz)) { - return true; + return; } if (isInnerClass(clazz) && predicate.test(clazz)) { @@ -1171,10 +1176,7 @@ private static boolean visitNestedClasses(Class clazz, Predicate> pr for (Class nestedClass : clazz.getDeclaredClasses()) { if (predicate.test(nestedClass)) { detectInnerClassCycle(nestedClass); - boolean shouldContinue = visitor.accept(nestedClass); - if (!shouldContinue) { - return false; - } + consumer.accept(nestedClass); } } } @@ -1183,20 +1185,12 @@ private static boolean visitNestedClasses(Class clazz, Predicate> pr } // Search class hierarchy - boolean shouldContinue = visitNestedClasses(clazz.getSuperclass(), predicate, visitor); - if (!shouldContinue) { - return false; - } + visitAllNestedClasses(clazz.getSuperclass(), predicate, consumer); // Search interface hierarchy for (Class ifc : clazz.getInterfaces()) { - shouldContinue = visitNestedClasses(ifc, predicate, visitor); - if (!shouldContinue) { - return false; - } + visitAllNestedClasses(ifc, predicate, consumer); } - - return true; } /** @@ -1936,14 +1930,4 @@ static Throwable getUnderlyingCause(Throwable t) { return t; } - private interface Visitor { - - /** - * @return {@code true} if the visitor should continue searching; - * {@code false} if the visitor should stop - */ - boolean accept(T value); - - } - }