Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -1119,10 +1121,7 @@ public static List<Class<?>> findNestedClasses(Class<?> clazz, Predicate<Class<?
Preconditions.notNull(predicate, "Predicate must not be null");

Set<Class<?>> candidates = new LinkedHashSet<>();
visitNestedClasses(clazz, predicate, nestedClass -> {
candidates.add(nestedClass);
return true;
});
visitAllNestedClasses(clazz, predicate, candidates::add);
return List.copyOf(candidates);
}

Expand All @@ -1144,8 +1143,9 @@ public static boolean isNestedClassPresent(Class<?> clazz, Predicate<Class<?>> p
Preconditions.notNull(clazz, "Class must not be null");
Preconditions.notNull(predicate, "Predicate must not be null");

boolean visitorWasNotCalled = visitNestedClasses(clazz, predicate, __ -> false);
return !visitorWasNotCalled;
AtomicBoolean foundNestedClass = new AtomicBoolean(false);
visitAllNestedClasses(clazz, predicate, __ -> foundNestedClass.setPlain(true));
return foundNestedClass.getPlain();
}

/**
Expand All @@ -1156,10 +1156,15 @@ public static Stream<Class<?>> streamNestedClasses(Class<?> clazz, Predicate<Cla
return findNestedClasses(clazz, predicate).stream();
}

private static boolean visitNestedClasses(Class<?> clazz, Predicate<Class<?>> predicate,
Visitor<Class<?>> visitor) {
/**
* Visit <em>all</em> nested classes without support for short-circuiting
* in order to ensure all of them are checked for class cycles.
*/
private static void visitAllNestedClasses(Class<?> clazz, Predicate<Class<?>> predicate,
Consumer<Class<?>> consumer) {

if (!isSearchable(clazz)) {
return true;
return;
}

if (isInnerClass(clazz) && predicate.test(clazz)) {
Expand All @@ -1171,10 +1176,7 @@ private static boolean visitNestedClasses(Class<?> clazz, Predicate<Class<?>> pr
for (Class<?> nestedClass : clazz.getDeclaredClasses()) {
if (predicate.test(nestedClass)) {
detectInnerClassCycle(nestedClass);
boolean shouldContinue = visitor.accept(nestedClass);
if (!shouldContinue) {
return false;
}
consumer.accept(nestedClass);
}
}
}
Expand All @@ -1183,20 +1185,12 @@ private static boolean visitNestedClasses(Class<?> clazz, Predicate<Class<?>> pr
}

// Search class hierarchy
boolean shouldContinue = visitNestedClasses(clazz.getSuperclass(), predicate, visitor);
if (!shouldContinue) {
return false;
}
visitAllNestedClasses(clazz.getSuperclass(), predicate, consumer);

// Search interface hierarchy
for (Class<?> ifc : clazz.getInterfaces()) {
shouldContinue = visitNestedClasses(ifc, predicate, visitor);
if (!shouldContinue) {
return false;
}
visitAllNestedClasses(ifc, predicate, consumer);
}

return true;
}

/**
Expand Down Expand Up @@ -1936,14 +1930,4 @@ static Throwable getUnderlyingCause(Throwable t) {
return t;
}

private interface Visitor<T> {

/**
* @return {@code true} if the visitor should continue searching;
* {@code false} if the visitor should stop
*/
boolean accept(T value);

}

}