diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7801cd347f7d..4a60456984f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -520,9 +520,41 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] def subqueries: Seq[PlanType] = _subqueries() private val _subqueries = new TransientBestEffortLazyVal(() => - expressions.filter(_.containsPattern(PLAN_EXPRESSION)).flatMap(_.collect { - case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType] - }) + { + val queryPlanBaseClass = classOf[QueryPlan[_]] + + def directChildQueryPlanClass(clazz: Class[_]): Option[Class[_]] = { + var current = clazz + while (current != null && queryPlanBaseClass.isAssignableFrom(current)) { + val parent = current.getSuperclass + if (parent == queryPlanBaseClass) { + return Some(current) + } + current = parent + } + None + } + + val baseQueryPlanClass = directChildQueryPlanClass(getClass) + + val rawSubqueries = expressions + .filter(_.containsPattern(PLAN_EXPRESSION)) + .flatMap(_.collect { + case planExpression: PlanExpression[_] => + planExpression.plan + }) + + baseQueryPlanClass match { + case Some(baseClass) => + rawSubqueries.collect { + case subquery + if directChildQueryPlanClass(subquery.getClass).contains(baseClass) => + subquery.asInstanceOf[PlanType] + } + case None => + Seq.empty + } + } ) /**