diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 964a9d2ef0b47..532bf51a517ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -31,9 +31,7 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.resolver.{ AnalyzerBridgeState, HybridAnalyzer, - Resolver => OperatorResolver, - ResolverExtension, - ResolverGuard + ResolverExtension } import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -297,17 +295,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor def getRelationResolution: RelationResolution = relationResolution def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { - if (plan.analyzed) return plan - AnalysisHelper.markInAnalyzer { - new HybridAnalyzer( - this, - new ResolverGuard(catalogManager), - new OperatorResolver( - catalogManager, - singlePassResolverExtensions, - singlePassMetadataResolverExtensions - ) - ).apply(plan, tracker) + if (plan.analyzed) { + plan + } else { + AnalysisContext.reset() + try { + AnalysisHelper.markInAnalyzer { + HybridAnalyzer.fromLegacyAnalyzer(legacyAnalyzer = this).apply(plan, tracker) + } + } finally { + AnalysisContext.reset() + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala index 4a01bf14fe4bc..b194a4e44a9d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateExpressionResolver.scala @@ -18,11 +18,6 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{ - AnsiTypeCoercion, - CollationTypeCoercion, - TypeCoercion -} import org.apache.spark.sql.catalyst.expressions.{Expression, OuterReference, SubExprUtils} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ListAgg} import org.apache.spark.sql.catalyst.util.toPrettySQL @@ -41,11 +36,6 @@ class AggregateExpressionResolver( private val traversals = expressionResolver.getExpressionTreeTraversals - protected override val ansiTransformations: CoercesExpressionTypes.Transformations = - AggregateExpressionResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS - protected override val nonAnsiTransformations: CoercesExpressionTypes.Transformations = - AggregateExpressionResolver.TYPE_COERCION_TRANSFORMATIONS - private val expressionResolutionContextStack = expressionResolver.getExpressionResolutionContextStack private val subqueryRegistry = operatorResolver.getSubqueryRegistry @@ -58,6 +48,7 @@ class AggregateExpressionResolver( * resolving its children recursively and validating the resolved expression. */ override def resolve(aggregateExpression: AggregateExpression): Expression = { + expressionResolutionContextStack.peek().resolvingTreeUnderAggregateExpression = true val aggregateExpressionWithChildrenResolved = withResolvedChildren(aggregateExpression, expressionResolver.resolve _) .asInstanceOf[AggregateExpression] @@ -132,15 +123,13 @@ class AggregateExpressionResolver( throwNestedAggregateFunction(aggregateExpression) } - val nonDeterministicChild = - aggregateExpression.aggregateFunction.children.collectFirst { - case child if !child.deterministic => child + aggregateExpression.aggregateFunction.children.foreach { child => + if (!child.deterministic) { + throwAggregateFunctionWithNondeterministicExpression( + aggregateExpression, + child + ) } - if (nonDeterministicChild.nonEmpty) { - throwAggregateFunctionWithNondeterministicExpression( - aggregateExpression, - nonDeterministicChild.get - ) } } @@ -249,23 +238,3 @@ class AggregateExpressionResolver( ) } } - -object AggregateExpressionResolver { - // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - CollationTypeCoercion.apply, - TypeCoercion.InTypeCoercion.apply, - TypeCoercion.FunctionArgumentTypeCoercion.apply, - TypeCoercion.IfTypeCoercion.apply, - TypeCoercion.ImplicitTypeCoercion.apply - ) - - // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - CollationTypeCoercion.apply, - AnsiTypeCoercion.InTypeCoercion.apply, - AnsiTypeCoercion.FunctionArgumentTypeCoercion.apply, - AnsiTypeCoercion.IfTypeCoercion.apply, - AnsiTypeCoercion.ImplicitTypeCoercion.apply - ) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolutionResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolutionResult.scala index d96185f642fd8..d4bb96e8d72f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolutionResult.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolutionResult.scala @@ -20,14 +20,15 @@ package org.apache.spark.sql.catalyst.analysis.resolver import java.util.HashSet import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId, NamedExpression} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} /** - * Stores the resulting operator, output list, grouping attributes and list of aliases from - * aggregate list, obtained by resolving an [[Aggregate]] operator. + * Stores the resulting operator, output list, grouping attributes, list of aliases from + * aggregate list and base [[Aggregate]], obtained by resolving an [[Aggregate]] operator. */ case class AggregateResolutionResult( operator: LogicalPlan, outputList: Seq[NamedExpression], - groupingAttributeIds: Option[HashSet[ExprId]], - aggregateListAliases: Seq[Alias]) + groupingAttributeIds: HashSet[ExprId], + aggregateListAliases: Seq[Alias], + baseAggregate: Aggregate) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala index fffd55b5897bd..7591452b76d21 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateResolver.scala @@ -17,24 +17,18 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import java.util.{HashSet, LinkedHashMap} +import java.util.HashSet -import scala.jdk.CollectionConverters._ - -import org.apache.spark.sql.catalyst.analysis.{ - AnalysisErrorAt, - NondeterministicExpressionCollection, - UnresolvedAttribute -} +import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.{ Alias, + AliasHelper, AttributeReference, Expression, ExprId, - ExprUtils, - NamedExpression + ExprUtils } -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} /** * Resolves an [[Aggregate]] by resolving its child, aggregate expressions and grouping @@ -42,7 +36,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj * related to [[Aggregate]] resolution. */ class AggregateResolver(operatorResolver: Resolver, expressionResolver: ExpressionResolver) - extends TreeNodeResolver[Aggregate, LogicalPlan] { + extends TreeNodeResolver[Aggregate, LogicalPlan] + with AliasHelper { private val scopes = operatorResolver.getNameScopes private val lcaResolver = expressionResolver.getLcaResolver @@ -50,17 +45,27 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi * Resolve [[Aggregate]] operator. * * 1. Resolve the child (inline table). - * 2. Resolve aggregate expressions using [[ExpressionResolver.resolveAggregateExpressions]] and + * 2. Clear [[NameScope.availableAliases]]. Those are only relevant for the immediate aggregate + * expressions for output prioritization to work correctly in + * [[NameScope.tryResolveMultipartNameByOutput]]. + * 3. Resolve aggregate expressions using [[ExpressionResolver.resolveAggregateExpressions]] and * set [[NameScope.ordinalReplacementExpressions]] for grouping expressions resolution. - * 3. If there's just one [[UnresolvedAttribute]] with a single-part name "ALL", expand it using + * 4. If there's just one [[UnresolvedAttribute]] with a single-part name "ALL", expand it using * aggregate expressions which don't contain aggregate functions. There should not exist a * column with that name in the lower operator's output, otherwise it takes precedence. - * 4. Resolve grouping expressions using [[ExpressionResolver.resolveGroupingExpressions]]. This + * 5. Resolve grouping expressions using [[ExpressionResolver.resolveGroupingExpressions]]. This * includes alias references to aggregate expressions, which is done in * [[NameScope.resolveMultipartName]] and replacing [[UnresolvedOrdinals]] with corresponding * expressions from aggregate list, done in [[OrdinalResolver]]. - * 5. Substitute non-deterministic expressions with derived attribute references to an - * artificial [[Project]] list. + * 6. Remove all the unnecessary [[Alias]]es from the grouping (all the aliases) and aggregate + * (keep the outermost one) expressions. This is needed to stay compatible with the + * fixed-point implementation. For example: + * + * {{{ SELECT timestamp(col1:str) FROM VALUES('a') GROUP BY timestamp(col1:str); }}} + * + * Here we end up having inner [[Alias]]es in both the grouping and aggregate expressions + * lists which are uncomparable because they have different expression IDs (thus we have to + * strip them). * * If the resulting [[Aggregate]] contains lateral columns references, delegate the resolution of * these columns to [[LateralColumnAliasResolver.handleLcaInAggregate]]. Otherwise, validate the @@ -73,6 +78,8 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi val resolvedAggregate = try { val resolvedChild = operatorResolver.resolve(unresolvedAggregate.child) + scopes.current.availableAliases.clear() + val resolvedAggregateExpressions = expressionResolver.resolveAggregateExpressions( unresolvedAggregate.aggregateExpressions, unresolvedAggregate @@ -100,21 +107,25 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi ) } - val partiallyResolvedAggregate = unresolvedAggregate.copy( - groupingExpressions = resolvedGroupingExpressions, - aggregateExpressions = resolvedAggregateExpressions.expressions, + val resolvedGroupingExpressionsWithoutAliases = resolvedGroupingExpressions.map(trimAliases) + val resolvedAggregateExpressionsWithoutAliases = + resolvedAggregateExpressions.expressions.map(trimNonTopLevelAliases) + + val resolvedAggregate = unresolvedAggregate.copy( + groupingExpressions = resolvedGroupingExpressionsWithoutAliases, + aggregateExpressions = resolvedAggregateExpressionsWithoutAliases, child = resolvedChild ) - val resolvedAggregate = tryPullOutNondeterministic(partiallyResolvedAggregate) - if (resolvedAggregateExpressions.hasLateralColumnAlias) { val aggregateWithLcaResolutionResult = lcaResolver.handleLcaInAggregate(resolvedAggregate) AggregateResolutionResult( operator = aggregateWithLcaResolutionResult.resolvedOperator, outputList = aggregateWithLcaResolutionResult.outputList, - groupingAttributeIds = None, - aggregateListAliases = aggregateWithLcaResolutionResult.aggregateListAliases + groupingAttributeIds = + getGroupingAttributeIds(aggregateWithLcaResolutionResult.baseAggregate), + aggregateListAliases = aggregateWithLcaResolutionResult.aggregateListAliases, + baseAggregate = aggregateWithLcaResolutionResult.baseAggregate ) } else { // TODO: This validation function does a post-traversal. This is discouraged in single-pass @@ -124,8 +135,9 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi AggregateResolutionResult( operator = resolvedAggregate, outputList = resolvedAggregate.aggregateExpressions, - groupingAttributeIds = Some(getGroupingAttributeIds(resolvedAggregate)), - aggregateListAliases = scopes.current.getTopAggregateExpressionAliases + groupingAttributeIds = getGroupingAttributeIds(resolvedAggregate), + aggregateListAliases = scopes.current.getTopAggregateExpressionAliases, + baseAggregate = resolvedAggregate ) } } finally { @@ -134,8 +146,9 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi scopes.overwriteOutputAndExtendHiddenOutput( output = resolvedAggregate.outputList.map(_.toAttribute), - groupingAttributeIds = resolvedAggregate.groupingAttributeIds, - aggregateListAliases = resolvedAggregate.aggregateListAliases + groupingAttributeIds = Some(resolvedAggregate.groupingAttributeIds), + aggregateListAliases = resolvedAggregate.aggregateListAliases, + baseAggregate = Some(resolvedAggregate.baseAggregate) ) resolvedAggregate.operator @@ -208,53 +221,6 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi } } - /** - * In case there are non-deterministic expressions in either `groupingExpressions` or - * `aggregateExpressions` replace them with attributes created out of corresponding - * non-deterministic expression. Example: - * - * {{{ SELECT RAND() GROUP BY 1; }}} - * - * This query would have the following analyzed plan: - * Aggregate( - * groupingExpressions = [AttributeReference(_nonDeterministic)] - * aggregateExpressions = [Alias(AttributeReference(_nonDeterministic), `rand()`)] - * child = Project( - * projectList = [Alias(Rand(...), `_nondeterministic`)] - * child = OneRowRelation - * ) - * ) - */ - private def tryPullOutNondeterministic(aggregate: Aggregate): Aggregate = { - val nondeterministicToAttributes: LinkedHashMap[Expression, NamedExpression] = - NondeterministicExpressionCollection.getNondeterministicToAttributes( - aggregate.groupingExpressions - ) - - if (!nondeterministicToAttributes.isEmpty) { - val newChild = Project( - scopes.current.output ++ nondeterministicToAttributes.values.asScala.toSeq, - aggregate.child - ) - val resolvedAggregateExpressions = aggregate.aggregateExpressions.map { expression => - PullOutNondeterministicExpressionInExpressionTree(expression, nondeterministicToAttributes) - } - val resolvedGroupingExpressions = aggregate.groupingExpressions.map { expression => - PullOutNondeterministicExpressionInExpressionTree( - expression, - nondeterministicToAttributes - ) - } - aggregate.copy( - groupingExpressions = resolvedGroupingExpressions, - aggregateExpressions = resolvedAggregateExpressions, - child = newChild - ) - } else { - aggregate - } - } - private def canGroupByAll(expressions: Seq[Expression]): Boolean = { val isOrderByAll = expressions match { case Seq(unresolvedAttribute: UnresolvedAttribute) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateWithLcaResolutionResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateWithLcaResolutionResult.scala index 535c41cad39ec..15d4ac4f5ceb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateWithLcaResolutionResult.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AggregateWithLcaResolutionResult.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} /** * Stores the result of resolution of lateral column aliases in an [[Aggregate]]. @@ -26,8 +26,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * @param outputList The output list of the resolved operator. * @param aggregateListAliases List of aliases from aggregate list and all artificially inserted * [[Project]] nodes. + * @param baseAggregate [[Aggregate]] node constructed by [[LateralColumnAliasResolver]] while + * resolving lateral column references in [[Aggregate]]. */ case class AggregateWithLcaResolutionResult( resolvedOperator: LogicalPlan, outputList: Seq[NamedExpression], - aggregateListAliases: Seq[Alias]) + aggregateListAliases: Seq[Alias], + baseAggregate: Aggregate) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala index 83329c0fe464d..b56281fa12ba8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AliasResolver.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import org.apache.spark.sql.catalyst.analysis.{AliasResolution, MultiAlias, UnresolvedAlias} +import org.apache.spark.sql.catalyst.analysis.{AliasResolution, UnresolvedAlias} import org.apache.spark.sql.catalyst.expressions.{ Alias, Expression, NamedExpression, OuterReference } +import org.apache.spark.sql.errors.QueryCompilationErrors /** * Resolver class that resolves unresolved aliases and handles user-specified aliases. @@ -45,25 +46,32 @@ class AliasResolver(expressionResolver: ExpressionResolver) * we create a new [[Alias]] using the [[AutoGeneratedAliasProvider]]. Here we allow inner * aliases to persist until the end of single-pass resolution, after which they will be removed * in the post-processing phase. + * + * Resulting [[Alias]] must be added to the list of `availableAliases` in the current + * [[NameScope]]. */ override def resolve(unresolvedAlias: UnresolvedAlias): NamedExpression = - scopes.current.lcaRegistry.withNewLcaScope { + scopes.current.lcaRegistry.withNewLcaScope( + isTopLevelAlias = expressionResolutionContextStack.peek().isTopOfProjectList + ) { val aliasWithResolvedChildren = withResolvedChildren(unresolvedAlias, expressionResolver.resolve _) .asInstanceOf[UnresolvedAlias] - val resolvedAlias = + val resolvedNode = AliasResolution.resolve(aliasWithResolvedChildren).asInstanceOf[NamedExpression] - resolvedAlias match { - case multiAlias: MultiAlias => - throw new ExplicitlyUnsupportedResolverFeature( - s"unsupported expression: ${multiAlias.getClass.getName}" - ) + resolvedNode match { case alias: Alias => - expressionResolver.getExpressionIdAssigner.mapExpression(alias) + val resultAlias = expressionResolver.getExpressionIdAssigner.mapExpression(alias) + scopes.current.availableAliases.add(resultAlias.exprId) + resultAlias case outerReference: OuterReference => autoGeneratedAliasProvider.newAlias(outerReference) + case _ => + throw QueryCompilationErrors.unsupportedSinglePassAnalyzerFeature( + s"${resolvedNode.getClass} expression resolution" + ) } } @@ -77,18 +85,21 @@ class AliasResolver(expressionResolver: ExpressionResolver) * those aliases. See [[ExpressionIdAssigner.mapExpression]] doc for more details. */ def handleResolvedAlias(alias: Alias): Alias = { - val resolvedAlias = scopes.current.lcaRegistry.withNewLcaScope { + val resolvedAlias = scopes.current.lcaRegistry.withNewLcaScope( + isTopLevelAlias = expressionResolutionContextStack.peek().isTopOfProjectList + ) { val aliasWithResolvedChildren = withResolvedChildren(alias, expressionResolver.resolve _).asInstanceOf[Alias] - val mappedAlias = expressionResolver.getExpressionIdAssigner.mapExpression( + val resultAlias = expressionResolver.getExpressionIdAssigner.mapExpression( originalExpression = aliasWithResolvedChildren, prioritizeOldDuplicateAliasId = expressionResolutionContextStack.peek().resolvingGroupingExpressions ) - scopes.current.availableAliases.add(mappedAlias.exprId) - mappedAlias + scopes.current.availableAliases.add(resultAlias.exprId) + + resultAlias } collapseAlias(resolvedAlias) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AutoGeneratedAliasProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AutoGeneratedAliasProvider.scala index 2a49581b3499b..5fd5a5ff7870e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AutoGeneratedAliasProvider.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/AutoGeneratedAliasProvider.scala @@ -62,7 +62,7 @@ class AutoGeneratedAliasProvider(expressionIdAssigner: ExpressionIdAssigner) { name: Option[String] = None, explicitMetadata: Option[Metadata] = None, skipExpressionIdAssigner: Boolean = false): Alias = { - var alias = Alias( + val alias = Alias( child = child, name = name.getOrElse(toPrettySQL(child)) )( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala index 00e5d2347150a..523e497f0613b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/BinaryArithmeticResolver.scala @@ -17,19 +17,11 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import org.apache.spark.sql.catalyst.analysis.{ - AnsiStringPromotionTypeCoercion, - AnsiTypeCoercion, - BinaryArithmeticWithDatetimeResolver, - DecimalPrecisionTypeCoercion, - DivisionTypeCoercion, - IntegralDivisionTypeCoercion, - StringPromotionTypeCoercion, - TypeCoercion -} +import org.apache.spark.sql.catalyst.analysis.BinaryArithmeticWithDatetimeResolver import org.apache.spark.sql.catalyst.expressions.{ Add, BinaryArithmetic, + Cast, DateAdd, Divide, Expression, @@ -37,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.{ Subtract, SubtractDates } -import org.apache.spark.sql.types.{DateType, StringType} +import org.apache.spark.sql.types._ /** * [[BinaryArithmeticResolver]] is invoked by [[ExpressionResolver]] in order to resolve @@ -90,11 +82,6 @@ class BinaryArithmeticResolver(expressionResolver: ExpressionResolver) private val traversals = expressionResolver.getExpressionTreeTraversals - protected override val ansiTransformations: CoercesExpressionTypes.Transformations = - BinaryArithmeticResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS - protected override val nonAnsiTransformations: CoercesExpressionTypes.Transformations = - BinaryArithmeticResolver.TYPE_COERCION_TRANSFORMATIONS - override def resolve(unresolvedBinaryArithmetic: BinaryArithmetic): Expression = { val binaryArithmeticWithResolvedChildren: BinaryArithmetic = withResolvedChildren(unresolvedBinaryArithmetic, expressionResolver.resolve _) @@ -117,8 +104,9 @@ class BinaryArithmeticResolver(expressionResolver: ExpressionResolver) * of nodes. */ private def transformBinaryArithmeticNode(binaryArithmetic: BinaryArithmetic): Expression = { + val binaryArithmeticWithNullReplaced: Expression = replaceNullType(binaryArithmetic) val binaryArithmeticWithDateTypeReplaced: Expression = - replaceDateType(binaryArithmetic) + replaceDateType(binaryArithmeticWithNullReplaced) val binaryArithmeticWithTypeCoercion: Expression = coerceExpressionTypes( expression = binaryArithmeticWithDateTypeReplaced, @@ -154,26 +142,29 @@ class BinaryArithmeticResolver(expressionResolver: ExpressionResolver) BinaryArithmeticWithDatetimeResolver.resolve(arithmetic) case other => other } -} -object BinaryArithmeticResolver { - // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - StringPromotionTypeCoercion.apply, - DecimalPrecisionTypeCoercion.apply, - DivisionTypeCoercion.apply, - IntegralDivisionTypeCoercion.apply, - TypeCoercion.ImplicitTypeCoercion.apply, - TypeCoercion.DateTimeOperationsTypeCoercion.apply - ) - - // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - AnsiStringPromotionTypeCoercion.apply, - DecimalPrecisionTypeCoercion.apply, - DivisionTypeCoercion.apply, - IntegralDivisionTypeCoercion.apply, - AnsiTypeCoercion.ImplicitTypeCoercion.apply, - AnsiTypeCoercion.AnsiDateTimeOperationsTypeCoercion.apply - ) + /** + * Replaces NullType by a compatible type in arithmetic expressions over Datetime operands. + * This avoids recursive calls of [[BinaryArithmeticWithDatetimeResolver]] which converts + * unacceptable nulls of `NullType` to an expected types of datetime expressions at the + * first step, and replacing arithmetic `Add` and `Subtract` by the same datetime expressions + * on the following steps. + */ + private def replaceNullType(expression: Expression): Expression = expression match { + case a @ Add(l, r, _) => (l.dataType, r.dataType) match { + case (_: DatetimeType, _: NullType) => + a.copy(right = Cast(a.right, DayTimeIntervalType.DEFAULT)) + case (_: NullType, _: DatetimeType) => + a.copy(left = Cast(a.left, DayTimeIntervalType.DEFAULT)) + case _ => a + } + case s @ Subtract(l, r, _) => (l.dataType, r.dataType) match { + case (_: NullType, _: DatetimeType) => + s.copy(left = Cast(s.left, s.right.dataType)) + case (_: DatetimeType, _: NullType) => + s.copy(right = Cast(s.right, s.left.dataType)) + case _ => s + } + case other => other + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesOperatorChildren.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CandidatesForResolution.scala similarity index 59% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesOperatorChildren.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CandidatesForResolution.scala index 0f548c3c55858..3b8ec61bf5ee8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesOperatorChildren.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CandidatesForResolution.scala @@ -17,20 +17,10 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions.Attribute /** - * A mixin trait for all operator resolvers that need to resolve their children. + * [[CandidatesForResolution]] is used by the [[NameScope]] during multipart name resolution to + * prioritize attributes from different types of operator output (main, hidden, metadata). */ -trait ResolvesOperatorChildren { - - /** - * Resolves generic [[LogicalPlan]] children and returns its copy with children resolved. - */ - protected def withResolvedChildren[OperatorType <: LogicalPlan]( - unresolvedOperator: OperatorType, - resolve: LogicalPlan => LogicalPlan): OperatorType = { - val newChildren = unresolvedOperator.children.map(resolve(_)) - unresolvedOperator.withNewChildren(newChildren).asInstanceOf[OperatorType] - } -} +case class CandidatesForResolution(attributes: Seq[Attribute], outputType: OutputType.OutputType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CoercesExpressionTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CoercesExpressionTypes.scala index 0fc6a6742edad..34be99a1abc0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CoercesExpressionTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CoercesExpressionTypes.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{ TypeCoercion } import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin /** * [[CoercesExpressionTypes]] is extended by resolvers that need to apply type coercion. @@ -57,27 +58,42 @@ trait CoercesExpressionTypes extends SQLConfHelper { * * In the end, we apply [[DefaultCollationTypeCoercion]]. * See [[DefaultCollationTypeCoercion]] doc for more info. + * + * Additionally, we copy the tags and origin in case the call to this method didn't come from + * [[ExpressionResolver]], where they are copied generically. */ def coerceExpressionTypes( expression: Expression, expressionTreeTraversal: ExpressionTreeTraversal): Expression = { - val coercedExpressionOnce = applyTypeCoercion( - expression = expression, - expressionTreeTraversal = expressionTreeTraversal - ) - // This is a hack necessary because fixed-point analyzer sometimes requires multiple passes to - // resolve type coercion. Instead, in single pass, we apply type coercion twice on the same - // node in order to ensure that types are resolved. - val coercedExpressionTwice = applyTypeCoercion( - expression = coercedExpressionOnce, - expressionTreeTraversal = expressionTreeTraversal - ) + withOrigin(expression.origin) { + val coercedExpressionOnce = applyTypeCoercion( + expression = expression, + expressionTreeTraversal = expressionTreeTraversal + ) + + // If the expression isn't changed by the first iteration of type coercion, + // second iteration won't be effective either. + val expressionAfterTypeCoercion = if (coercedExpressionOnce.eq(expression)) { + coercedExpressionOnce + } else { + // This is a hack necessary because fixed-point analyzer sometimes requires multiple passes + // to resolve type coercion. Instead, in single pass, we apply type coercion twice on the + // same node in order to ensure that types are resolved. + applyTypeCoercion( + expression = coercedExpressionOnce, + expressionTreeTraversal = expressionTreeTraversal + ) + } + + val coercionResult = expressionTreeTraversal.defaultCollation match { + case Some(defaultCollation) => + DefaultCollationTypeCoercion(expressionAfterTypeCoercion, defaultCollation) + case None => + expressionAfterTypeCoercion + } - expressionTreeTraversal.defaultCollation match { - case Some(defaultCollation) => - DefaultCollationTypeCoercion(coercedExpressionTwice, defaultCollation) - case None => - coercedExpressionTwice + coercionResult.copyTagsFrom(expression) + coercionResult } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala deleted file mode 100644 index 0d847cf09adf0..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ConditionalExpressionResolver.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis.resolver - -import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.{ConditionalExpression, Expression} - -/** - * Resolver for [[If]], [[CaseWhen]] and [[Coalesce]] expressions. - */ -class ConditionalExpressionResolver(expressionResolver: ExpressionResolver) - extends TreeNodeResolver[ConditionalExpression, Expression] - with ResolvesExpressionChildren - with CoercesExpressionTypes { - - private val traversals = expressionResolver.getExpressionTreeTraversals - - protected override val ansiTransformations: CoercesExpressionTypes.Transformations = - ConditionalExpressionResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS - protected override val nonAnsiTransformations: CoercesExpressionTypes.Transformations = - ConditionalExpressionResolver.TYPE_COERCION_TRANSFORMATIONS - - override def resolve(unresolvedConditionalExpression: ConditionalExpression): Expression = { - val conditionalExpressionWithResolvedChildren = - withResolvedChildren(unresolvedConditionalExpression, expressionResolver.resolve _) - - coerceExpressionTypes( - expression = conditionalExpressionWithResolvedChildren, - expressionTreeTraversal = traversals.current - ) - } -} - -object ConditionalExpressionResolver { - // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - TypeCoercion.CaseWhenTypeCoercion.apply, - TypeCoercion.FunctionArgumentTypeCoercion.apply, - TypeCoercion.IfTypeCoercion.apply - ) - - // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - AnsiTypeCoercion.CaseWhenTypeCoercion.apply, - AnsiTypeCoercion.FunctionArgumentTypeCoercion.apply, - AnsiTypeCoercion.IfTypeCoercion.apply - ) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CreateNamedStructResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CreateNamedStructResolver.scala deleted file mode 100644 index d0e4ecea25cb3..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/CreateNamedStructResolver.scala +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis.resolver - -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateNamedStruct, Expression} - -/** - * Resolves [[CreateNamedStruct]] nodes by recursively resolving children. If [[CreateNamedStruct]] - * is not directly under an [[Alias]], removes aliases from struct fields. Otherwise, let - * [[AliasResolver]] handle the removal. - */ -class CreateNamedStructResolver(expressionResolver: ExpressionResolver) - extends TreeNodeResolver[CreateNamedStruct, Expression] - with ResolvesExpressionChildren { - - override def resolve(createNamedStruct: CreateNamedStruct): Expression = { - val createNamedStructWithResolvedChildren = - withResolvedChildren(createNamedStruct, expressionResolver.resolve) - .asInstanceOf[CreateNamedStruct] - CreateNamedStructResolver.cleanupAliases(createNamedStructWithResolvedChildren) - } -} - -object CreateNamedStructResolver { - - /** - * For a query like: - * - * {{{ SELECT STRUCT(1 AS a, 2 AS b) }}} - * - * [[CreateNamedStruct]] will be: CreateNamedStruct(Seq("a", Alias(1, "a"), "b", Alias(2, "b"))) - * - * Because inner aliases are not expected in the analyzed logical plan, we need to remove them - * here. However, we only do so if [[CreateNamedStruct]] is not directly under an [[Alias]], in - * which case the removal will be handled by [[AliasResolver]]. This is because in single-pass, - * [[Alias]] is resolved after [[CreateNamedStruct]] and in order to compute the correct output - * name, it needs to know complete structure of the child. - */ - def cleanupAliases(createNamedStruct: CreateNamedStruct): CreateNamedStruct = { - createNamedStruct - .withNewChildren(createNamedStruct.children.map { - case a: Alias if a.metadata.isEmpty => a.child - case other => other - }) - .asInstanceOf[CreateNamedStruct] - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala index 71bc7f24d5f28..9cd4c525b48c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionIdAssigner.scala @@ -326,9 +326,17 @@ class ExpressionIdAssigner { * child mappings will have collisions during this merge operation. We need to decide which of * the new IDs get the priority for the old ID. This is done based on the IDs that are actually * outputted into the multi-child operator. This information is provided with `newOutputIds`. - * If the new ID is present in that set, we treat it as a P0 over the IDs that are hidden in the - * branch. Also, we iterate over child mappings from right to left, prioritizing IDs from the - * left, because that's how operators like [[Union]] propagate IDs upwards. + * + * The principles: + * 1. If the destination ID is present in `newOutputIds`, we treat it as a higher priority over + * the ID that is "hidden" in the logical plan branch. + * 2. If both destination IDs are present in `newOutputIds`, we prioritize the identity mapping - + * the new ID which is equal to the old ID, and not the "remapping". This is valid in SQL + * because we are dealing with a fully unresolved plan and the remapping is not needed. + * DataFrame queries that contain a self-join or a self-union and are referencing the same + * attribute from both branches will fail (which is expected). + * 3. We iterate over child mappings from right to left, prioritizing IDs from the left, because + * that's how multi-child operators like [[Join]] or [[Union]] propagate IDs upwards. * * Example 1: * {{{ @@ -360,8 +368,19 @@ class ExpressionIdAssigner { * df2.join(df1, df2("b") === df1("a")) * }}} * - * This is used by multi child operators like [[Join]] or [[Union]] to propagate mapped - * expression IDs upwards. + * Example 3: + * {{{ + * -- In this query CTE references a table which is also present in a JOIN. First, CTE definition + * -- is analyzed with `t1` inside. Let's say it outputs col1#0. Once we get to a left JOIN child, + * -- which is also `t1`, we know that expression IDs in `t1` have to be regenerated to col#1 + * -- because it's a duplicate relation. After resolving the JOIN, we are left with (#0 -> #0), + * -- (#1 -> #1) and (#0 -> #1) mappings. Also, JOIN outputs both #0 and #1. This is an example + * -- of principle 2. when identity (#0 -> #0) and (#1 -> #1) mappings have to be prioritized, + * -- because (#0 -> #1) is a remapping and not needed in SQL. + * SELECT * FROM ( + * WITH cte1 AS (SELECT * FROM t1) SELECT t1.col1 FROM t1 JOIN cte1 USING (col1) + * ); + * }}} * * When `mergeIntoExisting` is true, we merge child mappings into an existing mapping entry * instead of creating a new one. This setting is used when resolving [[LateralJoin]]s. @@ -380,15 +399,19 @@ class ExpressionIdAssigner { throw SparkException.internalError("No child mappings to create new current mapping") } - val priorityMapping = new ExpressionIdAssigner.PriorityMapping(newOutputIds.size) + val newMapping = if (mergeIntoExisting) { + currentStackEntry.mapping.get + } else { + new ExpressionIdAssigner.Mapping + } while (!currentStackEntry.childMappings.isEmpty) { val nextMapping = currentStackEntry.childMappings.pop() nextMapping.forEach { case (oldId, remappedId) => - updatePriorityMapping( - priorityMapping = priorityMapping, + updateNewMapping( + newMapping = newMapping, oldId = oldId, remappedId = remappedId, newOutputIds = newOutputIds @@ -396,17 +419,6 @@ class ExpressionIdAssigner { } } - val newMapping = if (mergeIntoExisting) { - currentStackEntry.mapping.get - } else { - new ExpressionIdAssigner.Mapping - } - - priorityMapping.forEach { - case (oldId, priority) => - newMapping.put(oldId, priority.pick()) - } - setCurrentMapping(newMapping) } @@ -606,27 +618,31 @@ class ExpressionIdAssigner { } /** - * Update the priority mapping for the given `oldId` and `remappedId`. If the `remappedId` is - * contained in the `newOutputIds`, we treat it as a P0 over the IDs that are not exposed from - * the operator branch. Otherwise, we treat it as a P1. + * Update `newMapping` with the `oldId -> remappedId` mapping, based on the principles described + * in [[createMappingFromChildMappings]]: + * 1. If no mapping from `oldId` exists, we create it + * 2. If the mapping from `oldId` already exists but is not present in `newOutputIds`, we + * deprioritize old mapping in favor of new one + * 3. If the mapping from `oldId` already exists and is present in `newOutputIds` and the new + * mapping is the identity one, we deprioritize old mapping in favor of new one + * 4. Otherwise we keep the existing mapping */ - private def updatePriorityMapping( - priorityMapping: ExpressionIdAssigner.PriorityMapping, + private def updateNewMapping( + newMapping: ExpressionIdAssigner.Mapping, oldId: ExprId, remappedId: ExprId, newOutputIds: Set[ExprId]): Unit = { - if (newOutputIds.contains(remappedId)) { - priorityMapping.merge( - oldId, - ExpressionIdPriority(p0 = Some(remappedId)), - (priority, _) => priority.copy(p0 = Some(remappedId)) - ) - } else { - priorityMapping.merge( - oldId, - ExpressionIdPriority(p1 = Some(remappedId)), - (priority, _) => priority.copy(p1 = Some(remappedId)) - ) + newMapping.get(oldId) match { + case null => + newMapping.put(oldId, remappedId) + + case knownRemappedId if !newOutputIds.contains(knownRemappedId) => + newMapping.put(oldId, remappedId) + + case knownRemappedId if newOutputIds.contains(remappedId) && remappedId == oldId => + newMapping.put(oldId, remappedId) + + case _ => } } } @@ -641,8 +657,6 @@ object ExpressionIdAssigner { type Stack = ArrayDeque[StackEntry] - type PriorityMapping = HashMap[ExprId, ExpressionIdPriority] - /** * Assert that `outputs` don't have conflicting expression IDs. */ @@ -696,15 +710,3 @@ object ExpressionIdAssigner { } } } - -/** - * [[ExpressionIdPriority]] is used by the [[ExpressionIdAssigner]] when merging child mappings - * of a multi-child operator to determine which new ID gets picked in case of an old ID collision. - */ -case class ExpressionIdPriority(p0: Option[ExprId] = None, p1: Option[ExprId] = None) { - def pick(): ExprId = p0.getOrElse { - p1.getOrElse { - throw SparkException.internalError("No expression ID to pick") - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala index 471405c7b0c62..b032beedc9706 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolver.scala @@ -39,8 +39,7 @@ import org.apache.spark.sql.catalyst.analysis.{ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Sort} -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNodeTag} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors @@ -126,7 +125,6 @@ class ExpressionResolver( private val aliasResolver = new AliasResolver(this) private val timezoneAwareExpressionResolver = new TimezoneAwareExpressionResolver(this) - private val conditionalExpressionResolver = new ConditionalExpressionResolver(this) private val binaryArithmeticResolver = new BinaryArithmeticResolver(this) private val limitLikeExpressionValidator = new LimitLikeExpressionValidator private val aggregateExpressionResolver = new AggregateExpressionResolver(resolver, this) @@ -136,11 +134,10 @@ class ExpressionResolver( aggregateExpressionResolver, binaryArithmeticResolver ) - private val timestampAddResolver = new TimestampAddResolver(this) - private val unaryMinusResolver = new UnaryMinusResolver(this) private val subqueryExpressionResolver = new SubqueryExpressionResolver(this, resolver) private val ordinalResolver = new OrdinalResolver(this) private val lcaResolver = new LateralColumnAliasResolver(this) + private val semiStructuredExtractResolver = new SemiStructuredExtractResolver(this) /** * Get the expression tree traversal stack. @@ -263,15 +260,15 @@ class ExpressionResolver( case unresolvedListQuery: ListQuery => subqueryExpressionResolver.resolveListQuery(unresolvedListQuery) case unresolvedTimestampAdd: TimestampAddInterval => - timestampAddResolver.resolve(unresolvedTimestampAdd) + resolveExpressionGenericallyWithTimezoneWithTypeCoercion(unresolvedTimestampAdd) case unresolvedUnaryMinus: UnaryMinus => - unaryMinusResolver.resolve(unresolvedUnaryMinus) + resolveExpressionGenericallyWithTypeCoercion(unresolvedUnaryMinus) case createNamedStruct: CreateNamedStruct => resolveExpressionGenerically(createNamedStruct) case sortOrder: SortOrder => resolveExpressionGenerically(sortOrder) case unresolvedConditionalExpression: ConditionalExpression => - conditionalExpressionResolver.resolve(unresolvedConditionalExpression) + resolveExpressionGenericallyWithTypeCoercion(unresolvedConditionalExpression) case getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal => resolveGetViewColumnByNameAndOrdinal(getViewColumnByNameAndOrdinal) case getTimeField: GetTimeField => @@ -286,6 +283,8 @@ class ExpressionResolver( resolveUpCast(unresolvedUpCast) case unresolvedCollation: UnresolvedCollation => resolveCollation(unresolvedCollation) + case semiStructuredExtract: SemiStructuredExtract => + semiStructuredExtractResolver.resolve(semiStructuredExtract) case expression: Expression => resolveExpressionGenericallyWithTypeCoercion(expression) } @@ -584,15 +583,10 @@ class ExpressionResolver( aliasResolver.resolve(unresolvedAlias) case unresolvedAttribute: UnresolvedAttribute => resolveAttribute(unresolvedAttribute) - case unresolvedStar: UnresolvedStar => - // We don't support edge cases of star usage, e.g. `WHERE col1 IN (*)` - throw new ExplicitlyUnsupportedResolverFeature("Star outside of Project list") case attributeReference: AttributeReference => handleResolvedAttributeReference(attributeReference) case outerReference: OuterReference => handleResolvedOuterReference(outerReference) - case _: UnresolvedNamedLambdaVariable => - throw new ExplicitlyUnsupportedResolverFeature("Lambda variables") case _ => withPosition(unresolvedNamedExpression) { throwUnsupportedSinglePassAnalyzerFeature(unresolvedNamedExpression) @@ -646,6 +640,9 @@ class ExpressionResolver( * In case that attribute is resolved as a literal function (i.e. result is [[CurrentDate]]), * perform additional resolution on it. * + * In case result of the previous step is a recursive data type, we coerce it to stay compatible + * with the fixed-point analyzer. + * * If the attribute is at the top of the project list (which is indicated by * [[ExpressionResolutionContext.isTopOfProjectList]]), we preserve the [[Alias]] or remove it * otherwise. @@ -667,7 +664,6 @@ class ExpressionResolver( .resolvingGroupingExpressions && traversals.current.groupByAliases ), canResolveNameByHiddenOutput = canResolveNameByHiddenOutput, - shouldPreferTableColumnsOverAliases = shouldPreferTableColumnsOverAliases, shouldPreferHiddenOutput = traversals.current.isFilterOnTopOfAggregate, canResolveNameByHiddenOutputInSubquery = subqueryRegistry.currentScope.aggregateExpressionsExtractor.isDefined, @@ -703,11 +699,16 @@ class ExpressionResolver( case other => other } + val coercedCandidate = candidateOrLiteralFunction match { + case extractValue: ExtractValue => coerceRecursiveDataTypes(extractValue) + case other => other + } + val properlyAliasedExpressionTree = if (expressionResolutionContext.isTopOfProjectList && nameTarget.aliasName.isDefined) { - Alias(candidateOrLiteralFunction, nameTarget.aliasName.get)() + Alias(coercedCandidate, nameTarget.aliasName.get)() } else { - candidateOrLiteralFunction + coercedCandidate } properlyAliasedExpressionTree match { @@ -718,17 +719,41 @@ class ExpressionResolver( } } + /** + * Coerces recursive types ([[ExtractValue]] expressions) in a bottom up manner. For example: + * + * {{{ + * CREATE OR REPLACE TABLE t(col MAP); + * SELECT col.field FROM t; + * }}} + * + * In this example we need to cast inner field from `String` to `BIGINT`, thus analyzed plan + * should look like: + * + * {{{ + * Project [col#x[cast(field as bigint)] AS field#x] + * +- SubqueryAlias spark_catalog.default.t + * +- Relation spark_catalog.default.t[col#x] parquet + * }}} + * + * This is needed to stay compatible with the fixed-point implementation. + */ + private def coerceRecursiveDataTypes(extractValue: ExtractValue): Expression = { + extractValue.transformUp { + case field => coerceExpressionTypes(field, traversals.current) + } + } + private def isFilterOnTopOfAggregate(parentOperator: LogicalPlan): Boolean = { parentOperator match { - case _ @Filter(_, _: Aggregate) => true + case _: Filter if scopes.current.baseAggregate.isDefined => true case _ => false } } private def isSortOnTopOfAggregate(parentOperator: LogicalPlan): Boolean = { parentOperator match { - case _ @Sort(_, _, _: Aggregate, _) => true - case _ @Sort(_, _, _ @Filter(_, _: Aggregate), _) => true + case _: Sort if scopes.current.baseAggregate.isDefined => true case _ => false } } @@ -738,11 +763,6 @@ class ExpressionResolver( case other => false } - private def shouldPreferTableColumnsOverAliases = traversals.current.parentOperator match { - case _: Sort => true - case _ => false - } - /** * [[AttributeReference]] is already resolved if it's passed to us from DataFrame `col(...)` * function, for example. @@ -1027,8 +1047,12 @@ class ExpressionResolver( TypeCoercionValidation.failOnTypeCheckResult(resolvedExpression) } - if (!resolvedExpression.resolved) { - throwSinglePassFailedToResolveExpression(resolvedExpression) + resolvedExpression match { + case runtimeReplaceable: RuntimeReplaceable if !runtimeReplaceable.replacement.resolved => + throwFailedToResolveRuntimeReplaceableExpression(runtimeReplaceable) + case expression if !expression.resolved => + throwSinglePassFailedToResolveExpression(resolvedExpression) + case _ => } validateExpressionUnderSupportedOperator(resolvedExpression) @@ -1066,10 +1090,17 @@ class ExpressionResolver( context = expression.origin.getQueryContext, summary = expression.origin.context.summary() ) + + private def throwFailedToResolveRuntimeReplaceableExpression( + runtimeReplaceable: RuntimeReplaceable) = { + throw SparkException.internalError( + s"Cannot resolve the runtime replaceable expression ${toSQLExpr(runtimeReplaceable)}. " + + s"The replacement is unresolved: ${toSQLExpr(runtimeReplaceable.replacement)}." + ) + } } object ExpressionResolver { - private val AMBIGUOUS_SELF_JOIN_METADATA = Seq("__dataset_id", "__col_position") val SINGLE_PASS_SUBTREE_BOUNDARY = TreeNodeTag[Unit]("single_pass_subtree_boundary") val SINGLE_PASS_IS_LCA = TreeNodeTag[Unit]("single_pass_is_lca") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FilterResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FilterResolver.scala index 599270b28b1d4..4489278f0b238 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FilterResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/FilterResolver.scala @@ -65,8 +65,7 @@ class FilterResolver(resolver: Resolver, expressionResolver: ExpressionResolver) retainOriginalOutput( operator = finalFilter, missingExpressions = missingAttributes, - output = scopes.current.output, - hiddenOutput = scopes.current.hiddenOutput + scopes = scopes ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAndAggregateExpressionsExtractor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAndAggregateExpressionsExtractor.scala index e28526959f851..4929ffc57b6da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAndAggregateExpressionsExtractor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/GroupingAndAggregateExpressionsExtractor.scala @@ -33,13 +33,23 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate class GroupingAndAggregateExpressionsExtractor( aggregate: Aggregate, autoGeneratedAliasProvider: AutoGeneratedAliasProvider) { - private val aliasChildToAliasInAggregateExpressions = new IdentityHashMap[Expression, Alias] + + /** + * Maps children of aliases from aggregate list to their parents or to `None` if the expression + * doesn't have an alias. This map only accounts for the first appearance of the expression. For + * example, for `SELECT col1, col1 AS a` map entry should be (col1 -> None), but for + * `SELECT col1 AS a, col1` map entry should be (col1 -> Some(a)). + */ + private val aliasChildToAliasInAggregateExpressions = + new IdentityHashMap[Expression, Option[Alias]] private val aggregateExpressionsSemanticComparator = new SemanticComparator( aggregate.aggregateExpressions.map { case alias: Alias => - aliasChildToAliasInAggregateExpressions.put(alias.child, alias) + aliasChildToAliasInAggregateExpressions.putIfAbsent(alias.child, Some(alias)) alias.child - case other => other + case other => + aliasChildToAliasInAggregateExpressions.put(other, None) + other } ) @@ -53,43 +63,48 @@ class GroupingAndAggregateExpressionsExtractor( * (grouping expressions used for extraction) and `extractedAggregateExpressionAliases` (aliases * of [[AggregateExpression]]s that are transformed to attributes during extraction) in order to * insert missing attributes to below operators. + * + * When an expression exists in both grouping and aggregate expressions (for example, when there + * are lateral column references in [[Aggregate]], LCA algorithm will copy grouping expressions + * to aggregate list prior to entering this method), we still extract grouping expression but + * don't add it later if it is not necessary. */ def extractReferencedGroupingAndAggregateExpressions( expression: Expression, referencedGroupingExpressions: mutable.ArrayBuffer[NamedExpression], extractedAggregateExpressionAliases: mutable.ArrayBuffer[Alias]): Expression = { - collectFirstAggregateExpression(expression) match { - case (Some(attribute: Attribute), _) - if !aliasChildToAliasInAggregateExpressions.containsKey(attribute) => - attribute - case (Some(expression), alias) => - alias match { - case None => - throw SparkException.internalError( - s"No parent alias for expression $expression while extracting aggregate" + - s"expressions in Sort operator." - ) - case Some(alias) => - alias.toAttribute - } - case (None, _) if groupingExpressionsSemanticComparator.exists(expression) => - expression match { - case attribute: Attribute => - referencedGroupingExpressions += attribute - attribute - case other => - val alias = autoGeneratedAliasProvider.newAlias(child = other) - referencedGroupingExpressions += alias - alias.toAttribute - } - case _ => - expression match { - case aggregateExpression: AggregateExpression => - val alias = autoGeneratedAliasProvider.newAlias(child = aggregateExpression) - extractedAggregateExpressionAliases += alias - alias.toAttribute - case other => other - } + val aggregateExpressionWithAlias = collectFirstAggregateExpression(expression) + val isGroupingExpression = groupingExpressionsSemanticComparator.exists(expression) + if (isGroupingExpression) { + val groupingExpressionReference = aggregateExpressionWithAlias match { + case (Some(attribute: Attribute), None) => attribute + case (Some(_), Some(alias)) => alias.toAttribute + case _ => + expression match { + case attribute: Attribute => attribute + case other => autoGeneratedAliasProvider.newAlias(child = other) + } + } + referencedGroupingExpressions += groupingExpressionReference + groupingExpressionReference.toAttribute + } else { + aggregateExpressionWithAlias match { + case (Some(attribute: Attribute), None) => attribute + case (Some(_), Some(alias)) => alias.toAttribute + case (Some(expression), None) => + throw SparkException.internalError( + s"No parent alias for expression $expression while extracting aggregate" + + s"expressions in an operator." + ) + case _ => + expression match { + case aggregateExpression: AggregateExpression => + val alias = autoGeneratedAliasProvider.newAlias(child = aggregateExpression) + extractedAggregateExpressionAliases += alias + alias.toAttribute + case other => other + } + } } } @@ -103,7 +118,7 @@ class GroupingAndAggregateExpressionsExtractor( aggregateExpressionsSemanticComparator.collectFirst(expression) referencedAggregateExpression match { case Some(expression) => - (Some(expression), Option(aliasChildToAliasInAggregateExpressions.get(expression))) + (Some(expression), aliasChildToAliasInAggregateExpressions.get(expression)) case None => (None, None) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HavingResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HavingResolver.scala index 4bbcd43a0377a..e84eb9a0bc291 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HavingResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HavingResolver.scala @@ -22,13 +22,14 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.UnresolvedHaving import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} /** * Resolves [[UnresolvedHaving]] node and its condition. */ class HavingResolver(resolver: Resolver, expressionResolver: ExpressionResolver) extends TreeNodeResolver[UnresolvedHaving, LogicalPlan] + with RewritesAliasesInTopLcaProject with ResolvesNameByHiddenOutput with ValidatesFilter { private val scopes: NameScopeStack = resolver.getNameScopes @@ -53,29 +54,46 @@ class HavingResolver(resolver: Resolver, expressionResolver: ExpressionResolver) val partiallyResolvedHaving = Filter(condition = unresolvedHaving.havingCondition, child = resolvedChild) - val resolvedCondition = expressionResolver.resolveExpressionTreeInOperator( + val partiallyResolvedCondition = expressionResolver.resolveExpressionTreeInOperator( partiallyResolvedHaving.condition, partiallyResolvedHaving ) - val (finalCondition, missingExpressions) = resolvedChild match { - case _ if scopes.current.hasLcaInAggregate => - throw new ExplicitlyUnsupportedResolverFeature( - "Lateral column alias in Aggregate below HAVING" + val (resolvedCondition, missingExpressions) = resolvedChild match { + case _ @(_: Project | _: Aggregate) if scopes.current.baseAggregate.isDefined => + handleAggregateBelowHaving( + scopes.current.baseAggregate.get, + partiallyResolvedCondition ) - case aggregate: Aggregate => - handleAggregateBelowHaving(aggregate, resolvedCondition) case other => throw SparkException.internalError( s"Unexpected operator ${other.getClass.getSimpleName} under HAVING" ) } + val (resolvedConditionWithAliasReplacement, filteredMissingExpressions) = + tryReplaceSortOrderOrHavingConditionWithAlias(resolvedCondition, scopes, missingExpressions) + val resolvedChildWithMissingAttributes = - insertMissingExpressions(resolvedChild, missingExpressions.toSeq) + insertMissingExpressions(resolvedChild, filteredMissingExpressions) + + val isChildChangedByMissingExpressions = !resolvedChildWithMissingAttributes.eq(resolvedChild) + + val (finalChild, finalCondition) = resolvedChildWithMissingAttributes match { + case project: Project if scopes.current.baseAggregate.isDefined => + val (newProject, newExpressions) = rewriteNamedExpressionsInTopLcaProject( + projectToRewrite = project, + baseAggregate = scopes.current.baseAggregate.get, + expressionsToRewrite = Seq(resolvedConditionWithAliasReplacement), + rewriteCandidates = missingExpressions, + autoGeneratedAliasProvider = autoGeneratedAliasProvider + ) + (newProject, newExpressions.head) + case other => (other, resolvedCondition) + } val resolvedHaving = partiallyResolvedHaving.copy( - child = resolvedChildWithMissingAttributes, + child = finalChild, condition = finalCondition ) @@ -85,12 +103,15 @@ class HavingResolver(resolver: Resolver, expressionResolver: ExpressionResolver) resolvedFilter = resolvedHaving ) - retainOriginalOutput( - operator = resolvedHaving, - missingExpressions = missingExpressions.toSeq, - output = scopes.current.output, - hiddenOutput = scopes.current.hiddenOutput - ) + if (isChildChangedByMissingExpressions) { + retainOriginalOutput( + operator = resolvedHaving, + missingExpressions = missingExpressions.toSeq, + scopes = scopes + ) + } else { + resolvedHaving + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala index c21d1aacadeaf..0117b3fc2fb55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala @@ -75,21 +75,17 @@ class HybridAnalyzer( private val sampleRateGenerator = new Random() def apply(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { - val passedResolvedGuard = resolverGuard.apply(plan) - val dualRun = conf.getConf(SQLConf.ANALYZER_DUAL_RUN_LEGACY_AND_SINGLE_PASS_RESOLVER) && - passedResolvedGuard && - checkDualRunSampleRate() + checkDualRunSampleRate() && + checkResolverGuard(plan) withTrackedAnalyzerBridgeState(dualRun) { if (dualRun) { resolveInDualRun(plan, tracker) } else if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED)) { resolveInSinglePass(plan, tracker) - } else if (passedResolvedGuard && conf.getConf( - SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED_TENTATIVELY - )) { + } else if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_ENABLED_TENTATIVELY)) { resolveInSinglePassTentatively(plan, tracker) } else { resolveInFixedPoint(plan, tracker) @@ -97,10 +93,6 @@ class HybridAnalyzer( } } - def getSinglePassResolutionDuration: Option[Long] = singlePassResolutionDuration - - def getFixedPointResolutionDuration: Option[Long] = fixedPointResolutionDuration - /** * Call `body` in the context of tracked [[AnalyzerBridgeState]]. Set the new bridge state * depending on whether we are in dual-run mode or not: @@ -206,10 +198,21 @@ class HybridAnalyzer( private def resolveInSinglePassTentatively( plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { - try { - resolveInSinglePass(plan, tracker) - } catch { - case _: ExplicitlyUnsupportedResolverFeature => + val singlePassResult = if (checkResolverGuard(plan)) { + try { + Some(resolveInSinglePass(plan, tracker)) + } catch { + case _: ExplicitlyUnsupportedResolverFeature => + None + } + } else { + None + } + + singlePassResult match { + case Some(result) => + result + case None => resolveInFixedPoint(plan, tracker) } } @@ -256,6 +259,16 @@ class HybridAnalyzer( } } + private def checkResolverGuard(plan: LogicalPlan): Boolean = { + try { + resolverGuard.apply(plan) + } catch { + case e: Throwable + if !conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_EXPOSE_RESOLVER_GUARD_FAILURE) => + false + } + } + /** * Normalizes the logical plan using [[NormalizePlan]]. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/KeyTransformingMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/KeyTransformingMap.scala index ae56c6ed04193..7815b3b0c79cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/KeyTransformingMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/KeyTransformingMap.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import java.util.{Collection, HashMap, Iterator} +import java.util.{Collection, Iterator, LinkedHashMap} import java.util.Map.Entry import java.util.function.Function /** * The [[KeyTransformingMap]] is a partial implementation of [[mutable.Map]] that transforms input * keys with a custom [[mapKey]] method. + * It preserves the order of insertion by using the [[LinkedHashMap]] as an underlying map. */ private abstract class KeyTransformingMap[K, V] { - private val impl = new HashMap[K, V] + private val impl = new LinkedHashMap[K, V] def get(key: K): Option[V] = Option(impl.get(mapKey(key))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasProhibitedRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasProhibitedRegistry.scala index bc0f11f5bd6de..e6e084115adec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasProhibitedRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasProhibitedRegistry.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} * idempotent. */ class LateralColumnAliasProhibitedRegistry extends LateralColumnAliasRegistry { - def withNewLcaScope(body: => Alias): Alias = body + def withNewLcaScope(isTopLevelAlias: Boolean)(body: => Alias): Alias = body def getAttribute(attributeName: String): Option[Attribute] = throwLcaResolutionNotEnabled() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistry.scala index 45a38417a8eed..fa539a4d9110f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistry.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} * LCA resolution is disabled by [[SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED]]. */ abstract class LateralColumnAliasRegistry { - def withNewLcaScope(body: => Alias): Alias + def withNewLcaScope(isTopLevelAlias: Boolean)(body: => Alias): Alias def getAttribute(attributeName: String): Option[Attribute] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistryImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistryImpl.scala index c685b098db2d2..94520455ee8d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistryImpl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasRegistryImpl.scala @@ -81,13 +81,15 @@ class LateralColumnAliasRegistryImpl(attributes: Seq[Attribute]) /** * Creates a new LCA resolution scope for each [[Alias]] resolution. Executes the lambda and - * registers the resolved alias for later LCA resolution. + * registers top-level resolved aliases for later LCA resolution. */ - def withNewLcaScope(body: => Alias): Alias = { + def withNewLcaScope(isTopLevelAlias: Boolean)(body: => Alias): Alias = { currentAttributeDependencyLevelStack.push(0) try { val resolvedAlias = body - registerAlias(resolvedAlias) + if (isTopLevelAlias) { + registerAlias(resolvedAlias) + } resolvedAlias } finally { currentAttributeDependencyLevelStack.pop() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasResolver.scala index 1ffaf73fc356a..5d268f087aa92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LateralColumnAliasResolver.scala @@ -57,9 +57,13 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends * - In order to be able to resolve [[Sort]] on top of an [[Aggregate]] that has LCAs, we need * to collect all aliases from [[Aggregate]], as well as any aliases from artificially inserted * [[Project]] nodes. + * - Collects all aliases from newly created [[Aggregate]] and [[Project]] nodes and adds them + * to `aliasesToCollect` */ def handleLcaInAggregate(resolvedAggregate: Aggregate): AggregateWithLcaResolutionResult = { - extractLcaAndReplaceAggWithProject(resolvedAggregate) match { + val aliasesToCollect = new ArrayBuffer[Alias] + + extractLcaAndReplaceAggWithProject(resolvedAggregate, aliasesToCollect) match { case _ @Project(projectList: Seq[_], aggregate: Aggregate) => // TODO: This validation function does a post-traversal. This is discouraged in single-pass // Analyzer. @@ -76,21 +80,17 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends scope = scopes.current, originalProjectList = projectList, firstIterationProjectList = aggregate.aggregateExpressions.map(_.toAttribute), - remappedAliases = Some(remappedAliases) + remappedAliases = Some(remappedAliases), + aliasesToCollect = aliasesToCollect ) - val aggregateListAliases = - scopes.current.lcaRegistry.getAliasDependencyLevels().asScala.flatMap(_.asScala).toSeq - - scopes.overwriteCurrent( - output = Some(finalProject.projectList.map(_.toAttribute)), - hasLcaInAggregate = true - ) + scopes.overwriteCurrent(output = Some(finalProject.projectList.map(_.toAttribute))) AggregateWithLcaResolutionResult( resolvedOperator = finalProject, outputList = finalProject.projectList, - aggregateListAliases = aggregateListAliases + aggregateListAliases = aliasesToCollect.toSeq, + baseAggregate = aggregate ) case _ => throw SparkException.internalError( @@ -128,6 +128,8 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends * full definitions ( `attr` as `name` ) have already been defined on lower levels. * - If an attribute is never referenced, it does not show up in multi-level project lists, but * instead only in the top-most [[Project]]. + * - Additionally, collect all aliases from newly created [[Project]] nodes and add them to + * `aliasesToCollect`. * * For previously given query, following above rules, resolved [[Project]] would look like: * @@ -142,7 +144,8 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends scope: NameScope, originalProjectList: Seq[NamedExpression], firstIterationProjectList: Seq[NamedExpression], - remappedAliases: Option[HashMap[ExprId, Alias]] = None): Project = { + remappedAliases: Option[HashMap[ExprId, Alias]] = None, + aliasesToCollect: ArrayBuffer[Alias] = ArrayBuffer.empty): Project = { val aliasDependencyMap = scope.lcaRegistry.getAliasDependencyLevels() val (finalChildPlan, _) = aliasDependencyMap.asScala.foldLeft( (resolvedChild, firstIterationProjectList) @@ -159,6 +162,12 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends if (referencedAliases.nonEmpty) { val newProjectList = currentProjectList.map(_.toAttribute) ++ referencedAliases + + newProjectList.foreach { + case alias: Alias => aliasesToCollect += alias + case _ => + } + (Project(newProjectList, currentPlan), newProjectList) } else { (currentPlan, currentProjectList) @@ -173,6 +182,11 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends } } + finalProjectList.foreach { + case alias: Alias => aliasesToCollect += alias + case _ => + } + Project(finalProjectList, finalChildPlan) } @@ -184,6 +198,8 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends * [[NamedExpression]] we don't need to alias it again. * - Places a [[Project]] on top of the new [[Aggregate]] operator, where the project list will * be created from [[Alias]] references to original aggregate expressions. + * - Additionally, collect aliases from newly created aggregate expressions and add them to + * `aliasesToCollect`. * * For example, for a query like: * @@ -198,7 +214,9 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends * The [[Project]] is unresolved, which is fine, because it will later be resolved as if we only * had a lateral alias reference in [[Project]] and not [[Aggregate]]. */ - private def extractLcaAndReplaceAggWithProject(aggregate: Aggregate): Project = { + private def extractLcaAndReplaceAggWithProject( + aggregate: Aggregate, + aliasesToCollect: ArrayBuffer[Alias]): Project = { val newAggregateExpressions = new LinkedHashSet[NamedExpression] val extractedExpressionAliases = new HashMap[Expression, NamedExpression]() val groupingExpressionSemanticComparator = new SemanticComparator(aggregate.groupingExpressions) @@ -212,9 +230,16 @@ class LateralColumnAliasResolver(expressionResolver: ExpressionResolver) extends newAggregateExpressions = newAggregateExpressions ).asInstanceOf[NamedExpression] ) + + val newAggregateExpressionsSeq = newAggregateExpressions.asScala.toSeq + newAggregateExpressionsSeq.foreach { + case alias: Alias => aliasesToCollect += alias + case _ => + } + val result = Project( projectList = extractedExpressions, - child = aggregate.copy(aggregateExpressions = newAggregateExpressions.asScala.toSeq) + child = aggregate.copy(aggregateExpressions = newAggregateExpressionsSeq) ) result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolver.scala deleted file mode 100644 index d25112d78c6e7..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolver.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis.resolver - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.types.IntegerType - -/** - * The [[LimitExpressionResolver]] is a resolver that resolves a [[LocalLimit]] or [[GlobalLimit]] - * expression and performs all the necessary validation. - */ -class LimitExpressionResolver extends TreeNodeResolver[Expression, Expression] { - - /** - * Resolve a limit expression of [[GlobalLimit]] or [[LocalLimit]] and perform validation. - */ - override def resolve(unresolvedLimitExpression: Expression): Expression = { - validateLimitExpression(unresolvedLimitExpression, expressionName = "limit") - unresolvedLimitExpression - } - - /** - * Validate a resolved limit expression of [[GlobalLimit]] or [[LocalLimit]]: - * - The expression has to be foldable - * - The result data type has to be [[IntegerType]] - * - The evaluated expression has to be non-null - * - The evaluated expression has to be positive - * - * The `foldable` check is implemented in some expressions - * as a recursive expression tree traversal. - * It is not an ideal approach for the single-pass [[ExpressionResolver]], - * but __is__ practical, since: - * - We have to call `eval` here anyway, and it's recursive - * - In practice `LIMIT` expression trees are very small - */ - private def validateLimitExpression(expression: Expression, expressionName: String): Unit = { - if (!expression.foldable) { - throwInvalidLimitLikeExpressionIsUnfoldable(expressionName, expression) - } - if (expression.dataType != IntegerType) { - throwInvalidLimitLikeExpressionDataType(expressionName, expression) - } - expression.eval() match { - case null => - throwInvalidLimitLikeExpressionIsNull(expressionName, expression) - case value: Int if value < 0 => - throwInvalidLimitLikeExpressionIsNegative(expressionName, expression, value) - case _ => - } - } - - private def throwInvalidLimitLikeExpressionIsUnfoldable( - name: String, - expression: Expression): Nothing = - throw new AnalysisException( - errorClass = "INVALID_LIMIT_LIKE_EXPRESSION.IS_UNFOLDABLE", - messageParameters = Map( - "name" -> name, - "expr" -> toSQLExpr(expression) - ), - origin = expression.origin - ) - - private def throwInvalidLimitLikeExpressionDataType( - name: String, - expression: Expression): Nothing = - throw new AnalysisException( - errorClass = "INVALID_LIMIT_LIKE_EXPRESSION.DATA_TYPE", - messageParameters = Map( - "name" -> name, - "expr" -> toSQLExpr(expression), - "dataType" -> toSQLType(expression.dataType) - ), - origin = expression.origin - ) - - private def throwInvalidLimitLikeExpressionIsNull(name: String, expression: Expression): Nothing = - throw new AnalysisException( - errorClass = "INVALID_LIMIT_LIKE_EXPRESSION.IS_NULL", - messageParameters = Map("name" -> name, "expr" -> toSQLExpr(expression)), - origin = expression.origin - ) - - private def throwInvalidLimitLikeExpressionIsNegative( - name: String, - expression: Expression, - value: Int): Nothing = - throw new AnalysisException( - errorClass = "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE", - messageParameters = - Map("name" -> name, "expr" -> toSQLExpr(expression), "v" -> toSQLValue(value, IntegerType)), - origin = expression.origin - ) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala index bccc038f87eff..12ca6b4b333d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/MetadataResolver.scala @@ -17,17 +17,9 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import org.apache.spark.sql.catalyst.analysis.{ - FunctionResolution, - RelationResolution, - UnresolvedRelation -} -import org.apache.spark.sql.catalyst.plans.logical.{ - AnalysisHelper, - LogicalPlan, - SubqueryAlias, - UnresolvedWith -} +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.{RelationResolution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.{UNRESOLVED_RELATION, UNRESOLVED_WITH} import org.apache.spark.sql.connector.catalog.CatalogManager @@ -45,9 +37,9 @@ import org.apache.spark.sql.connector.catalog.CatalogManager class MetadataResolver( override val catalogManager: CatalogManager, override val relationResolution: RelationResolution, - functionResolution: FunctionResolution, override val extensions: Seq[ResolverExtension] = Seq.empty) - extends RelationMetadataProvider + extends SQLConfHelper + with RelationMetadataProvider with DelegatesResolutionToExtensions { override val relationsWithResolvedMetadata = new RelationsWithResolvedMetadata diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala index aed5b767b2066..3ccae116cb187 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/NameScope.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.{ NamedExpression, OuterReference } +import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.Metadata @@ -141,8 +142,9 @@ import org.apache.spark.sql.types.Metadata * to this [[NameScope]]. If the [[Aggregate]] has lateral column references, this list contains * both the aliases from [[Aggregate]] as well as all aliases from artificially inserted * [[Project]] nodes. - * @param hasLcaInAggregate Flag that indicates whether there is a lateral column alias reference - * in the [[Aggregate]] corresponding to this [[NameScope]]. + * @param baseAggregate [[Aggregate]] node that is either a resolved [[Aggregate]] corresponding to + * this node or base [[Aggregate]] constructed when resolving lateral column references in + * [[Aggregate]]. */ class NameScope( val output: Seq[Attribute] = Seq.empty, @@ -150,7 +152,7 @@ class NameScope( val isSubqueryRoot: Boolean = false, val availableAliases: HashSet[ExprId] = new HashSet[ExprId], val aggregateListAliases: Seq[Alias] = Seq.empty, - val hasLcaInAggregate: Boolean = false, + val baseAggregate: Option[Aggregate] = None, planLogger: PlanLogger = new PlanLogger ) extends SQLConfHelper { @@ -223,21 +225,21 @@ class NameScope( /** * Returns new [[NameScope]] which preserves all the immutable [[NameScope]] properties but * overwrites `output`, `hiddenOutput`, `availableAliases`, `aggregateListAliases` and - * `hasLcaInAggregate` if provided. Mutable state like `lcaRegistry` is not preserved. + * `baseAggregate` if provided. Mutable state like `lcaRegistry` is not preserved. */ def overwrite( output: Option[Seq[Attribute]] = None, hiddenOutput: Option[Seq[Attribute]] = None, availableAliases: Option[HashSet[ExprId]] = None, aggregateListAliases: Seq[Alias] = Seq.empty, - hasLcaInAggregate: Boolean = false): NameScope = { + baseAggregate: Option[Aggregate] = None): NameScope = { new NameScope( output = output.getOrElse(this.output), hiddenOutput = hiddenOutput.getOrElse(this.hiddenOutput), isSubqueryRoot = isSubqueryRoot, availableAliases = availableAliases.getOrElse(this.availableAliases), aggregateListAliases = aggregateListAliases, - hasLcaInAggregate = hasLcaInAggregate || this.hasLcaInAggregate, + baseAggregate = baseAggregate, planLogger = planLogger ) } @@ -300,6 +302,26 @@ class NameScope( def getOrdinalReplacementExpressions: Option[OrdinalReplacementExpressions] = ordinalReplacementExpressions + /** + * Returns attribute with `expressionId` if `output` contains it. This is used to preserve + * nullability for resolved [[AttributeReference]]. + */ + def getAttributeById(expressionId: ExprId): Option[Attribute] = + Option(outputById.get(expressionId)) + + /** + * Returns attribute with `expressionId` if `hiddenOutput` contains it. + */ + def getHiddenAttributeById(expressionId: ExprId): Option[Attribute] = + Option(hiddenAttributesById.get(expressionId)) + + /** + * Return all the explicitly outputted expression IDs. Hidden or metadata output are not included. + */ + def getOutputIds: Set[ExprId] = { + output.map(_.exprId).toSet + } + /** * Expand the [[UnresolvedStar]]. The expected use case for this method is star expansion inside * [[Project]]. @@ -351,6 +373,28 @@ class NameScope( ) } + /** + * Find attributes in this [[NameScope]] that match a provided one-part `name`. + * + * This method is simpler and more lightweight than [[resolveMultipartName]], because here we + * just return all the attributes matched by the one-part `name`. This is only suitable + * for situations where name _resolution_ is not required (e.g. accessing struct fields + * from the lower operator's output). + * + * For example, this method is used to look up attributes to match a specific [[View]] schema. + * See [[ExpressionResolver.resolveGetViewColumnByNameAndOrdinal]] for more info on view column + * lookup. + * + * We are relying on a simple [[IdentifierMap]] to perform that work, since we just need to match + * one-part name from the lower operator's output here. + */ + def findAttributesByName(name: String): Seq[Attribute] = { + attributesByName.get(name) match { + case Some(attributes) => attributes.toSeq + case None => Seq.empty + } + } + /** * Resolve multipart name into a [[NameTarget]]. [[NameTarget]]'s `candidates` may contain * simple [[AttributeReference]]s if it's a column or alias, or [[ExtractValue]] expressions if @@ -445,19 +489,6 @@ class NameScope( * Even though there is ambiguity with the name `col1`, the [[SortOrder]] expression should be * resolved as a table column from the project list and not throw [[AMBIGUOUS_REFERENCE]]. * - * On the other hand, in the following example: - * - * {{{ - * val df = sql("SELECT 1 AS col1, col1 FROM VALUES(1)") - * df.select("col1") - * }}} - * - * Resolution of name `col1` in the second [[Project]] produce [[AMBIGUOUS_REFERENCE]] error. - * - * In order to achieve this we are using [[shouldPreferTableColumnsOverAliases]] flag which - * should be set to true when the parent operator is [[Sort]] and only when we are resolving by - * `output` (we don't consider this flag for `metadataOutput` or `hiddenOutput`). - * * The names in [[Aggregate.groupingExpressions]] can reference * [[Aggregate.aggregateExpressions]] aliases. `canReferenceAggregateExpressionAliases` will be * true when we are resolving the grouping expressions. @@ -484,17 +515,35 @@ class NameScope( * SELECT COUNT(col1) FROM t1 GROUP BY col1 ORDER BY MAX(col2); * }}} * - * We are relying on the [[AttributeSeq]] to perform that work, since it requires complex - * resolution logic involving nested field extraction and multipart name matching. + * Spark is being smart about name resolution and prioritizes candidates from output levels that + * can actually be resolved, even though that output level might not be the first choice. + * For example, ORDER BY clause prefers attributes from SELECT list (namely, aliases) over table + * columns from below. However, if attributes on the SELECT level have name ambiguity or other + * issues, Spark will try to resolve the name using the table columns from below. Examples: + * + * {{{ + * CREATE TABLE t1 (col1 INT); + * CREATE TABLE t2 (col1 STRUCT); + * + * -- Main output is ambiguous, so col1 from t1 is used for sorting. + * SELECT 1 AS col1, 2 AS col1 FROM t1 ORDER BY col1; + * + * -- col1 from main output does not have `field`, so struct field of col1 from t2 is used for + * -- sorting. + * SELECT 1 AS col1 FROM t2 ORDER BY col1.field; + * }}} + * + * This is achieved using candidate prioritization mechanism in [[pickSuitableCandidates]]. * - * Also, see [[AttributeSeq.resolve]] for more details. + * We are relying on the [[AttributeSeq]] to perform name resolution, since it requires complex + * resolution logic involving nested field extraction and multipart name matching. See + * [[AttributeSeq.resolve]] for more details. */ def resolveMultipartName( multipartName: Seq[String], canLaterallyReferenceColumn: Boolean = false, canReferenceAggregateExpressionAliases: Boolean = false, canResolveNameByHiddenOutput: Boolean = false, - shouldPreferTableColumnsOverAliases: Boolean = false, shouldPreferHiddenOutput: Boolean = false, canReferenceAggregatedAccessOnlyAttributes: Boolean = false): NameTarget = { val resolvedMultipartName: ResolvedMultipartName = @@ -502,7 +551,6 @@ class NameScope( multipartName = multipartName, nameComparator = nameComparator, canResolveNameByHiddenOutput = canResolveNameByHiddenOutput, - shouldPreferTableColumnsOverAliases = shouldPreferTableColumnsOverAliases, shouldPreferHiddenOutput = shouldPreferHiddenOutput, canReferenceAggregatedAccessOnlyAttributes = canReferenceAggregatedAccessOnlyAttributes ).orElse(tryResolveMultipartNameAsLiteralFunction(multipartName)) @@ -536,149 +584,104 @@ class NameScope( } /** - * Find attributes in this [[NameScope]] that match a provided one-part `name`. - * - * This method is simpler and more lightweight than [[resolveMultipartName]], because here we - * just return all the attributes matched by the one-part `name`. This is only suitable - * for situations where name _resolution_ is not required (e.g. accessing struct fields - * from the lower operator's output). - * - * For example, this method is used to look up attributes to match a specific [[View]] schema. - * See [[ExpressionResolver.resolveGetViewColumnByNameAndOrdinal]] for more info on view column - * lookup. - * - * We are relying on a simple [[IdentifierMap]] to perform that work, since we just need to match - * one-part name from the lower operator's output here. - */ - def findAttributesByName(name: String): Seq[Attribute] = { - attributesByName.get(name) match { - case Some(attributes) => attributes.toSeq - case None => Seq.empty - } - } - - /** - * Returns attribute with `expressionId` if `output` contains it. This is used to preserve - * nullability for resolved [[AttributeReference]]. - */ - def getAttributeById(expressionId: ExprId): Option[Attribute] = - Option(outputById.get(expressionId)) - - /** - * Returns attribute with `expressionId` if `hiddenOutput` contains it. - */ - def getHiddenAttributeById(expressionId: ExprId): Option[Attribute] = - Option(hiddenAttributesById.get(expressionId)) - - /** - * Return all the explicitly outputted expression IDs. Hidden or metadata output are not included. - */ - def getOutputIds: Set[ExprId] = { - output.map(_.exprId).toSet - } - - /** - * Resolution by attributes available in the current [[NameScope]] is done in the following way: - * - First, we resolve the name using all the available attributes in the current scope - * - For all the candidates that are found, we lookup the expression IDs in the maps created - * when [[NameScope]] is updated to distinguish attributes resolved using the main output, - * hidden output and metadata output (for hidden output, we use - * `canReferenceAggregatedAccessOnlyAttributes` flag to determine if all the attributes can be - * used or only the ones that are not tagged as `aggregatedAccessOnly`). - * - We prioritize the hidden output over the other ones if `shouldPreferHiddenOutput` is set to - * true. This is done in case of HAVING where attributes from grouping expressions of the - * underlying [[Aggregate]] are preferred over aliases from operator below. Example: - * - * {{{ SELECT 1 AS col1 FROM VALUES(1, 2) GROUP BY col1 HAVING col1 = 1; }}} - * - * Plan would be: - * Project [col1#2] - * +- Filter (col1#1 = 1) - * +- Aggregate [col1#1], [a AS col1#2, col1#1] - * +- LocalRelation [col1#1, col2#3] - * - * Otherwise, we prioritize main output over the metadata output and metadata output - * over the hidden output. - * - If `shouldPreferTableColumnsOverAliases` is set to true, we prefer the table columns over - * the aliases which can be used for name resolution. - * - If we didn't find any candidates this way we fallback to other ways of resolution described - * in `resolveMultipartName` doc. + * Try resolve [[multipartName]] using attributes from a relevant operator output. This algorithm + * splits candidates from [[attributesForResolution]] into several groups and picks the best match + * ensuring that there's no choice ambiguity. + * + * Detailed flow: + * 1. Match the given [[multipartName]] using + * [[attributesForResolution.getCandidatesForResolution]] and get a subset of candidates for + * that name. + * 2. If nested fields were inferred during the name matching process, we are dealing with + * struct/map/array field/element extraction. Further narrow down those attributes that are + * suitable for field extraction using [[ExtractValue.isExtractable]]. We can safely do this + * right away, because nested fields cannot be applied to non-recursive data types. + * 3. Triage the candidates into several groups: main output, metadata output and hidden output. + * Main output is the topmost output of a relevant operator (actual SELECT list). Metadata + * output is a special qualified-access only output which originates from [[NaturalJoin]] or + * [[UsingJoin]] and can only be accessed by a qualified multipart name. If we have it, it + * means that [[attributesForResolution.getCandidatesForResolution]] inferred a qualified + * attribute name. Hidden output is only used if [[canResolveNameByHiddenOutput]] is specified + * (in ORDER BY and HAVING clauses). These attributes can sometimes be accessed from below in + * relation to the relevant operator - the attributes are not explicitly mentioned in a SELECT + * clause, but SQL language rules still allow referencing them. Not all hidden attributes can + * be referenced if we are dealing with an [[Aggregate]] - only those that are part of grouping + * expressions, or if we are resolving a name under an aggregate function (if + * [[canReferenceAggregatedAccessOnlyAttributes]] is specified). + * 4. Infer the right resolution priority depending on [[canResolveNameByHiddenOutput]] and + * [[shouldPreferHiddenOutput]] flag values. These flags are set depending on the operator + * in which context we are currently resolving the [[multipartName]]. For example, ORDER BY + * clause prefers attributes from SELECT list over lower attributes from the table, but HAVING + * clause has the opposite rules. + * 5. Pick the best suitable candidates using [[pickSuitableCandidates]]. We prioritize candidates + * that have exactly 1 match for the [[multipartName]], because other options would fail. + * If there was a single match, we return [[ResolvedMultipartName]] with that attribute, and + * multipart name resolution process succeeds. If none of the options are suitable, we fall + * back to the main output and either return [[ResolvedMultipartName]] with multiple candidates + * from that main output to throw a descriptive [[AMBIGUOUS_REFERENCE]] error later or return + * [[None]] to continue the name resolution process using other sources. + * + * This algorithm is incomplete and completely covers just the SQL scenarios. DataFrame + * programs can prioritize several layers of [[Project]] outputs if several nested + * `.select(...)` calls have conflicting attributes. */ private def tryResolveMultipartNameByOutput( multipartName: Seq[String], nameComparator: NameComparator, canResolveNameByHiddenOutput: Boolean, - shouldPreferTableColumnsOverAliases: Boolean, shouldPreferHiddenOutput: Boolean, canReferenceAggregatedAccessOnlyAttributes: Boolean): Option[ResolvedMultipartName] = { - val (candidates, nestedFields) = - attributesForResolution.getCandidatesForResolution(multipartName, nameComparator) - - val hiddenOutputCandidates = candidates.filter { element => - !outputById.containsKey(element.exprId) && - (canReferenceAggregatedAccessOnlyAttributes || !element.aggregatedAccessOnly) - } + val (candidates, nestedFields) = getCandidatesForResolution(multipartName) - val (currentCandidates: Seq[Attribute], resolutionType: String) = - if (shouldPreferHiddenOutput && hiddenOutputCandidates.nonEmpty) { - (hiddenOutputCandidates, "hidden") - } else { - val outputCandidates = candidates.filter { element => - outputById.containsKey(element.exprId) - } + val mainOutputCandidates = getMainOutputCandidates(candidates) + val metadataOutputCandidates = getMetadataOutputCandidates(candidates) - if (outputCandidates.nonEmpty) { - (outputCandidates, "normal") - } else { - val metadataOutputCandidates = - candidates.filter { element => - !outputById.containsKey(element.exprId) && element.qualifiedAccessOnly - } + val resolutionOrder = if (canResolveNameByHiddenOutput) { + val hiddenOutputCandidates = + getHiddenOutputCandidates(candidates, canReferenceAggregatedAccessOnlyAttributes) - if (metadataOutputCandidates.nonEmpty) { - (metadataOutputCandidates, "metadata") - } else { - if (canResolveNameByHiddenOutput && - !shouldPreferHiddenOutput && - hiddenOutputCandidates.nonEmpty) { - (hiddenOutputCandidates, "hidden") - } else { - (Seq.empty, "") - } - } - } + if (shouldPreferHiddenOutput) { + Seq( + CandidatesForResolution(hiddenOutputCandidates, OutputType.Hidden), + CandidatesForResolution(mainOutputCandidates, OutputType.Main), + CandidatesForResolution(metadataOutputCandidates, OutputType.Metadata) + ) + } else { + Seq( + CandidatesForResolution(mainOutputCandidates, OutputType.Main), + CandidatesForResolution(metadataOutputCandidates, OutputType.Metadata), + CandidatesForResolution(hiddenOutputCandidates, OutputType.Hidden) + ) } + } else { + Seq( + CandidatesForResolution(mainOutputCandidates, OutputType.Main), + CandidatesForResolution(metadataOutputCandidates, OutputType.Metadata) + ) + } + + val suitableCandidates = pickSuitableCandidates( + resolutionOrder = resolutionOrder, + fallbackCandidates = CandidatesForResolution(mainOutputCandidates, OutputType.Main) + ) val resolvedCandidates = attributesForResolution.resolveCandidates( multipartName, nameComparator, - currentCandidates, + suitableCandidates.attributes, nestedFields ) if (resolvedCandidates.nonEmpty) { - val candidatesWithPreferredColumnsOverAliases = if (shouldPreferTableColumnsOverAliases) { - val (aliasCandidates, nonAliasCandidates) = - resolvedCandidates.partition(candidate => availableAliases.contains(candidate.exprId)) - - if (nonAliasCandidates.nonEmpty) { - nonAliasCandidates - } else { - aliasCandidates - } - } else { - resolvedCandidates - } - planLogger.logNameResolutionEvent( multipartName, - candidatesWithPreferredColumnsOverAliases, - s"From $resolutionType output" + resolvedCandidates, + s"From ${suitableCandidates.outputType} output" ) Some( ResolvedMultipartName( - candidates = candidatesWithPreferredColumnsOverAliases, + candidates = resolvedCandidates, referencedAttribute = None ) ) @@ -687,6 +690,54 @@ class NameScope( } } + private def getCandidatesForResolution( + multipartName: Seq[String]): (Seq[Attribute], Seq[String]) = { + val (candidates, nestedFields) = + attributesForResolution.getCandidatesForResolution(multipartName, nameComparator) + + val filteredCandidates = if (nestedFields.nonEmpty) { + candidates.filter { attribute => + ExtractValue.isExtractable(attribute, nestedFields, nameComparator) + } + } else { + candidates + } + + (filteredCandidates, nestedFields) + } + + private def getMainOutputCandidates(candidates: Seq[Attribute]): Seq[Attribute] = { + candidates.filter { attribute => + outputById.containsKey(attribute.exprId) + } + } + + private def getMetadataOutputCandidates(candidates: Seq[Attribute]): Seq[Attribute] = { + candidates.filter { element => + !outputById.containsKey(element.exprId) && element.qualifiedAccessOnly + } + } + + private def getHiddenOutputCandidates( + candidates: Seq[Attribute], + canReferenceAggregatedAccessOnlyAttributes: Boolean): Seq[Attribute] = { + candidates.filter { attribute => + !availableAliases.contains(attribute.exprId) && + (canReferenceAggregatedAccessOnlyAttributes || !attribute.aggregatedAccessOnly) + } + } + + private def pickSuitableCandidates( + resolutionOrder: Seq[CandidatesForResolution], + fallbackCandidates: CandidatesForResolution): CandidatesForResolution = { + resolutionOrder + .collectFirst { + case candidates if candidates.attributes.size == 1 => + candidates + } + .getOrElse(fallbackCandidates) + } + private def tryResolveMultipartNameAsLiteralFunction( multipartName: Seq[String]): Option[ResolvedMultipartName] = { val literalFunction = LiteralFunctionResolution.resolve(multipartName).toSeq @@ -841,8 +892,8 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel /** * Completely overwrite the current scope state with operator `output`, `hiddenOutput`, - * `availableAliases`, `aggregateListAliases` and `hasLcaInAggregate`. If `hiddenOutput`, - * `availableAliases` or `hasLcaInAggregate` are not provided, preserve the previous values. + * `availableAliases`, `aggregateListAliases` and `baseAggregate`. If `hiddenOutput`, + * `availableAliases` or `baseAggregate` are not provided, preserve the previous values. * Additionally, update nullabilities of attributes in hidden output from new output, so that if * attribute was nullable in either old hidden output or new output, it must stay nullable in new * hidden output as well. @@ -869,7 +920,7 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel hiddenOutput: Option[Seq[Attribute]] = None, availableAliases: Option[HashSet[ExprId]] = None, aggregateListAliases: Seq[Alias] = Seq.empty, - hasLcaInAggregate: Boolean = false): Unit = { + baseAggregate: Option[Aggregate] = None): Unit = { val hiddenOutputWithUpdatedNullabilities = updateHiddenOutputProperties( output.getOrElse(stack.peek().output), hiddenOutput.getOrElse(stack.peek().hiddenOutput) @@ -880,15 +931,15 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel hiddenOutput = Some(hiddenOutputWithUpdatedNullabilities), availableAliases = availableAliases, aggregateListAliases = aggregateListAliases, - hasLcaInAggregate = hasLcaInAggregate + baseAggregate = baseAggregate ) stack.push(newScope) } /** - * Overwrites `output`, `groupingAttributeIds` and `aggregateListAliases` of the current - * [[NameScope]] entry and: + * Overwrites `output`, `groupingAttributeIds`, `aggregateListAliases` and `baseAggregate` of the + * current [[NameScope]] entry and: * 1. extends hidden output with the provided output (only attributes that are not in the hidden * output are added). This is done because resolution of arguments can be done through certain * operators by hidden output. This use case is specific to Dataframe programs. Example: @@ -916,7 +967,8 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel def overwriteOutputAndExtendHiddenOutput( output: Seq[Attribute], groupingAttributeIds: Option[HashSet[ExprId]] = None, - aggregateListAliases: Seq[Alias] = Seq.empty): Unit = { + aggregateListAliases: Seq[Alias] = Seq.empty, + baseAggregate: Option[Aggregate] = None): Unit = { val prevScope = stack.pop val hiddenOutputWithUpdatedProperties: Seq[Attribute] = updateHiddenOutputProperties( @@ -932,7 +984,8 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel val newScope = prevScope.overwrite( output = Some(output), hiddenOutput = Some(hiddenOutput), - aggregateListAliases = aggregateListAliases + aggregateListAliases = aggregateListAliases, + baseAggregate = baseAggregate ) stack.push(newScope) @@ -989,11 +1042,11 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel /** * After finishing the resolution after [[pushScope]], the caller needs to call [[popScope]] to - * clear the stack. We propagate `hiddenOutput`, `availableAliases` and `hasLcaInAggregate` - * upwards because of name resolution by overwriting their current values with the popped ones. - * This is not done in case [[pushScope]] and [[popScope]] were called in the context of subquery - * resolution (which is indicated by `isSubqueryRoot` flag), because we don't want to overwrite - * the existing `hiddenOutput` of the main plan. + * clear the stack. We propagate `hiddenOutput`, `availableAliases` upwards because of name + * resolution by overwriting their current values with the popped ones. This is not done in case + * [[pushScope]] and [[popScope]] were called in the context of subquery resolution (which is + * indicated by `isSubqueryRoot` flag), because we don't want to overwrite the existing + * `hiddenOutput` of the main plan. */ def popScope(): Unit = { val childScope = stack.pop() @@ -1002,8 +1055,7 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel stack.push( currentScope.overwrite( hiddenOutput = Some(childScope.hiddenOutput), - availableAliases = Some(childScope.availableAliases), - hasLcaInAggregate = childScope.hasLcaInAggregate + availableAliases = Some(childScope.availableAliases) ) ) } @@ -1082,7 +1134,6 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel canLaterallyReferenceColumn = canLaterallyReferenceColumn, canReferenceAggregateExpressionAliases = canReferenceAggregateExpressionAliases, canResolveNameByHiddenOutput = canResolveNameByHiddenOutput, - shouldPreferTableColumnsOverAliases = shouldPreferTableColumnsOverAliases, shouldPreferHiddenOutput = shouldPreferHiddenOutput, canReferenceAggregatedAccessOnlyAttributes = canReferenceAggregatedAccessOnlyAttributes ) @@ -1103,7 +1154,9 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel if (nameTarget.candidates.nonEmpty) { nameTarget.copy( isOuterReference = true, - candidates = nameTarget.candidates.map(wrapCandidateInOuterReference) + candidates = nameTarget.candidates.map { candidate => + wrapCandidateInOuterReference(candidate, outer) + } ) } else { nameTargetFromCurrentScope @@ -1136,18 +1189,60 @@ class NameScopeStack(planLogger: PlanLogger = new PlanLogger) extends SQLConfHel /** * Wrap candidate in [[OuterReference]]. If the root is not an [[Attribute]], but an * [[ExtractValue]] (struct/map/array field reference) we find the actual [[Attribute]] and wrap - * it in [[OuterReference]]. + * it in [[OuterReference]]. In case found [[Attribute]] is aliased in the outer scope, we + * replace it with an [[Attribute]] created from the [[Alias]]. */ - private def wrapCandidateInOuterReference(candidate: Expression): Expression = candidate match { - case candidate: Attribute => - OuterReference(candidate) - case extractValue: ExtractValue => - extractValue.transformUp { - case attribute: Attribute => OuterReference(attribute) - case other => other + private def wrapCandidateInOuterReference( + candidate: Expression, + outerScope: NameScope): Expression = { + candidate match { + case extractValue: ExtractValue => + extractValue.transformUp { + case attribute: Attribute => + tryReplaceOuterReferenceAttributeWithAlias(attribute, outerScope) + case other => other + } + case attribute: Attribute => + tryReplaceOuterReferenceAttributeWithAlias(attribute, outerScope) + case other => other + } + } + + /** + * Try to replace an [[Attribute]] with an [[Attribute]] created out of the [[Alias]] from the + * outer scope. For example: + * + * {{{ SELECT col1 AS alias FROM VALUES('a') GROUP BY col1 HAVING (SELECT col1 = 'a'); }}} + * + * Plan should be: + * + * {{{ + * Filter cast(scalar-subquery#2 [alias#1] as boolean) + * +- Project [(outer(alias#1) = a) AS (outer(col1) = a)#3] + * +- OneRowRelation + * +- Aggregate [col1#0], [col1#0 AS alias#1] + * +- LocalRelation [col1#0] + * }}} + * + * As it can be seen, we replace `outer(col1)` with `outer(alias)` but keep the original + * [[Attribute]] in the name (to be compatible with the fixed-point implementation). + */ + private def tryReplaceOuterReferenceAttributeWithAlias( + attribute: Attribute, + outerScope: NameScope): OuterReference = { + val replacedAttribute = outerScope.aggregateListAliases + .collectFirst { + case alias if alias.child.semanticEquals(attribute) => alias.toAttribute } - case _ => - candidate + .getOrElse(attribute) + + val outerReference = OuterReference(replacedAttribute) + outerReference.setTagValue( + OuterReference.SINGLE_PASS_SQL_STRING_OVERRIDE, + toPrettySQL(OuterReference(attribute)) + ) + + outerReference } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/OperatorWithUncomparableTypeValidator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/OperatorWithUncomparableTypeValidator.scala new file mode 100644 index 0000000000000..6afa9e379297d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/OperatorWithUncomparableTypeValidator.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{Distinct, LogicalPlan, SetOperation} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{DataType, MapType, VariantType} + +/** + * [[OperatorWithUncomparableTypeValidator]] performs the validation of a logical plan to ensure + * that it (if it is [[Distinct]] or [[SetOperation]]) does not contain any uncomparable types: + * [[VariantType]], [[MapType]], [[GeometryType]] or [[GeographyType]]. + */ +object OperatorWithUncomparableTypeValidator { + + /** + * Validates that the provided logical plan does not contain any uncomparable types: + * [[VariantType]], [[MapType]], [[GeometryType]] or [[GeographyType]] (throws a specific + * user-facing error if it does). Operators that are not supported are [[Distinct]] and + * [[SetOperation]] ([[Union]], [[Except]], [[Intersect]]). + */ + def validate(operator: LogicalPlan, output: Seq[Attribute]): Unit = { + operator match { + case unsupportedOperator @ (_: SetOperation | _: Distinct) => + + output.foreach { element => + if (hasMapType(element.dataType)) { + throwUnsupportedSetOperationOnMapType(element, unsupportedOperator) + } + + if (hasVariantType(element.dataType)) { + throwUnsupportedSetOperationOnVariantType(element, unsupportedOperator) + } + } + case _ => + } + } + + private def hasMapType(dt: DataType): Boolean = { + dt.existsRecursively(_.isInstanceOf[MapType]) + } + + private def hasVariantType(dt: DataType): Boolean = { + dt.existsRecursively(_.isInstanceOf[VariantType]) + } + + private def throwUnsupportedSetOperationOnMapType( + mapCol: Attribute, + unresolvedPlan: LogicalPlan): Unit = { + throw QueryCompilationErrors.unsupportedSetOperationOnMapType( + mapCol = mapCol, + origin = unresolvedPlan.origin + ) + } + + private def throwUnsupportedSetOperationOnVariantType( + variantCol: Attribute, + unresolvedPlan: LogicalPlan): Unit = { + throw QueryCompilationErrors.unsupportedSetOperationOnVariantType( + variantCol = variantCol, + origin = unresolvedPlan.origin + ) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TryExtractOrdinal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/OutputType.scala similarity index 67% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TryExtractOrdinal.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/OutputType.scala index 42766a78e248f..04152870e0014 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TryExtractOrdinal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/OutputType.scala @@ -17,18 +17,12 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import org.apache.spark.sql.catalyst.expressions.{Expression, IntegerLiteral} - /** - * Try to extract ordinal from an expression. Return `Some(ordinal)` if the type of the expression - * is [[IntegerLitera]], `None` otherwise. + * [[OutputType]] represents different types of output used during multipart name resolution in the + * [[NameScope]]. */ -object TryExtractOrdinal { - def apply(expression: Expression): Option[Int] = { - expression match { - case IntegerLiteral(literal) => - Some(literal) - case other => None - } - } +object OutputType extends Enumeration { + type OutputType = Value + + val Main, Hidden, Metadata = Value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanRewriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanRewriter.scala index 73a83fd8c3eec..544d8180fa1fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanRewriter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PlanRewriter.scala @@ -59,7 +59,7 @@ class PlanRewriter( val planWithRewrittenSubqueries = plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case subqueryExpression: SubqueryExpression => - val rewrittenSubqueryPlan = rewrite(subqueryExpression.plan) + val rewrittenSubqueryPlan = doRewriteWithSubqueries(subqueryExpression.plan) subqueryExpression.withNewPlan(rewrittenSubqueryPlan) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala index 076ba6019d786..9e6522968b5b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ProjectResolver.scala @@ -60,12 +60,19 @@ class ProjectResolver(operatorResolver: Resolver, expressionResolver: Expression * * After the subtree and project-list expressions are resolved in the child scope we overwrite * current scope with resolved operators output to expose new names to the parent operators. + * + * We need to clear [[NameScope.availableAliases]]. Those are only relevant for the immediate + * project list for output prioritization to work correctly in + * [[NameScope.tryResolveMultipartNameByOutput]]. */ override def resolve(unresolvedProject: Project): LogicalPlan = { scopes.pushScope() val (resolvedOperator, resolvedProjectList) = try { val resolvedChild = operatorResolver.resolve(unresolvedProject.child) + + scopes.current.availableAliases.clear() + val childReferencedAttributes = expressionResolver.getLastReferencedAttributes val resolvedProjectList = expressionResolver.resolveProjectList(unresolvedProject.projectList, unresolvedProject) @@ -87,7 +94,8 @@ class ProjectResolver(operatorResolver: Resolver, expressionResolver: Expression expressions = aggregateWithLcaResolutionResult.outputList, hasAggregateExpressions = false, hasLateralColumnAlias = false, - aggregateListAliases = aggregateWithLcaResolutionResult.aggregateListAliases + aggregateListAliases = aggregateWithLcaResolutionResult.aggregateListAliases, + baseAggregate = Some(aggregateWithLcaResolutionResult.baseAggregate) ) (aggregateWithLcaResolutionResult.resolvedOperator, projectList) } else { @@ -95,8 +103,10 @@ class ProjectResolver(operatorResolver: Resolver, expressionResolver: Expression // single-pass Analyzer. ExprUtils.assertValidAggregation(aggregate) - val resolvedAggregateList = - resolvedProjectList.copy(aggregateListAliases = scopes.current.aggregateListAliases) + val resolvedAggregateList = resolvedProjectList.copy( + aggregateListAliases = scopes.current.aggregateListAliases, + baseAggregate = Some(aggregate) + ) (aggregate, resolvedAggregateList) } @@ -119,7 +129,8 @@ class ProjectResolver(operatorResolver: Resolver, expressionResolver: Expression scopes.overwriteOutputAndExtendHiddenOutput( output = resolvedProjectList.expressions.map(namedExpression => namedExpression.toAttribute), - aggregateListAliases = resolvedProjectList.aggregateListAliases + aggregateListAliases = resolvedProjectList.aggregateListAliases, + baseAggregate = resolvedProjectList.baseAggregate ) resolvedOperator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PullOutNondeterministicExpressionInExpressionTree.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PullOutNondeterministicExpressionInExpressionTree.scala deleted file mode 100644 index 3272c6975075c..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/PullOutNondeterministicExpressionInExpressionTree.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis.resolver - -import java.util.LinkedHashMap - -import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} - -/** - * Pull out nondeterministic expressions in an expression tree and replace them with the - * corresponding attributes in the `nondeterministicToAttributes` map. - */ -object PullOutNondeterministicExpressionInExpressionTree { - def apply[ExpressionType <: Expression]( - expression: ExpressionType, - nondeterministicToAttributes: LinkedHashMap[Expression, NamedExpression]): ExpressionType = { - expression - .transform { - case childExpression => - nondeterministicToAttributes.get(childExpression) match { - case null => - childExpression - case namedExpression => - namedExpression.toAttribute - } - } - .asInstanceOf[ExpressionType] - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionCheckRunner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionCheckRunner.scala new file mode 100644 index 0000000000000..79a3f9cf1707c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionCheckRunner.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.logical.{AnalysisHelper, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION +import org.apache.spark.sql.internal.SQLConf + +/** + * The [[ResolutionCheckRunner]] is used to run `resolutionChecks` on the logical plan. + * + * Important note: these checks are not always idempotent, and sometimes perform heavy network + * operations. + */ +class ResolutionCheckRunner(resolutionChecks: Seq[LogicalPlan => Unit]) extends SQLConfHelper { + + /** + * Runs the resolution checks on `plan`. Invokes all the checks for every subquery plan, and + * eventually for the main query plan. + */ + def runWithSubqueries(plan: LogicalPlan): Unit = { + if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_RUN_EXTENDED_RESOLUTION_CHECKS)) { + AnalysisHelper.allowInvokingTransformsInAnalyzer { + doRunWithSubqueries(plan) + } + } + } + + private def doRunWithSubqueries(plan: LogicalPlan): Unit = { + val planWithRewrittenSubqueries = + plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { + case subqueryExpression: SubqueryExpression => + doRunWithSubqueries(subqueryExpression.plan) + + subqueryExpression + } + + run(planWithRewrittenSubqueries) + } + + private def run(plan: LogicalPlan): Unit = { + for (check <- resolutionChecks) { + check(plan) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala index b621e396a8839..326ee8463a79f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidator.scala @@ -101,6 +101,8 @@ class ResolutionValidator { validateJoin(join) case repartition: Repartition => validateRepartition(repartition) + case sample: Sample => + validateSample(sample) // [[LogicalRelation]], [[HiveTableRelation]] and other specific relations can't be imported // because of a potential circular dependency, so we match a generic Catalyst // [[MultiInstanceRelation]] instead. @@ -269,6 +271,10 @@ class ResolutionValidator { validate(repartition.child) } + private def validateSample(sample: Sample): Unit = { + validate(sample.child) + } + private def validateJoin(join: Join) = { attributeScopeStack.pushScope() try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala index 6c436d4176d22..9bbce99bb1130 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvedProjectList.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate /** * Structure used to return results of the resolved project list. @@ -28,9 +29,12 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression} * - hasLateralColumnAlias: True if the resolved project list contains any lateral column aliases. * - aggregateListAliases: List of aliases in aggregate list if there are aggregate expressions in * the [[Project]]. + * - baseAggregate: Base [[Aggregate]] node constructed by [[LateralColumnAliasResolver]] while + * resolving lateral column references in [[Aggregate]]. */ case class ResolvedProjectList( expressions: Seq[NamedExpression], hasAggregateExpressions: Boolean, hasLateralColumnAlias: Boolean, - aggregateListAliases: Seq[Alias]) + aggregateListAliases: Seq[Alias], + baseAggregate: Option[Aggregate] = None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala index d9d698b1fecac..29cd7f0d3e42d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala @@ -42,8 +42,7 @@ import org.apache.spark.sql.catalyst.expressions.{ ExprId } import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNodeTag} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors @@ -112,7 +111,6 @@ class Resolver( private var relationMetadataProvider: RelationMetadataProvider = new MetadataResolver( catalogManager, relationResolution, - functionResolution, metadataResolverExtensions ) @@ -257,6 +255,8 @@ class Resolver( resolveSupervisingCommand(supervisingCommand) case repartition: Repartition => resolveRepartition(repartition) + case sample: Sample => + resolveSample(sample) case _ => tryDelegateResolutionToExtension(unresolvedPlan).getOrElse { handleUnmatchedOperator(unresolvedPlan) @@ -476,12 +476,18 @@ class Resolver( /** * [[Distinct]] operator doesn't require any special resolution. + * We validate results of the resolution using the [[OperatorWithUncomparableTypeValidator]] + * ([[MapType]], [[VariantType]], [[GeometryType]] and [[GeographyType]] are not supported + * under [[Distinct]] operator). * * `hiddenOutput` and `availableAliases` are reset when [[Distinct]] is reached during tree * traversal. */ private def resolveDistinct(unresolvedDistinct: Distinct): LogicalPlan = { val resolvedDistinct = unresolvedDistinct.copy(child = resolve(unresolvedDistinct.child)) + + OperatorWithUncomparableTypeValidator.validate(resolvedDistinct, scopes.current.output) + scopes.overwriteCurrent( hiddenOutput = Some(scopes.current.output), availableAliases = Some(new HashSet[ExprId]) @@ -660,6 +666,14 @@ class Resolver( repartition.copy(child = resolve(repartition.child)) } + /** + * Resolve [[Sample]] operator. Its resolution doesn't require any specific logic (besides + * child resolution). + */ + private def resolveSample(sample: Sample): LogicalPlan = { + sample.copy(child = resolve(sample.child)) + } + private def createCteRelationRef(name: String, cteRelationDef: CTERelationDef): LogicalPlan = { SubqueryAlias( identifier = name, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala index 75215524d2144..5b28d5369e387 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala @@ -19,7 +19,11 @@ package org.apache.spark.sql.catalyst.analysis.resolver import java.util.Locale -import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper, SqlScriptingContextManager} +import org.apache.spark.sql.catalyst.{ + FunctionIdentifier, + SQLConfHelper, + SqlScriptingContextManager +} import org.apache.spark.sql.catalyst.analysis.{ FunctionRegistry, GetViewColumnByNameAndOrdinal, @@ -143,6 +147,8 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { checkRepartition(repartition) case having: UnresolvedHaving => checkHaving(having) + case sample: Sample => + checkSample(sample) case _ => false } @@ -168,8 +174,6 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { checkUnresolvedCast(unresolvedCast) case unresolvedUpCast: UpCast => checkUnresolvedUpCast(unresolvedUpCast) - case unresolvedStar: UnresolvedStar => - checkUnresolvedStar(unresolvedStar) case unresolvedAlias: UnresolvedAlias => checkUnresolvedAlias(unresolvedAlias) case unresolvedAttribute: UnresolvedAttribute => @@ -194,6 +198,8 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { checkUnresolvedFunction(unresolvedFunction) case getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal => checkGetViewColumnBynameAndOrdinal(getViewColumnByNameAndOrdinal) + case semiStructuredExtract: SemiStructuredExtract => + checkSemiStructuredExtract(semiStructuredExtract) case expression if isGenerallySupportedExpression(expression) => expression.children.forall(checkExpression) case _ => @@ -219,13 +225,23 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { } private def checkProject(project: Project) = { - checkOperator(project.child) && project.projectList.forall(checkExpression) + checkOperator(project.child) && project.projectList.forall { + case _: UnresolvedStar => + true + case other => + checkExpression(other) + } } private def checkAggregate(aggregate: Aggregate) = { checkOperator(aggregate.child) && aggregate.groupingExpressions.forall(checkExpression) && - aggregate.aggregateExpressions.forall(checkExpression) + aggregate.aggregateExpressions.forall { + case _: UnresolvedStar => + true + case other => + checkExpression(other) + } } private def checkJoin(join: Join) = { @@ -267,7 +283,8 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { private def checkUnresolvedInlineTable(unresolvedInlineTable: UnresolvedInlineTable) = unresolvedInlineTable.rows.forall(_.forall(checkExpression)) - private def checkUnresolvedRelation(unresolvedRelation: UnresolvedRelation) = true + private def checkUnresolvedRelation(unresolvedRelation: UnresolvedRelation) = + !unresolvedRelation.isStreaming private def checkResolvedInlineTable(resolvedInlineTable: ResolvedInlineTable) = resolvedInlineTable.rows.forall(_.forall(checkExpression)) @@ -306,8 +323,6 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { private def checkUnresolvedUpCast(upCast: UpCast) = checkExpression(upCast.child) - private def checkUnresolvedStar(unresolvedStar: UnresolvedStar) = true - private def checkUnresolvedAlias(unresolvedAlias: UnresolvedAlias) = checkExpression(unresolvedAlias.child) @@ -331,6 +346,7 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { } private def checkUnresolvedFunction(unresolvedFunction: UnresolvedFunction) = + unresolvedFunction.nameParts.size == 1 && !ResolverGuard.UNSUPPORTED_FUNCTION_NAMES.contains(unresolvedFunction.nameParts.head) && // UDFs are not supported FunctionRegistry.functionSet.contains( @@ -358,6 +374,9 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { private def checkGetViewColumnBynameAndOrdinal( getViewColumnByNameAndOrdinal: GetViewColumnByNameAndOrdinal) = true + private def checkSemiStructuredExtract(semiStructuredExtract: SemiStructuredExtract) = + checkExpression(semiStructuredExtract.child) + private def checkRepartition(repartition: Repartition) = { checkOperator(repartition.child) } @@ -365,6 +384,10 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { private def checkHaving(having: UnresolvedHaving) = checkExpression(having.havingCondition) && checkOperator(having.child) + private def checkSample(sample: Sample) = { + checkOperator(sample.child) + } + /** * Most of the expressions come from resolving the [[UnresolvedFunction]], but here we have some * popular expressions allowlist for two reasons: @@ -417,8 +440,8 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { _: RegExpCount | _: RegExpSubStr | _: RegExpInStr => true // JSON - case _: JsonToStructs | _: StructsToJson | _: SchemaOfJson | _: JsonObjectKeys | - _: LengthOfJsonArray => + case _: GetJsonObject | _: JsonTuple | _: JsonToStructs | _: StructsToJson | + _: SchemaOfJson | _: JsonObjectKeys | _: LengthOfJsonArray => true // CSV case _: SchemaOfCsv | _: StructsToCsv | _: CsvToStructs => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverMetricTracker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverMetricTracker.scala index 680360836eb7d..0e92684dc2a45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverMetricTracker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverMetricTracker.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.rules.QueryExecutionMetering * Trait for tracking and logging timing metrics for single-pass resolver. */ trait ResolverMetricTracker { - private val profilerGroup: String = getClass.getSimpleName /** * Log top-level timing metrics for single-pass analyzer. In order to utilize existing logging diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala index 37d41919f1323..fa06d39f13ccb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverRunner.scala @@ -18,7 +18,11 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.{QueryPlanningTracker, SQLConfHelper} -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, CleanupAliases} +import org.apache.spark.sql.catalyst.analysis.{ + AnalysisContext, + CleanupAliases, + PullOutNondeterministic +} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -41,7 +45,8 @@ class ResolverRunner( */ private val planRewriteRules: Seq[Rule[LogicalPlan]] = Seq( PruneMetadataColumns, - CleanupAliases + CleanupAliases, + PullOutNondeterministic ) /** @@ -50,6 +55,11 @@ class ResolverRunner( */ private val planRewriter = new PlanRewriter(planRewriteRules, extendedRewriteRules) + /** + * `resolutionCheckRunner` is used to run `extendedResolutionChecks` on the resolved plan. + */ + private val resolutionCheckRunner = new ResolutionCheckRunner(extendedResolutionChecks) + /** * Entry point for the resolver. This method performs following 4 steps: * - Resolves the plan in a bottom-up using [[Resolver]], single-pass manner. @@ -69,7 +79,7 @@ class ResolverRunner( runValidator(rewrittenPlan) - runExtendedResolutionChecks(rewrittenPlan) + resolutionCheckRunner.runWithSubqueries(rewrittenPlan) rewrittenPlan } @@ -82,12 +92,4 @@ class ResolverRunner( validator.validatePlan(plan) } } - - private def runExtendedResolutionChecks(plan: LogicalPlan): Unit = { - if (conf.getConf(SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_RUN_EXTENDED_RESOLUTION_CHECKS)) { - for (check <- extendedResolutionChecks) { - check(plan) - } - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesNameByHiddenOutput.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesNameByHiddenOutput.scala index 06a93910f2c59..cab55ea3b66af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesNameByHiddenOutput.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolvesNameByHiddenOutput.scala @@ -21,22 +21,11 @@ import java.util.HashSet import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{ - Attribute, - AttributeReference, - ExprId, - NamedExpression, - PipeOperator -} -import org.apache.spark.sql.catalyst.plans.logical.{ - Aggregate, - Distinct, - LogicalPlan, - Project, - SubqueryAlias, - UnaryNode -} +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{ExprId, NamedExpression, PipeOperator} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf /** * [[ResolvesNameByHiddenOutput]] is used by resolvers for operators that are able to resolve @@ -175,7 +164,7 @@ import org.apache.spark.sql.catalyst.util._ * * In the plan you can see that `col2` is added to the lower [[Project.projectList]]. */ -trait ResolvesNameByHiddenOutput { +trait ResolvesNameByHiddenOutput extends SQLConfHelper { /** * Insert the missing expressions in the output list of the operator. Recursively call @@ -242,24 +231,27 @@ trait ResolvesNameByHiddenOutput { missingExpressions = missingExpressions ) - val missingAttributes = filteredMissingExpressions.collect { - case attribute: AttributeReference => attribute - } + if (filteredMissingExpressions.nonEmpty) { + val (metadataCols, nonMetadataCols) = + operatorOutput.partition(_.toAttribute.qualifiedAccessOnly) - val expandedChild = insertMissingExpressions(operator.child, missingAttributes) + operator match { + case aggregate: Aggregate => + val newAggregateList = nonMetadataCols ++ filteredMissingExpressions ++ metadataCols + aggregate.copy(aggregateExpressions = newAggregateList) + case project: Project => + val expandedChild = insertMissingExpressions( + operator = operator.child, + missingExpressions = filteredMissingExpressions + ) + val newProjectList = + nonMetadataCols ++ filteredMissingExpressions.map(_.toAttribute) ++ metadataCols - val (metadataCols, nonMetadataCols) = - operatorOutput.partition(_.toAttribute.qualifiedAccessOnly) - - val newOutputList = nonMetadataCols ++ filteredMissingExpressions ++ metadataCols - val newOperator = operator match { - case aggregate: Aggregate => - aggregate.copy(aggregateExpressions = newOutputList, child = expandedChild) - case project: Project => - project.copy(projectList = newOutputList, child = expandedChild) + project.copy(projectList = newProjectList, child = expandedChild) + } + } else { + operator } - - newOperator } private def filterMissingExpressions( @@ -300,12 +292,43 @@ trait ResolvesNameByHiddenOutput { * because they may be needed in upper operators (if not, they will be pruned away in * [[PruneMetadataColumns]]). Other hidden attributes are thrown away, because we cannot * reference them from the new [[Project]] (they are not outputted from below). + * + * If [[SQLConf.SINGLE_PASS_RESOLVER_PREVENT_USING_ALIASES_FROM_NON_DIRECT_CHILDREN]] is set to + * true, we need to overwrite the current scope and clear `aggregateListAliases` and + * `baseAggregate`. This is needed in order to prevent later replacement of Sort/Having + * expressions using semantically equal aliased expressions from non-direct children. For + * example, in the following query: + * + * {{{ SELECT col1 AS a FROM VALUES(1,2) GROUP BY col1, col2 HAVING col2 > 1 ORDER BY col1; }}} + * + * With flag set to false, analyzed plan will be: + * + * Sort [a#3 ASC NULLS FIRST], true + * +- Project [a#3] + * +- Filter (col2#2 > 1) + * +- Aggregate [col1#1, col2#2], [col1#1 AS a#3, col2#2, col1#1] + * +- LocalRelation [col1#1, col2#2] + * + * Instead of using missing attribute `col1#1` we can use its alias `a#3` in the [[Sort]] and + * avoid adding an extra projection. This is because all of [[Sort]], [[Project]], [[Filter]] and + * [[Aggregate]] belong to the same [[NameScope]] since [[Project]] was artificially inserted. + * + * However, fixed-point can't handle this case properly and produces the following plan: + * + * Project [a#3] + * +- Sort [col1#1 ASC NULLS FIRST], true + * +- Project [a#3, col1#1] + * +- Filter (col2#2 > 1) + * +- Aggregate [col1#1, col2#2], [col1#1 AS a#3, col2#2, col1#1] + * +- LocalRelation [col1#1, col2#2] + * + * Therefore, we need to match this behavior of fixed-point in single-pass in order to avoid + * logical plan mismatches. */ def retainOriginalOutput( operator: LogicalPlan, missingExpressions: Seq[NamedExpression], - output: Seq[Attribute], - hiddenOutput: Seq[Attribute]): LogicalPlan = { + scopes: NameScopeStack): LogicalPlan = { if (missingExpressions.isEmpty) { operator } else { @@ -314,17 +337,29 @@ trait ResolvesNameByHiddenOutput { missingExpressionIds.add(expression.exprId) } - val hiddenOutputToPreserve = hiddenOutput.filter { hiddenAttribute => + val hiddenOutputToPreserve = scopes.current.hiddenOutput.filter { hiddenAttribute => hiddenAttribute.qualifiedAccessOnly && missingExpressionIds.contains( hiddenAttribute.exprId ) } val project = Project( - projectList = output ++ hiddenOutputToPreserve, + projectList = scopes.current.output ++ hiddenOutputToPreserve, child = operator ) + if (conf.getConf( + SQLConf.ANALYZER_SINGLE_PASS_RESOLVER_PREVENT_USING_ALIASES_FROM_NON_DIRECT_CHILDREN + )) { + scopes.overwriteCurrent( + output = Some(scopes.current.output), + hiddenOutput = Some(scopes.current.hiddenOutput), + availableAliases = Some(scopes.current.availableAliases), + aggregateListAliases = Seq.empty, + baseAggregate = None + ) + } + project } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RewritesAliasesInTopLcaProject.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RewritesAliasesInTopLcaProject.scala new file mode 100644 index 0000000000000..1fc9d579f6324 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/RewritesAliasesInTopLcaProject.scala @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import java.util.{HashMap, HashSet} + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + AttributeReference, + Expression, + ExprId, + NamedExpression +} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, ATTRIBUTE_REFERENCE} + +/** + * During LCA resolution some aliases may be rewritten as new aliases with new [[ExprId]]s. This + * trait handles remapping of old aliases to new ones, when these attributes appear in + * [[SortOrder]] expressions and Having conditions. + */ +trait RewritesAliasesInTopLcaProject { + + /** + * When resolving lateral column references in [[Aggregate]] below [[Sort]] or HAVING operators, + * fixed-point first resolves [[SortOrder]] expressions and HAVING conditions using + * [[TempResolvedColumn]] and only after that resolves lateral column references. For example, + * consider the following query: + * + * {{{ SELECT avg(col1) AS a, a AS b FROM VALUES(1,2,3) GROUP BY col2 ORDER BY max(col3) }}} + * + * Fixed-point plan before resolving [[SortOrder]]: + * + * Sort [max(tempresolvedcolumn(col3#5, col3, false)) ASC NULLS FIRST], true + * +- Aggregate [col2#4], [avg(col1#3) AS a#6, lateralAliasReference(a) AS b#7] + * +- LocalRelation [col1#3, col2#4, col3#5] + * + * After resolving [[TempResolvedColumn]]: + * + * Project [a#6, b#7] + * +- Sort [max(col3)#10 ASC NULLS FIRST], true + * +- Aggregate [col2#4], [avg(col1#3) AS a#6, lca(a) AS b#7, max(col3#5) AS max(col3)#10] + * +- LocalRelation [col1#3, col2#4, col3#5] + * + * In the above case fixed-point first resolves [[SortOrder]] to `max(col3)#10` and only then + * resolves LCAs. However, while resolving LCAs in [[Aggregate]], fixed-point first constructs + * a base [[Aggregate]] by pushing down all aggregate expressions with new aliases. It then + * places a [[Project]] on top reinstating the original alias on top of a newly created one, + * in order to still match the attribute reference from [[SortOrder]]: + * + * Project [a#6, b#7] + * +- Sort [max(col3)#10 ASC NULLS FIRST], true + * +- Project [avg(col1)#11 AS a#6, lca(a) AS b#7, max(col3)#12 AS max(col3)#10] + * +- Aggregate [col2#4], [avg(col1#3) AS avg(col1)#11, max(col3#5) AS max(col3)#12] + * +- LocalRelation [col1#3, col2#4, col3#5] + * + * In the example above, `max(col3#5)` gets pushed down and aliased as `max(col3)#12`, even + * though `max(col3)#10` attribute reference already exists. Because of that `max(col3)#12` needs + * to be remapped back to `max(col3)#10`. + * + * However, in single-pass analyzer, we will first resolve all lateral column references before + * starting the resolution of [[SortOrder]] resulting in the following plan: + * + * Project [a#6, b#7] + * +- Sort [max(col3)#16 ASC NULLS FIRST], true + * +- Project [a#6, a#6 AS b#7, max(col3)#16] + * +- Project [avg(col1)#14, avg(col1)#14 AS a#6, max(col3)#16] + * +- Aggregate [col2#4], [avg(col1#3) AS avg(col1)#14, max(col3#5) AS max(col3)#16] + * +- LocalRelation [col1#3, col2#4, col3#5] + * + * In the above case, rewriting `max(col3)#16` with an [[Alias]] is not necessary from + * correctness perspective, but we need to do it in order to stay compatible with fixed-point + * analyzer. Because fixed-point only regenerates aliases from original aggregate list, in + * single-pass we need to handle the following: + * 1. all aliases from top-level [[Project]] (because they originate from the unresolved + * aggregate list); + * 2. all references to aliases from the base aggregate (because they are became attribute + * references during LCA resolution); + * + * This same issue also applies to HAVING resolution. + */ + def rewriteNamedExpressionsInTopLcaProject[ExpressionType <: Expression]( + projectToRewrite: Project, + baseAggregate: Aggregate, + expressionsToRewrite: Seq[ExpressionType], + rewriteCandidates: Seq[NamedExpression], + autoGeneratedAliasProvider: AutoGeneratedAliasProvider): (Project, Seq[ExpressionType]) = { + val candidateExpressions = getCandidateExpressionsForRewrite( + baseAggregate = baseAggregate, + oldExpressions = rewriteCandidates, + autoGeneratedAliasProvider = autoGeneratedAliasProvider + ) + val newProject = rewriteNamedExpressionsInProject(projectToRewrite, candidateExpressions) + val newExpressions = updateAttributeReferencesInExpressions[ExpressionType]( + expressionsToRewrite, + candidateExpressions + ) + + (newProject, newExpressions) + } + + /** + * When resolving [[Sort]] or Having on top of an [[Aggregate]] that has lateral column + * references, aggregate and grouping expressions might not be correctly replaced in + * [[SortOrder]] and HAVING condition, because of [[Project]] nodes created when resolving + * lateral column references. Because of that, we need to additionally try and replace + * [[SortOrder]] expressions and HAVING conditions that don't appear in the child [[Project]], + * but the aliases of semantically equivalent expressions do. In case both the attribute and its + * alias exist in the output, don't replace the attribute in [[SortOrder]] / HAVING condition, + * because there is no missing input in that case. + * For example, consider the following query: + * + * {{{ SELECT col1 AS a, a FROM VALUES(1) GROUP BY col1 ORDER BY col1 }}} + * + * After resolving lateral column references and partially resolving [[SortOrder]] expression, we + * get the following plan: + * + * !Sort [col1#3 ASC NULLS FIRST], true + * +- Project [a#4, a#4] + * +- Project [col1#3, col1#3 AS a#4] + * +- Aggregate [col1#3], [col1#3] + * +- LocalRelation [col1#3] + * + * In the above plan, [[Sort]] has a missing input `col1#3`. Because of LCA resolution this + * attribute is pushed down into the [[Project]] stack and aliased as `a#4`. Instead of using + * `col1#3` we can reference its semantically equivalent alias `a#4` in the [[SortOrder]]. The + * resolved plan looks like: + * + * Sort [a#4 ASC NULLS FIRST], true + * +- Project [a#4, a#4] + * +- Project [col1#3, col1#3 AS a#4] + * +- Aggregate [col1#3], [col1#3] + * +- LocalRelation [col1#3] + * + * Because we used `a#4` alias instead of `col1#3`, we do not need to insert `col1#3` to the + * child [[Project]] as a missing expression. Therefore, `missingExpressions` need to be updated + * in order not to insert unnecessary attributes in + * [[ResolvesNameByHiddenOutput.insertMissingExpressions]] + * + * However, for a query like: + * + * {{{ SELECT col1, col1 AS a FROM VALUES(1) GROUP BY col1 ORDER BY col1 }}} + * + * The resolved plan will be: + * + * Sort [col1#4 ASC NULLS FIRST], true + * +- Aggregate [col1#4], [col1#4, col1#4 AS a#5] + * +- LocalRelation [col1#4] + * + * In the above example, we do not replace `col1#4` with `a#5` because `col1#4` is present in the + * output. + */ + def tryReplaceSortOrderOrHavingConditionWithAlias( + sortOrderOrCondition: Expression, + scopes: NameScopeStack, + missingExpressions: Seq[NamedExpression]): (Expression, Seq[NamedExpression]) = { + val replacedAttributeReferences = new HashSet[ExprId] + val expressionWithReplacedAliases = sortOrderOrCondition.transformDownWithPruning( + _.containsAnyPattern(AGGREGATE_EXPRESSION, ATTRIBUTE_REFERENCE) + ) { + case attributeReference: AttributeReference => + scopes.current.aggregateListAliases + .collectFirst { + case alias + if alias.child.semanticEquals(attributeReference) && + scopes.current.getAttributeById(attributeReference.exprId).isEmpty => + replacedAttributeReferences.add(attributeReference.exprId) + alias.toAttribute + } + .getOrElse(attributeReference) + case aggregateExpression: AggregateExpression => + scopes.current.aggregateListAliases + .collectFirst { + case alias if alias.child.semanticEquals(aggregateExpression) => + alias.toAttribute + } + .getOrElse(aggregateExpression) + } + val filteredMissingExpressions = missingExpressions.filter( + expression => !replacedAttributeReferences.contains(expression.exprId) + ) + + (expressionWithReplacedAliases, filteredMissingExpressions) + } + + private def getCandidateExpressionsForRewrite( + baseAggregate: Aggregate, + oldExpressions: Seq[NamedExpression], + autoGeneratedAliasProvider: AutoGeneratedAliasProvider): HashMap[ExprId, NamedExpression] = { + val expressionsToRewrite = new HashMap[ExprId, NamedExpression](oldExpressions.size) + val baseAggregateOutputLookup = new HashSet[ExprId](baseAggregate.aggregateExpressions.size) + baseAggregate.aggregateExpressions.foreach { + case alias: Alias => baseAggregateOutputLookup.add(alias.exprId) + case _ => + } + oldExpressions.foreach { + case oldAlias: Alias => + expressionsToRewrite.put( + oldAlias.exprId, + autoGeneratedAliasProvider.newAlias(oldAlias.toAttribute) + ) + case oldAttributeReference: AttributeReference + if baseAggregateOutputLookup.contains(oldAttributeReference.exprId) => + expressionsToRewrite.put( + oldAttributeReference.exprId, + autoGeneratedAliasProvider.newAlias(oldAttributeReference.toAttribute) + ) + case other => expressionsToRewrite.put(other.exprId, other) + } + + expressionsToRewrite + } + + private def rewriteNamedExpressionsInProject( + project: Project, + candiidateExpressions: HashMap[ExprId, NamedExpression]): Project = { + val newProjectList = project.projectList.map { + case namedExpression: NamedExpression => + candiidateExpressions.getOrDefault(namedExpression.exprId, namedExpression) + case other => other + } + project.copy(projectList = newProjectList) + } + + private def updateAttributeReferencesInExpressions[ExpressionType <: Expression]( + expressions: Seq[ExpressionType], + candidateAliases: HashMap[ExprId, NamedExpression] + ): Seq[ExpressionType] = { + expressions.map { expression => + expression + .transformDownWithPruning(_.containsPattern(ATTRIBUTE_REFERENCE)) { + case attributeReference: AttributeReference => + val newAliasOrOldAttribute = + candidateAliases.getOrDefault(attributeReference.exprId, attributeReference) + newAliasOrOldAttribute.toAttribute + } + .asInstanceOf[ExpressionType] + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SemiStructuredExtractResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SemiStructuredExtractResolver.scala new file mode 100644 index 0000000000000..a5f20ae4abc7a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SemiStructuredExtractResolver.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SemiStructuredExtract} +import org.apache.spark.sql.catalyst.expressions.variant.VariantGet +import org.apache.spark.sql.types.VariantType +import org.apache.spark.unsafe.types.UTF8String + +/** + * Resolver for [[SemiStructuredExtract]]. Resolves [[SemiStructuredExtract]] by resolving its + * children, replacing it with the proper semi-structured field extraction method and applying type + * coercion to the result. + */ +class SemiStructuredExtractResolver(expressionResolver: ExpressionResolver) + extends TreeNodeResolver[SemiStructuredExtract, Expression] + with ResolvesExpressionChildren + with CoercesExpressionTypes { + + private val timezoneAwareExpressionResolver = + expressionResolver.getTimezoneAwareExpressionResolver + + /** + * Resolves children and replaces [[SemiStructuredExtract]] expressions with the proper + * semi-structured field extraction method depending on column type. In case the column is of + * [[VariantType]], applies timezone to the result of the previous step. + * + * Currently only JSON is supported as an extraction method. An important distinction here with + * other JSON extraction methods is that the extraction fields provided here should be + * case-insensitive, unless explicitly stated through quoting. + * + * After replacing with proper extraction method, apply type coercion to the result. + */ + override def resolve(semiStructuredExtract: SemiStructuredExtract): Expression = { + val semiStructuredExtractWithResolvedChildren = + withResolvedChildren(semiStructuredExtract, expressionResolver.resolve _) + .asInstanceOf[SemiStructuredExtract] + + val semiStructuredExtractWithProperExtractionMethod = + semiStructuredExtractWithResolvedChildren.child.dataType match { + case _: VariantType => + val extractResult = VariantGet( + child = semiStructuredExtractWithResolvedChildren.child, + path = Literal(UTF8String.fromString(semiStructuredExtractWithResolvedChildren.field)), + targetType = VariantType, + failOnError = true + ) + timezoneAwareExpressionResolver.resolve(extractResult) + case _ => + throw new AnalysisException( + errorClass = "COLUMN_IS_NOT_VARIANT_TYPE", + messageParameters = Map.empty + ) + } + + coerceExpressionTypes( + expression = semiStructuredExtractWithProperExtractionMethod, + expressionTreeTraversal = expressionResolver.getExpressionTreeTraversals.current + ) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SetOperationLikeResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SetOperationLikeResolver.scala index 695413eaf8434..35ec8efaf8301 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SetOperationLikeResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SetOperationLikeResolver.scala @@ -22,17 +22,10 @@ import java.util.HashSet import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion, TypeCoercionBase} import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, ExprId} -import org.apache.spark.sql.catalyst.plans.logical.{ - Except, - Intersect, - LogicalPlan, - Project, - SetOperation, - Union -} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{DataType, MapType, MetadataBuilder, VariantType} +import org.apache.spark.sql.types.{DataType, MetadataBuilder} /** * The [[SetOperationLikeResolver]] performs [[Union]], [[Intersect]] or [[Except]] operator @@ -53,16 +46,17 @@ class SetOperationLikeResolver(resolver: Resolver, expressionResolver: Expressio * - Create a new mapping in [[ExpressionIdAssigner]] for the current operator. We only need the * left child mapping, because that's the only child whose expression IDs get propagated * upwards for [[Union]], [[Intersect]] or [[Except]]. This is an optimization. - * - Perform individual output deduplication to handle the distinct union case described in - * [[performIndividualOutputExpressionIdDeduplication]] scaladoc. - * - Validate that child outputs have same length or throw "NUM_COLUMNS_MISMATCH" otherwise. * - Compute widened data types for child output attributes using * [[getTypeCoercion.findWiderTypeForTwo]] or throw "INCOMPATIBLE_COLUMN_TYPE" if coercion * fails. + * - Perform individual output deduplication to handle the distinct union case described in + * [[performIndividualOutputExpressionIdDeduplication]] scaladoc. + * - Validate that child outputs have same length or throw "NUM_COLUMNS_MISMATCH" otherwise. * - Add [[Project]] with [[Cast]] on children needing attribute data type widening. * - Assert that coerced outputs don't have conflicting expression IDs. * - Merge transformed outputs using a separate logic for each operator type. * - Store merged output in current [[NameScope]]. + * - Validate that the operator doesn't have unsupported data types in the output * - Create a new mapping in [[ExpressionIdAssigner]] using the coerced and validated outputs. * - Return the resolved operator with new children optionally wrapped in [[WithCTE]]. See * [[CteScope]] scaladoc for more info. @@ -74,30 +68,32 @@ class SetOperationLikeResolver(resolver: Resolver, expressionResolver: Expressio newOutputIds = childScopes.head.getOutputIds ) - val (deduplicatedChildren, deduplicatedChildOutputs) = - performIndividualOutputExpressionIdDeduplication( - resolvedChildren, - childScopes.map(_.output), - unresolvedOperator - ) + val childOutputs = childScopes.map(_.output) - val (newChildren, newChildOutputs) = - if (needToCoerceChildOutputs(deduplicatedChildOutputs, unresolvedOperator)) { + val (coercedChildren, coercedChildOutputs) = + if (needToCoerceChildOutputs(childOutputs, unresolvedOperator)) { coerceChildOutputs( - deduplicatedChildren, - deduplicatedChildOutputs, - validateAndDeduceTypes(unresolvedOperator, deduplicatedChildOutputs) + resolvedChildren, + childOutputs, + validateAndDeduceTypes(unresolvedOperator, childOutputs) ) } else { - (deduplicatedChildren, deduplicatedChildOutputs) + (resolvedChildren, childOutputs) } + val (newChildren, newChildOutputs) = + performIndividualOutputExpressionIdDeduplication( + coercedChildren, + coercedChildOutputs, + unresolvedOperator + ) + ExpressionIdAssigner.assertOutputsHaveNoConflictingExpressionIds(newChildOutputs) val output = mergeChildOutputs(unresolvedOperator, newChildOutputs) scopes.overwriteCurrent(output = Some(output), hiddenOutput = Some(output)) - validateOutputs(unresolvedOperator, output) + OperatorWithUncomparableTypeValidator.validate(unresolvedOperator, output) val resolvedOperator = unresolvedOperator.withNewChildren(newChildren) @@ -362,24 +358,6 @@ class SetOperationLikeResolver(resolver: Resolver, expressionResolver: Expressio } } - /** - * Validate outputs of [[SetOperation]]. - * - [[MapType]] and [[VariantType]] are currently not supported for [[SetOperations]] and we need - * to throw a relevant user-facing error. - */ - private def validateOutputs(unresolvedPlan: LogicalPlan, output: Seq[Attribute]): Unit = { - unresolvedPlan match { - case _: SetOperation => - output.find(a => hasMapType(a.dataType)).foreach { mapCol => - throwUnsupportedSetOperationOnMapType(mapCol, unresolvedPlan) - } - output.find(a => hasVariantType(a.dataType)).foreach { variantCol => - throwUnsupportedSetOperationOnVariantType(variantCol, unresolvedPlan) - } - case _ => - } - } - private def getTypeCoercion: TypeCoercionBase = { if (conf.ansiEnabled) { AnsiTypeCoercion @@ -388,24 +366,6 @@ class SetOperationLikeResolver(resolver: Resolver, expressionResolver: Expressio } } - private def throwUnsupportedSetOperationOnMapType( - mapCol: Attribute, - unresolvedPlan: LogicalPlan): Unit = { - throw QueryCompilationErrors.unsupportedSetOperationOnMapType( - mapCol = mapCol, - origin = unresolvedPlan.origin - ) - } - - private def throwUnsupportedSetOperationOnVariantType( - variantCol: Attribute, - unresolvedPlan: LogicalPlan): Unit = { - throw QueryCompilationErrors.unsupportedSetOperationOnVariantType( - variantCol = variantCol, - origin = unresolvedPlan.origin - ) - } - private def throwNumColumnsMismatch( expectedNumColumns: Int, childColumnTypes: Seq[DataType], @@ -436,12 +396,4 @@ class SetOperationLikeResolver(resolver: Resolver, expressionResolver: Expressio origin = unresolvedOperator.origin ) } - - private def hasMapType(dt: DataType): Boolean = { - dt.existsRecursively(_.isInstanceOf[MapType]) - } - - private def hasVariantType(dt: DataType): Boolean = { - dt.existsRecursively(_.isInstanceOf[VariantType]) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SortResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SortResolver.scala index 3e271a324209e..0c7432af71936 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SortResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/SortResolver.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.catalyst.analysis.resolver -import java.util.{HashMap, LinkedHashMap} +import java.util.HashMap import scala.collection.mutable -import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.analysis.{ - NondeterministicExpressionCollection, - UnresolvedAttribute -} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{ Alias, Attribute, @@ -34,13 +30,14 @@ import org.apache.spark.sql.catalyst.expressions.{ NamedExpression, SortOrder } -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.plans.logical._ /** * Resolves a [[Sort]] by resolving its child and order expressions. */ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionResolver) extends TreeNodeResolver[Sort, LogicalPlan] + with RewritesAliasesInTopLcaProject with ResolvesNameByHiddenOutput { private val scopes: NameScopeStack = operatorResolver.getNameScopes private val autoGeneratedAliasProvider = new AutoGeneratedAliasProvider( @@ -102,8 +99,6 @@ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionRes * +- Aggregate [col1, (col2 + 1)], * [col1, sum(col1) AS sum(col1)#..., (col2 + 1) AS (col2 + 1)#...] * +- LocalRelation [col1, col2] - * 5. In case there are non-deterministic expressions in the order expressions, substitute them - * with derived attribute references to an artificial [[Project]] list. */ override def resolve(unresolvedSort: Sort): LogicalPlan = { val resolvedChild = operatorResolver.resolve(unresolvedSort.child) @@ -123,50 +118,52 @@ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionRes } else { val partiallyResolvedSort = unresolvedSort.copy(child = resolvedChild) - val (resolvedOrderExpressions, missingAttributes) = + val (partiallyResolvedOrderExpressions, missingAttributes) = resolveOrderExpressions(partiallyResolvedSort) - val (finalOrderExpressions, missingExpressions) = resolvedChild match { - case _ if scopes.current.hasLcaInAggregate => - throw new ExplicitlyUnsupportedResolverFeature( - "Lateral column alias in Aggregate below a Sort" - ) - case aggregate: Aggregate => - val (cleanedOrderExpressions, extractedExpressions) = - extractReferencedGroupingAndAggregateExpressions(aggregate, resolvedOrderExpressions) - (cleanedOrderExpressions, extractedExpressions) - case filter @ Filter(_, aggregate: Aggregate) => - val (cleanedOrderExpressions, extractedExpressions) = - extractReferencedGroupingAndAggregateExpressions(aggregate, resolvedOrderExpressions) - (cleanedOrderExpressions, extractedExpressions) - case project @ Project(_, Filter(_, aggregate: Aggregate)) => - throw new ExplicitlyUnsupportedResolverFeature( - "Project on top of HAVING below a Sort" + val (resolvedOrderExpressions, missingExpressions) = resolvedChild match { + case _ @(_: Aggregate | _: Filter | _: Project) if scopes.current.baseAggregate.isDefined => + extractReferencedGroupingAndAggregateExpressions( + scopes.current.baseAggregate.get, + partiallyResolvedOrderExpressions ) case other => - (resolvedOrderExpressions, missingAttributes) + (partiallyResolvedOrderExpressions, missingAttributes) } + val (resolvedOrderExpressionsWithAliasesReplaced, filteredMissingExpressions) = + tryReplaceSortOrderWithAlias(resolvedOrderExpressions, missingExpressions) + val resolvedChildWithMissingAttributes = - insertMissingExpressions(resolvedChild, missingExpressions) + insertMissingExpressions(resolvedChild, filteredMissingExpressions) + + val isChildChangedByMissingExpressions = !resolvedChildWithMissingAttributes.eq(resolvedChild) + + val (finalChild, finalOrderExpressions) = resolvedChildWithMissingAttributes match { + case project: Project if scopes.current.baseAggregate.isDefined => + rewriteNamedExpressionsInTopLcaProject[SortOrder]( + projectToRewrite = project, + baseAggregate = scopes.current.baseAggregate.get, + expressionsToRewrite = resolvedOrderExpressionsWithAliasesReplaced, + rewriteCandidates = missingExpressions, + autoGeneratedAliasProvider = autoGeneratedAliasProvider + ) + case other => (other, resolvedOrderExpressionsWithAliasesReplaced) + } val resolvedSort = unresolvedSort.copy( - child = resolvedChildWithMissingAttributes, + child = finalChild, order = finalOrderExpressions ) - val sortWithOriginalOutput = retainOriginalOutput( - operator = resolvedSort, - missingExpressions = missingExpressions, - output = scopes.current.output, - hiddenOutput = scopes.current.hiddenOutput - ) - - sortWithOriginalOutput match { - case project @ Project(_, sort: Sort) => - project.copy(child = tryPullOutNondeterministic(sort, childOutput = sort.child.output)) - case sort: Sort => - tryPullOutNondeterministic(sort, childOutput = scopes.current.output) + if (isChildChangedByMissingExpressions) { + retainOriginalOutput( + operator = resolvedSort, + missingExpressions = missingExpressions, + scopes = scopes + ) + } else { + resolvedSort } } } @@ -197,10 +194,6 @@ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionRes * SELECT col1 FROM VALUES(1,2) GROUP BY col1 HAVING col1 > 1 ORDER BY col2; * SELECT col1 FROM VALUES(1) ORDER BY col2; * }}} - * - * If the order expression is not present in the current scope, but an alias of this expression - * is, replace the order expression with its alias (see - * [[tryReplaceSortOrderExpressionWithAlias]]). */ private def resolveOrderExpressions( partiallyResolvedSort: Sort): (Seq[SortOrder], Seq[Attribute]) = { @@ -211,11 +204,9 @@ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionRes .resolveExpressionTreeInOperator(sortOrder, partiallyResolvedSort) .asInstanceOf[SortOrder] - tryReplaceSortOrderExpressionWithAlias(resolvedSortOrder).getOrElse { - referencedAttributes.putAll(expressionResolver.getLastReferencedAttributes) + referencedAttributes.putAll(expressionResolver.getLastReferencedAttributes) - resolvedSortOrder - } + resolvedSortOrder } val missingAttributes = scopes.current.resolveMissingAttributesByHiddenOutput( @@ -225,37 +216,6 @@ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionRes (resolvedSortOrder, missingAttributes) } - /** - * When resolving [[SortOrder]] on top of an [[Aggregate]], if there is an attribute that is - * present in `hiddenOutput` and there is an [[Alias]] of this attribute in the `output`, - * [[SortOrder]] should be resolved by the [[Alias]] instead of an attribute. This is done as - * optimization in order to avoid a [[Project]] node being added when resolving the attribute via - * missing input (because attribute is not present in direct output, only its alias is). - * - * For example, for a query like: - * - * {{{ - * SELECT col1 + 1 AS a FROM VALUES(1) GROUP BY a ORDER BY col1 + 1; - * }}} - * - * The resolved plan should be: - * - * Sort [a#2 ASC NULLS FIRST], true - * +- Aggregate [(col1#1 + 1)], [(col1#1 + 1) AS a#2] - * +- LocalRelation [col1#1] - * - * [[SortOrder]] expression is resolved to alias of `col1 + 1` instead of `col1 + 1` itself. - */ - private def tryReplaceSortOrderExpressionWithAlias(sortOrder: SortOrder): Option[SortOrder] = { - scopes.current.aggregateListAliases - .collectFirst { - case alias if alias.child.semanticEquals(sortOrder.child) => alias.toAttribute - } - .map { aliasCandidate => - sortOrder.withNewChildren(newChildren = Seq(aliasCandidate)).asInstanceOf[SortOrder] - } - } - /** * Extracts the referenced grouping and aggregate expressions from the order expressions. This is * used to update the output of the child operator and add a [[Project]] as a parent of [[Sort]] @@ -313,52 +273,33 @@ class SortResolver(operatorResolver: Resolver, expressionResolver: ExpressionRes } /** - * In case there are non-deterministic expressions in `order` expressions replace them with - * attributes created out of corresponding non-deterministic expression. Example: - * - * {{{ SELECT 1 ORDER BY RAND(); }}} - * - * This query would have the following analyzed plan: - * - * Project [1] - * +- Sort [_nondeterministic ASC NULLS FIRST], true - * +- Project [1, rand(...) AS _nondeterministic#...] - * +- Project [1 AS 1#...] - * +- OneRowRelation - * - * We use `childOutput` instead of directly calling `scopes.current.output`, because - * [[insertMissingExpressions]] could have changed the output of the child operator. - * We could just call `sort.child.output`, but this is suboptimal for the simple case when - * [[Sort]] child is left unchanged, and in that case we actually call `scopes.current.output`. - * See the call site in [[resolve]]. + * When resolving [[Sort]] on top of an [[Aggregate]] that has lateral column aliases, + * [[extractReferencedGroupingAndAggregateExpressions]] may not correctly replace all + * [[SortOrder]] expressions because of newly construct [[Project]] nodes coming from LCA + * resolution. This method replaces all [[SortOrder]] with their aliases if those expressions + * don't exist in child [[Project]] but the aliases do. + * For more details see [[tryReplaceSortOrderOrHavingConditionWithAlias]]. */ - private def tryPullOutNondeterministic(sort: Sort, childOutput: Seq[Attribute]): LogicalPlan = { - val nondeterministicToAttributes: LinkedHashMap[Expression, NamedExpression] = - NondeterministicExpressionCollection.getNondeterministicToAttributes( - sort.order.map(_.child) - ) + private def tryReplaceSortOrderWithAlias( + orderExpressions: Seq[SortOrder], + missingExpressions: Seq[NamedExpression] + ): (Seq[SortOrder], Seq[NamedExpression]) = { + val replacedOrderExpressions = new mutable.ArrayBuffer[SortOrder] + var currentMissingExpressions = missingExpressions - if (!nondeterministicToAttributes.isEmpty) { - val newChild = Project( - childOutput ++ nondeterministicToAttributes.values.asScala.toSeq, - sort.child - ) - val resolvedOrder = sort.order.map { sortOrder => - sortOrder.copy( - child = PullOutNondeterministicExpressionInExpressionTree( - sortOrder.child, - nondeterministicToAttributes - ) + orderExpressions.map { orderExpression => + val (replacedOrder, updatedMissingExpressions) = + tryReplaceSortOrderOrHavingConditionWithAlias( + sortOrderOrCondition = orderExpression, + scopes = scopes, + missingExpressions = currentMissingExpressions ) - } - val resolvedSort = sort.copy( - order = resolvedOrder, - child = newChild - ) - Project(projectList = childOutput, child = resolvedSort) - } else { - sort + + replacedOrderExpressions += replacedOrder.asInstanceOf[SortOrder] + currentMissingExpressions = updatedMissingExpressions } + + (replacedOrderExpressions.toSeq, currentMissingExpressions) } private def canOrderByAll(expressions: Seq[SortOrder]): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimestampAddResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimestampAddResolver.scala deleted file mode 100644 index 3dc665a6d88b4..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimestampAddResolver.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis.resolver - -import org.apache.spark.sql.catalyst.analysis.{ - AnsiStringPromotionTypeCoercion, - AnsiTypeCoercion, - StringPromotionTypeCoercion, - TypeCoercion -} -import org.apache.spark.sql.catalyst.expressions.{Expression, TimestampAddInterval} - -/** - * Helper resolver for [[TimestampAddInterval]] which is produced by resolving [[BinaryArithmetic]] - * nodes. - */ -class TimestampAddResolver(expressionResolver: ExpressionResolver) - extends TreeNodeResolver[TimestampAddInterval, Expression] - with ResolvesExpressionChildren - with CoercesExpressionTypes { - - private val traversals = expressionResolver.getExpressionTreeTraversals - - protected override val ansiTransformations: CoercesExpressionTypes.Transformations = - TimestampAddResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS - protected override val nonAnsiTransformations: CoercesExpressionTypes.Transformations = - TimestampAddResolver.TYPE_COERCION_TRANSFORMATIONS - - override def resolve(unresolvedTimestampAdd: TimestampAddInterval): Expression = { - val timestampAddWithResolvedChildren = - withResolvedChildren(unresolvedTimestampAdd, expressionResolver.resolve _) - val timestampAddWithTypeCoercion: Expression = coerceExpressionTypes( - expression = timestampAddWithResolvedChildren, - expressionTreeTraversal = traversals.current - ) - TimezoneAwareExpressionResolver.resolveTimezone( - timestampAddWithTypeCoercion, - traversals.current.sessionLocalTimeZone - ) - } -} - -object TimestampAddResolver { - // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - StringPromotionTypeCoercion.apply, - TypeCoercion.ImplicitTypeCoercion.apply, - TypeCoercion.DateTimeOperationsTypeCoercion.apply - ) - - // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - AnsiStringPromotionTypeCoercion.apply, - AnsiTypeCoercion.ImplicitTypeCoercion.apply, - AnsiTypeCoercion.AnsiDateTimeOperationsTypeCoercion.apply - ) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala index c084932813c29..712efcef5e7a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolver.scala @@ -123,6 +123,22 @@ object TimezoneAwareExpressionResolver { /** * Applies a timezone to a [[TimeZoneAwareExpression]] while preserving original tags. * + * Method is applied recursively to all the nested [[TimeZoneAwareExpression]]s which lack a + * timezone until we find one which has it. This is because sometimes type coercion rules (or + * other code) can produce multiple [[Cast]]s on top of an expression. For example: + * + * {{{ SELECT NANVL(1, null); }}} + * + * Plan: + * + * {{{ + * Project [nanvl(cast(1 as double), cast(cast(null as int) as double)) AS nanvl(1, NULL)#0] + * +- OneRowRelation + * }}} + * + * As it can be seen, there are multiple nested [[Cast]] nodes and timezone should be applied to + * all of them. + * * This method is particularly useful for cases like resolving [[Cast]] expressions where tags * such as [[USER_SPECIFIED_CAST]] need to be preserved. * @@ -133,7 +149,13 @@ object TimezoneAwareExpressionResolver { def resolveTimezone(expression: Expression, timeZoneId: String): Expression = { expression match { case timezoneExpression: TimeZoneAwareExpression if timezoneExpression.timeZoneId.isEmpty => - val withTimezone = timezoneExpression.withTimeZone(timeZoneId) + val childrenWithTimeZone = timezoneExpression.children.map { child => + resolveTimezone(child, timeZoneId) + } + val withNewChildren = timezoneExpression + .withNewChildren(childrenWithTimeZone) + .asInstanceOf[TimeZoneAwareExpression] + val withTimezone = withNewChildren.withTimeZone(timeZoneId) withTimezone.copyTagsFrom(timezoneExpression) withTimezone case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala deleted file mode 100644 index 48ceb7e10ebd5..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnaryMinusResolver.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis.resolver - -import org.apache.spark.sql.catalyst.analysis.{AnsiTypeCoercion, TypeCoercion} -import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryMinus} - -/** - * Resolver for [[UnaryMinus]]. Resolves children and applies type coercion to target node. - */ -class UnaryMinusResolver(expressionResolver: ExpressionResolver) - extends TreeNodeResolver[UnaryMinus, Expression] - with ResolvesExpressionChildren - with CoercesExpressionTypes { - - private val traversals = expressionResolver.getExpressionTreeTraversals - - protected override val ansiTransformations: CoercesExpressionTypes.Transformations = - UnaryMinusResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS - protected override val nonAnsiTransformations: CoercesExpressionTypes.Transformations = - UnaryMinusResolver.TYPE_COERCION_TRANSFORMATIONS - - override def resolve(unresolvedUnaryMinus: UnaryMinus): Expression = { - val unaryMinusWithResolvedChildren = - withResolvedChildren(unresolvedUnaryMinus, expressionResolver.resolve _) - coerceExpressionTypes( - expression = unaryMinusWithResolvedChildren, - expressionTreeTraversal = traversals.current - ) - } -} - -object UnaryMinusResolver { - // Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]]. - private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - TypeCoercion.ImplicitTypeCoercion.apply, - TypeCoercion.DateTimeOperationsTypeCoercion.apply - ) - - // Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]]. - private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq( - AnsiTypeCoercion.ImplicitTypeCoercion.apply, - AnsiTypeCoercion.AnsiDateTimeOperationsTypeCoercion.apply - ) -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnsupportedExpressionInOperatorValidation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnsupportedExpressionInOperatorValidation.scala index ae0b5d4a48019..def4e3c30a6c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnsupportedExpressionInOperatorValidation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/UnsupportedExpressionInOperatorValidation.scala @@ -19,16 +19,7 @@ package org.apache.spark.sql.catalyst.analysis.resolver import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{ - Aggregate, - BaseEvalPythonUDTF, - CollectMetrics, - Generate, - LateralJoin, - LogicalPlan, - Project, - Window -} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.errors.QueryCompilationErrors object UnsupportedExpressionInOperatorValidation { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala index 3470bed9cfb2e..ad1926772e7f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ViewResolver.scala @@ -165,6 +165,7 @@ class ViewResolver(resolver: Resolver, catalogManager: CatalogManager) * * @param nestedViewDepth Current nested view depth. Cannot exceed the `maxNestedViewDepth`. * @param maxNestedViewDepth Maximum allowed nested view depth. Configured in the upper context + * based on [[SQLConf.MAX_NESTED_VIEW_DEPTH]]. * @param collation View's default collation if explicitly set. * @param catalogAndNamespace Catalog and camespace under which the [[View]] was created. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 12400d66f4442..5cdbdf3f0e7c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -98,6 +98,21 @@ object ExtractValue { } } + /** + * Check that [[attribute]] can be fully extracted using the given [[nestedFields]]. + */ + def isExtractable( + attribute: Attribute, nestedFields: Seq[String], resolver: Resolver): Boolean = { + nestedFields + .foldLeft(Some(attribute): Option[Expression]) { + case (Some(expression), field) => + ExtractValue.extractValue(expression, Literal(field), resolver) + case _ => + None + } + .isDefined + } + /** * Find the ordinal of StructField, report error if no desired field or over one * desired fields are found. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b17c4147b951b..e076f4ede0515 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -346,6 +346,31 @@ object SQLConf { .doubleConf .createWithDefault(if (Utils.isTesting) 1.0 else 0.001) + val ANALYZER_SINGLE_PASS_RESOLVER_EXPOSE_RESOLVER_GUARD_FAILURE = + buildConf("spark.sql.analyzer.singlePassResolver.exposeResolverGuardFailure") + .internal() + .doc( + "When true, any failure thrown from ResolverGuard will be exposed as a query failure. " + + "Otherwise we just assume that the ResolverGuard returned false and the query is not " + + "supported by the single-pass Analyzer. This is important to make dual-runs unnoticeable " + + "in production.") + .version("4.1.0") + .booleanConf + .createWithDefault(Utils.isTesting) + + val ANALYZER_SINGLE_PASS_RESOLVER_PREVENT_USING_ALIASES_FROM_NON_DIRECT_CHILDREN = + buildConf("spark.sql.analyzer.singlePassResolver.preventUsingAliasesFromNonDirectChildren") + .internal() + .doc("When true, in Sort/Having/Filter expressions allow replacing of these expressions " + + "only with semantically equal aliased expressions from direct children. This is " + + "necessary in order to stay compatible with fixed-point, but the functionality and " + + "correctness remain the same. Because enabling this case would break some cases that " + + "are supported in single-pass but not in fixed-point, this flag should only be used to " + + "hide false positive logical plan mismatches during testing.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val ANALYZER_SINGLE_PASS_RESOLVER_VALIDATION_ENABLED = buildConf("spark.sql.analyzer.singlePassResolver.validationEnabled") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolverSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolverSuite.scala deleted file mode 100644 index c8b4db3e10aa2..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/LimitExpressionResolverSuite.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis.resolver - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Literal} -import org.apache.spark.sql.errors.QueryErrorsBase -import org.apache.spark.sql.types.IntegerType - -class LimitExpressionResolverSuite extends SparkFunSuite with QueryErrorsBase { - - private val limitExpressionResolver = new LimitExpressionResolver - - test("Basic LIMIT without errors") { - val expr = Literal(42, IntegerType) - assert(limitExpressionResolver.resolve(expr) == expr) - } - - test("Unfoldable LIMIT") { - val col = AttributeReference(name = "foo", dataType = IntegerType)() - checkError( - exception = intercept[AnalysisException] { - limitExpressionResolver.resolve(col) - }, - condition = "INVALID_LIMIT_LIKE_EXPRESSION.IS_UNFOLDABLE", - parameters = Map("name" -> "limit", "expr" -> toSQLExpr(col)) - ) - } - - test("LIMIT with non-integer") { - val anyNonInteger = Literal("42") - checkError( - exception = intercept[AnalysisException] { - limitExpressionResolver.resolve(anyNonInteger) - }, - condition = "INVALID_LIMIT_LIKE_EXPRESSION.DATA_TYPE", - parameters = Map( - "name" -> "limit", - "expr" -> toSQLExpr(anyNonInteger), - "dataType" -> toSQLType(anyNonInteger.dataType) - ) - ) - } - - test("LIMIT with null") { - val expr = Cast(Literal(null), IntegerType) - checkError( - exception = intercept[AnalysisException] { - limitExpressionResolver.resolve(expr) - }, - condition = "INVALID_LIMIT_LIKE_EXPRESSION.IS_NULL", - parameters = Map( - "name" -> "limit", - "expr" -> toSQLExpr(expr) - ) - ) - } - - test("LIMIT with negative integer") { - val expr = Literal(-1, IntegerType) - checkError( - exception = intercept[AnalysisException] { - limitExpressionResolver.resolve(expr) - }, - condition = "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE", - parameters = Map( - "name" -> "limit", - "expr" -> toSQLExpr(expr), - "v" -> toSQLValue(-1, IntegerType) - ) - ) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidatorSuite.scala index 94b954c9b9a0f..e4af0df35615b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolutionValidatorSuite.scala @@ -33,14 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{ TimestampAddInterval } import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project} -import org.apache.spark.sql.types.{ - BooleanType, - DayTimeIntervalType, - DecimalType, - IntegerType, - StringType, - TimestampType -} +import org.apache.spark.sql.types._ class ResolutionValidatorSuite extends SparkFunSuite with SQLConfHelper { private val resolveMethodNamesToIgnore = Seq( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolverSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolverSuite.scala index e2e52081e8c65..8897d65654540 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolverSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/resolver/TimezoneAwareExpressionResolverSuite.scala @@ -47,6 +47,15 @@ class TimezoneAwareExpressionResolverSuite extends SparkFunSuite { AttributeReference(name = "unresolvedChild", dataType = StringType)() private val resolvedChild = AttributeReference(name = "resolvedChild", dataType = IntegerType)() private val castExpression = Cast(child = unresolvedChild, dataType = IntegerType) + private val nestedCasts = Cast( + child = Cast( + child = Cast(child = unresolvedChild, dataType = IntegerType, timeZoneId = Some("UTC")), + dataType = IntegerType, + timeZoneId = None + ), + dataType = IntegerType, + timeZoneId = None + ) private val expressionResolver = new HardCodedExpressionResolver( catalogManager = mock[CatalogManager], resolvedExpression = resolvedChild @@ -72,4 +81,14 @@ class TimezoneAwareExpressionResolverSuite extends SparkFunSuite { assert(resolvedExpression.timeZoneId.nonEmpty) assert(resolvedExpression.getTagValue(Cast.USER_SPECIFIED_CAST).nonEmpty) } + + test("Timezone is applied recursively") { + val expressionWithTimezone = + TimezoneAwareExpressionResolver.resolveTimezone(nestedCasts, "UTC") + + assert(expressionWithTimezone.asInstanceOf[Cast].timeZoneId.get == "UTC") + assert( + expressionWithTimezone.asInstanceOf[Cast].child.asInstanceOf[Cast].timeZoneId.get == "UTC" + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExplicitlyUnsupportedResolverFeatureSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExplicitlyUnsupportedResolverFeatureSuite.scala index 0e66897868d3f..222128b394449 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExplicitlyUnsupportedResolverFeatureSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ExplicitlyUnsupportedResolverFeatureSuite.scala @@ -45,22 +45,6 @@ class ExplicitlyUnsupportedResolverFeatureSuite extends QueryTest with SharedSpa } } - test("Unsupported star expansion") { - checkResolution("SELECT * FROM VALUES (1, 2) WHERE 3 IN (*)") - } - - test("Lateral column alias in Aggregate below a Sort") { - checkResolution( - "SELECT dept AS d, d, 10 AS d FROM VALUES(1) AS t(dept) GROUP BY dept ORDER BY dept" - ) - } - - test("Unsupported lambda") { - checkResolution( - "SELECT array_sort(array(2, 1), (p1, p2) -> CASE WHEN p1 > p2 THEN 1 ELSE 0 END)" - ) - } - private def checkResolution(sqlText: String, shouldPass: Boolean = false): Unit = { val unresolvedPlan = spark.sessionState.sqlParser.parsePlan(sqlText) checkResolution(unresolvedPlan, shouldPass) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/HybridAnalyzerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/HybridAnalyzerSuite.scala index 66f412c8c3195..e5ee8c8ae2b88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/HybridAnalyzerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/HybridAnalyzerSuite.scala @@ -21,30 +21,18 @@ import org.scalactic.source.Position import org.scalatest.Tag import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.{ - ExtendedAnalysisException, - QueryPlanningTracker -} +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{ AnalysisContext, Analyzer, UnresolvedAttribute, UnresolvedStar } -import org.apache.spark.sql.catalyst.analysis.resolver.{ - AnalyzerBridgeState, - ExplicitlyUnsupportedResolverFeature, - HybridAnalyzer, - Resolver, - ResolverGuard -} +import org.apache.spark.sql.catalyst.analysis.resolver._ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.NormalizePlan -import org.apache.spark.sql.catalyst.plans.logical.{ - LocalRelation, - LogicalPlan, - Project -} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -97,6 +85,13 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { } } + private class BrokenResolverGuard(catalogManager: CatalogManager) + extends ResolverGuard(catalogManager) { + override def apply(plan: LogicalPlan): Boolean = { + throw new Exception("Broken resolver guard") + } + } + private class ValidatingResolver(bridgeRelations: Boolean) extends Resolver(spark.sessionState.catalogManager) { override def lookupMetadataAndResolve( @@ -300,13 +295,7 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { } test("Explicitly unsupported resolver feature") { - val plan: LogicalPlan = { - Project( - Seq(UnresolvedStar(None)), - LocalRelation(col1Integer) - ) - } - checkAnswer( + assertPlansEqual( new HybridAnalyzer( new ValidatingAnalyzer(bridgeRelations = true), new ResolverGuard(spark.sessionState.catalogManager), @@ -314,8 +303,8 @@ class HybridAnalyzerSuite extends QueryTest with SharedSparkSession { new ExplicitlyUnsupportedResolverFeature("FAILURE"), bridgeRelations = true ) - ).apply(plan, new QueryPlanningTracker), - plan + ).apply(unresolvedPlan, new QueryPlanningTracker), + resolvedPlan ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/MetadataResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/MetadataResolverSuite.scala index be5d95633b5e3..f315f1b8a5971 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/MetadataResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/MetadataResolverSuite.scala @@ -21,28 +21,19 @@ import scala.collection.mutable import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{ - AnalysisContext, - FunctionResolution, - UnresolvedRelation -} -import org.apache.spark.sql.catalyst.analysis.resolver.{ - AnalyzerBridgeState, - BridgedRelationId, - BridgedRelationMetadataProvider, - MetadataResolver, - RelationId, - Resolver, - ViewResolver -} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.resolver._ import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.execution.datasources.{FileResolver, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap -class MetadataResolverSuite extends QueryTest with SharedSparkSession with SQLTestUtils { +class MetadataResolverSuite + extends QueryTest + with SharedSparkSession + with SQLTestUtils { private val catalogName = "spark_catalog" private val keyValueTableSchema = StructType( @@ -57,9 +48,7 @@ class MetadataResolverSuite extends QueryTest with SharedSparkSession with SQLTe ) ) - test( - "Single CSV relation" - ) { + test("Single CSV relation") { withTable("src_csv") { spark.sql("CREATE TABLE src_csv (key INT, value STRING) USING CSV;").collect() @@ -70,9 +59,7 @@ class MetadataResolverSuite extends QueryTest with SharedSparkSession with SQLTe } } - test( - "Single ORC relation" - ) { + test("Single ORC relation") { withTable("src_orc") { spark.sql("CREATE TABLE src_orc (key INT, value STRING) USING ORC;").collect() @@ -176,9 +163,7 @@ class MetadataResolverSuite extends QueryTest with SharedSparkSession with SQLTe } } - test( - "Relation from a file" - ) { + test("Relation from a file") { val df = spark.range(100).toDF() withTempPath(f => { df.write.json(f.getCanonicalPath) @@ -341,7 +326,6 @@ class MetadataResolverSuite extends QueryTest with SharedSparkSession with SQLTe new MetadataResolver( spark.sessionState.catalogManager, relationResolution, - new FunctionResolution(spark.sessionState.catalogManager, relationResolution), Seq(new FileResolver(spark)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/NameScopeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/NameScopeSuite.scala index 1fdf833d22de5..614c3b2e9ad3a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/NameScopeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/NameScopeSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.analysis.resolver -import java.util.HashSet +import java.util.{Arrays, HashSet} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar import org.apache.spark.sql.catalyst.analysis.resolver.{NameScope, NameScopeStack, NameTarget} import org.apache.spark.sql.catalyst.expressions.{ Attribute, @@ -34,15 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{ OuterReference } import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.types.{ - ArrayType, - BooleanType, - IntegerType, - MapType, - StringType, - StructField, - StructType -} +import org.apache.spark.sql.types._ class NameScopeSuite extends PlanTest { private val col1Integer = AttributeReference(name = "col1", dataType = IntegerType)() @@ -85,6 +77,7 @@ class NameScopeSuite extends PlanTest { name = "col10", dataType = MapType(StringType, IntegerType) )() + private val col10Integer = AttributeReference(name = "col10", dataType = IntegerType)() private val col11MapWithStruct = AttributeReference( name = "col11", dataType = MapType( @@ -644,6 +637,259 @@ class NameScopeSuite extends PlanTest { } } + test("Hidden output gets prioritized because of conflict") { + val stack = new NameScopeStack + + stack.overwriteCurrent( + output = Some(Seq(col1Integer, col1IntegerOther)), + hiddenOutput = Some(Seq(col1IntegerOther, col2Integer)), + availableAliases = Some(new HashSet[ExprId](Arrays.asList(col1Integer.exprId))) + ) + + assert( + stack.resolveMultipartName(Seq("col1")) == NameTarget( + candidates = Seq(col1Integer, col1IntegerOther), + output = Seq(col1Integer, col1IntegerOther) + ) + ) + assert( + stack.resolveMultipartName(Seq("col1"), shouldPreferHiddenOutput = true) == NameTarget( + candidates = Seq(col1Integer, col1IntegerOther), + output = Seq(col1Integer, col1IntegerOther) + ) + ) + assert( + stack.resolveMultipartName(Seq("col1"), canResolveNameByHiddenOutput = true) == NameTarget( + candidates = Seq(col1IntegerOther), + output = Seq(col1Integer, col1IntegerOther) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col1"), + canResolveNameByHiddenOutput = true, + shouldPreferHiddenOutput = true + ) == NameTarget( + candidates = Seq(col1IntegerOther), + output = Seq(col1Integer, col1IntegerOther) + ) + ) + } + + test("Main output gets prioritized because of conflict") { + val stack = new NameScopeStack + + stack.overwriteCurrent( + output = Some(Seq(col1Integer)), + hiddenOutput = Some(Seq(col1Integer, col1IntegerOther, col2Integer)), + availableAliases = Some(new HashSet[ExprId]) + ) + + assert( + stack.resolveMultipartName(Seq("col1")) == NameTarget( + candidates = Seq(col1Integer), + output = Seq(col1Integer) + ) + ) + assert( + stack.resolveMultipartName(Seq("col1"), shouldPreferHiddenOutput = true) == NameTarget( + candidates = Seq(col1Integer), + output = Seq(col1Integer) + ) + ) + assert( + stack.resolveMultipartName(Seq("col1"), canResolveNameByHiddenOutput = true) == NameTarget( + candidates = Seq(col1Integer), + output = Seq(col1Integer) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col1"), + canResolveNameByHiddenOutput = true, + shouldPreferHiddenOutput = true + ) == NameTarget( + candidates = Seq(col1Integer), + output = Seq(col1Integer) + ) + ) + } + + test("Both main and hidden outputs have a conflict") { + val stack = new NameScopeStack + + stack.overwriteCurrent( + output = Some(Seq(col1Integer, col1IntegerOther)), + hiddenOutput = Some(Seq(col1Integer, col1IntegerOther, col2Integer)), + availableAliases = Some(new HashSet[ExprId]) + ) + + assert( + stack.resolveMultipartName(Seq("col1")) == NameTarget( + candidates = Seq(col1Integer, col1IntegerOther), + output = Seq(col1Integer, col1IntegerOther) + ) + ) + assert( + stack.resolveMultipartName(Seq("col1"), shouldPreferHiddenOutput = true) == NameTarget( + candidates = Seq(col1Integer, col1IntegerOther), + output = Seq(col1Integer, col1IntegerOther) + ) + ) + assert( + stack.resolveMultipartName(Seq("col1"), canResolveNameByHiddenOutput = true) == NameTarget( + candidates = Seq(col1Integer, col1IntegerOther), + output = Seq(col1Integer, col1IntegerOther) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col1"), + canResolveNameByHiddenOutput = true, + shouldPreferHiddenOutput = true + ) == NameTarget( + candidates = Seq(col1Integer, col1IntegerOther), + output = Seq(col1Integer, col1IntegerOther) + ) + ) + } + + test("Hidden output gets prioritized because of impossible extract") { + val stack = new NameScopeStack + + stack.overwriteCurrent( + output = Some(Seq(col10Integer)), + hiddenOutput = Some(Seq(col10Map)), + availableAliases = Some(new HashSet[ExprId](Arrays.asList(col10Integer.exprId))) + ) + + assert( + stack.resolveMultipartName(Seq("col10", "key")) == NameTarget( + candidates = Seq.empty, + output = Seq(col10Integer) + ) + ) + assert( + stack + .resolveMultipartName(Seq("col10", "key"), shouldPreferHiddenOutput = true) == NameTarget( + candidates = Seq.empty, + output = Seq(col10Integer) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col10", "key"), + canResolveNameByHiddenOutput = true + ) == NameTarget( + candidates = Seq(GetMapValue(col10Map, Literal("key"))), + aliasName = Some("key"), + output = Seq(col10Integer) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col10", "key"), + canResolveNameByHiddenOutput = true, + shouldPreferHiddenOutput = true + ) == NameTarget( + candidates = Seq(GetMapValue(col10Map, Literal("key"))), + aliasName = Some("key"), + output = Seq(col10Integer) + ) + ) + } + + test("Main output gets prioritized because of impossible extract") { + val stack = new NameScopeStack + + stack.overwriteCurrent( + output = Some(Seq(col10Map)), + hiddenOutput = Some(Seq(col10Integer)), + availableAliases = Some(new HashSet[ExprId](Arrays.asList(col10Map.exprId))) + ) + + assert( + stack.resolveMultipartName(Seq("col10", "key")) == NameTarget( + candidates = Seq(GetMapValue(col10Map, Literal("key"))), + aliasName = Some("key"), + output = Seq(col10Map) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col10", "key"), + shouldPreferHiddenOutput = true + ) == NameTarget( + candidates = Seq(GetMapValue(col10Map, Literal("key"))), + aliasName = Some("key"), + output = Seq(col10Map) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col10", "key"), + canResolveNameByHiddenOutput = true + ) == NameTarget( + candidates = Seq(GetMapValue(col10Map, Literal("key"))), + aliasName = Some("key"), + output = Seq(col10Map) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col10", "key"), + canResolveNameByHiddenOutput = true, + shouldPreferHiddenOutput = true + ) == NameTarget( + candidates = Seq(GetMapValue(col10Map, Literal("key"))), + aliasName = Some("key"), + output = Seq(col10Map) + ) + ) + } + + test("Both main and hidden outputs have impossible extract") { + val stack = new NameScopeStack + + stack.overwriteCurrent( + output = Some(Seq(col1Integer)), + hiddenOutput = Some(Seq(col1IntegerOther)), + availableAliases = Some(new HashSet[ExprId](Arrays.asList(col1Integer.exprId))) + ) + + assert( + stack.resolveMultipartName(Seq("col1", "key")) == NameTarget( + candidates = Seq.empty, + output = Seq(col1Integer) + ) + ) + assert( + stack.resolveMultipartName(Seq("col1", "key"), shouldPreferHiddenOutput = true) == NameTarget( + candidates = Seq.empty, + output = Seq(col1Integer) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col1", "key"), + canResolveNameByHiddenOutput = true + ) == NameTarget( + candidates = Seq.empty, + output = Seq(col1Integer) + ) + ) + assert( + stack.resolveMultipartName( + Seq("col1", "key"), + canResolveNameByHiddenOutput = true, + shouldPreferHiddenOutput = true + ) == NameTarget( + candidates = Seq.empty, + output = Seq(col1Integer) + ) + ) + } + test("Empty stack") { val stack = new NameScopeStack @@ -728,37 +974,6 @@ class NameScopeSuite extends PlanTest { assert(stack.current.output == Seq(col1Integer)) } - test( - "Name resolution should prefer table columns over aliases with same name when " + - "shouldPreferTableColumnsOverAliases is set or throw AMBIGUOUS_REFERENCE otherwise" - ) { - val stack = new NameScopeStack - val output = Seq(col1Integer, col1IntegerOther) - val availableAliases = new HashSet[ExprId](1) - availableAliases.add(col1IntegerOther.exprId) - - stack.overwriteCurrent(output = Some(output), availableAliases = Some(availableAliases)) - - assert( - stack.resolveMultipartName( - multipartName = Seq("col1"), - shouldPreferTableColumnsOverAliases = true - ) == NameTarget( - candidates = Seq(col1Integer), - output = output - ) - ) - - checkError( - exception = intercept[AnalysisException] { - val nameTarget = stack.resolveMultipartName(multipartName = Seq("col1")) - nameTarget.pickCandidate(UnresolvedAttribute(nameParts = Seq("col1"))) - }, - condition = "AMBIGUOUS_REFERENCE", - parameters = Map("name" -> "`col1`", "referenceNames" -> "[`col1`, `col1`]") - ) - } - /** * Check both [[resolveMultipartName]] and [[findAttributesByName]] for a single part name. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala index dbdca7d0b5d6d..7d6b321edba1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ResolverGuardSuite.scala @@ -232,6 +232,7 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { test("Group by") { checkResolverGuard("SELECT col1, count(col1) FROM VALUES(1) GROUP BY ALL", shouldPass = true) + checkResolverGuard("SELECT * FROM VALUES(1,2,3) GROUP BY ALL", shouldPass = true) checkResolverGuard("SELECT col1 FROM VALUES(1) GROUP BY 1", shouldPass = true) checkResolverGuard("SELECT col1, col1 + 1 FROM VALUES(1) GROUP BY 1, col1", shouldPass = true) } @@ -286,6 +287,17 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { ) } + test("TABLESAMPLE") { + checkResolverGuard( + "SELECT * FROM (VALUES (1), (2), (3)) TABLESAMPLE (40 PERCENT)", + shouldPass = true + ) + } + + test("Semi-structured extract") { + checkResolverGuard("SELECT PARSE_JSON('{\"a\":1}'):a", shouldPass = true) + } + // Queries that shouldn't pass the OperatorResolverGuard test("Unsupported literal functions") { @@ -306,9 +318,26 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { } } - test("UDFs") { - sql("CREATE FUNCTION supermario(x INT) RETURNS INT RETURN x + 3") - checkResolverGuard("SELECT supermario(2)", shouldPass = false) + test("UDF") { + withSqlFunction("supermario") { + sql("CREATE FUNCTION supermario(x INT) RETURNS INT RETURN x + 3") + + checkResolverGuard("SELECT supermario(2)", shouldPass = false) + } + } + + test("UDF in a database with the same name as a built-in function") { + withDatabase("upper") { + sql("CREATE DATABASE IF NOT EXISTS upper") + + withSqlFunction("supermario") { + sql("USE DATABASE upper") + + sql("CREATE FUNCTION supermario(x INT) RETURNS INT RETURN x + 3") + + checkResolverGuard("SELECT upper.supermario(2)", shouldPass = false) + } + } } test("PLAN_ID_TAG") { @@ -320,6 +349,19 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { checkResolverGuard(plan, shouldPass = false) } + test("Star outside of Project list") { + checkResolverGuard("SELECT * FROM VALUES (1, 2) WHERE 3 IN (*)", shouldPass = false) + } + + test("Lambda variable") { + checkResolverGuard( + "SELECT array_sort(array(2, 1), (p1, p2) -> IF(p1 > p2, 1, 0))", + shouldPass = false + ) + checkResolverGuard("SELECT transform(array(2, 1), x -> x * 2)", shouldPass = false) + checkResolverGuard("SELECT filter(array(2, 1), x -> x > 0)", shouldPass = false) + } + test("Catch ExplicitlyUnsupportedResolverFeature exceptions") { class ThrowsExplicitlyUnsupportedFeatureResolver @@ -348,7 +390,7 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { } private def checkResolverGuard(query: String, shouldPass: Boolean): Unit = { - checkResolverGuard(spark.sql(query).queryExecution.logical, shouldPass) + checkResolverGuard(spark.sessionState.sqlParser.parsePlan(query), shouldPass) } private def checkResolverGuard( @@ -396,4 +438,12 @@ class ResolverGuardSuite extends QueryTest with SharedSparkSession { sql("DROP TEMPORARY VARIABLE session_variable;") } } + + private def withSqlFunction[R](name: String)(body: => R): R = { + try { + body + } finally { + spark.sql(s"DROP FUNCTION IF EXISTS $name") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ViewResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ViewResolverSuite.scala index 9d2601e2578c3..a03e353c40b01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ViewResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/analysis/resolver/ViewResolverSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.analysis.resolver import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.analysis.{FunctionResolution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.analysis.resolver.{MetadataResolver, Resolver, ResolverRunner} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -177,8 +177,7 @@ class ViewResolverSuite extends QueryTest with SharedSparkSession { val relationResolution = Resolver.createRelationResolution(spark.sessionState.catalogManager) val metadataResolver = new MetadataResolver( spark.sessionState.catalogManager, - relationResolution, - new FunctionResolution(spark.sessionState.catalogManager, relationResolution) + relationResolution ) val unresolvedPlan = spark.sessionState.sqlParser.parsePlan(sqlText) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceResolverSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceResolverSuite.scala index 11b2577f1e873..5e12948aab68e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceResolverSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceResolverSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.analysis.{FunctionResolution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.analysis.resolver.{ MetadataResolver, ProhibitedResolver, @@ -93,8 +93,7 @@ class DataSourceResolverSuite extends QueryTest with SharedSparkSession { Resolver.createRelationResolution(spark.sessionState.catalogManager) val metadataResolver = new MetadataResolver( spark.sessionState.catalogManager, - relationResolution, - new FunctionResolution(spark.sessionState.catalogManager, relationResolution) + relationResolution ) val dataSourceResolver = new DataSourceResolver(spark) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveTableRelationResolverSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveTableRelationResolverSuite.scala index 8861d2c1f1683..674b726e94fb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveTableRelationResolverSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveTableRelationResolverSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.catalyst.analysis.{FunctionResolution, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.analysis.resolver.{ MetadataResolver, ProhibitedResolver, @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.resolver.{ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -77,7 +76,6 @@ class HiveTableRelationResolverSuite extends TestHiveSingleton with SQLTestUtils val metadataResolver = new MetadataResolver( spark.sessionState.catalogManager, relationResolution, - new FunctionResolution(spark.sessionState.catalogManager, relationResolution), extensions = spark.sessionState.analyzer.singlePassMetadataResolverExtensions ) val hiveTableRelationResolver = new HiveTableRelationResolver(