Skip to content

[SPARK-52842][SQL] New functionality and bugfixes for single-pass analyzer #51513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.resolver.{
AnalyzerBridgeState,
HybridAnalyzer,
Resolver => OperatorResolver,
ResolverExtension,
ResolverGuard
ResolverExtension
}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
Expand Down Expand Up @@ -297,17 +295,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
def getRelationResolution: RelationResolution = relationResolution

def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = {
if (plan.analyzed) return plan
AnalysisHelper.markInAnalyzer {
new HybridAnalyzer(
this,
new ResolverGuard(catalogManager),
new OperatorResolver(
catalogManager,
singlePassResolverExtensions,
singlePassMetadataResolverExtensions
)
).apply(plan, tracker)
if (plan.analyzed) {
plan
} else {
AnalysisContext.reset()
try {
AnalysisHelper.markInAnalyzer {
HybridAnalyzer.fromLegacyAnalyzer(legacyAnalyzer = this).apply(plan, tracker)
}
} finally {
AnalysisContext.reset()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
package org.apache.spark.sql.catalyst.analysis.resolver

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{
AnsiTypeCoercion,
CollationTypeCoercion,
TypeCoercion
}
import org.apache.spark.sql.catalyst.expressions.{Expression, OuterReference, SubExprUtils}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ListAgg}
import org.apache.spark.sql.catalyst.util.toPrettySQL
Expand All @@ -41,11 +36,6 @@ class AggregateExpressionResolver(

private val traversals = expressionResolver.getExpressionTreeTraversals

protected override val ansiTransformations: CoercesExpressionTypes.Transformations =
AggregateExpressionResolver.ANSI_TYPE_COERCION_TRANSFORMATIONS
protected override val nonAnsiTransformations: CoercesExpressionTypes.Transformations =
AggregateExpressionResolver.TYPE_COERCION_TRANSFORMATIONS

private val expressionResolutionContextStack =
expressionResolver.getExpressionResolutionContextStack
private val subqueryRegistry = operatorResolver.getSubqueryRegistry
Expand All @@ -58,6 +48,7 @@ class AggregateExpressionResolver(
* resolving its children recursively and validating the resolved expression.
*/
override def resolve(aggregateExpression: AggregateExpression): Expression = {
expressionResolutionContextStack.peek().resolvingTreeUnderAggregateExpression = true
val aggregateExpressionWithChildrenResolved =
withResolvedChildren(aggregateExpression, expressionResolver.resolve _)
.asInstanceOf[AggregateExpression]
Expand Down Expand Up @@ -132,15 +123,13 @@ class AggregateExpressionResolver(
throwNestedAggregateFunction(aggregateExpression)
}

val nonDeterministicChild =
aggregateExpression.aggregateFunction.children.collectFirst {
case child if !child.deterministic => child
aggregateExpression.aggregateFunction.children.foreach { child =>
if (!child.deterministic) {
throwAggregateFunctionWithNondeterministicExpression(
aggregateExpression,
child
)
}
if (nonDeterministicChild.nonEmpty) {
throwAggregateFunctionWithNondeterministicExpression(
aggregateExpression,
nonDeterministicChild.get
)
}
}

Expand Down Expand Up @@ -249,23 +238,3 @@ class AggregateExpressionResolver(
)
}
}

object AggregateExpressionResolver {
// Ordering in the list of type coercions should be in sync with the list in [[TypeCoercion]].
private val TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq(
CollationTypeCoercion.apply,
TypeCoercion.InTypeCoercion.apply,
TypeCoercion.FunctionArgumentTypeCoercion.apply,
TypeCoercion.IfTypeCoercion.apply,
TypeCoercion.ImplicitTypeCoercion.apply
)

// Ordering in the list of type coercions should be in sync with the list in [[AnsiTypeCoercion]].
private val ANSI_TYPE_COERCION_TRANSFORMATIONS: Seq[Expression => Expression] = Seq(
CollationTypeCoercion.apply,
AnsiTypeCoercion.InTypeCoercion.apply,
AnsiTypeCoercion.FunctionArgumentTypeCoercion.apply,
AnsiTypeCoercion.IfTypeCoercion.apply,
AnsiTypeCoercion.ImplicitTypeCoercion.apply
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ package org.apache.spark.sql.catalyst.analysis.resolver
import java.util.HashSet

import org.apache.spark.sql.catalyst.expressions.{Alias, ExprId, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}

/**
* Stores the resulting operator, output list, grouping attributes and list of aliases from
* aggregate list, obtained by resolving an [[Aggregate]] operator.
* Stores the resulting operator, output list, grouping attributes, list of aliases from
* aggregate list and base [[Aggregate]], obtained by resolving an [[Aggregate]] operator.
*/
case class AggregateResolutionResult(
operator: LogicalPlan,
outputList: Seq[NamedExpression],
groupingAttributeIds: Option[HashSet[ExprId]],
aggregateListAliases: Seq[Alias])
groupingAttributeIds: HashSet[ExprId],
aggregateListAliases: Seq[Alias],
baseAggregate: Aggregate)
Original file line number Diff line number Diff line change
Expand Up @@ -17,50 +17,55 @@

package org.apache.spark.sql.catalyst.analysis.resolver

import java.util.{HashSet, LinkedHashMap}
import java.util.HashSet

import scala.jdk.CollectionConverters._

import org.apache.spark.sql.catalyst.analysis.{
AnalysisErrorAt,
NondeterministicExpressionCollection,
UnresolvedAttribute
}
import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.{
Alias,
AliasHelper,
AttributeReference,
Expression,
ExprId,
ExprUtils,
NamedExpression
ExprUtils
}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}

/**
* Resolves an [[Aggregate]] by resolving its child, aggregate expressions and grouping
* expressions. Updates the [[NameScopeStack]] with its output and performs validation
* related to [[Aggregate]] resolution.
*/
class AggregateResolver(operatorResolver: Resolver, expressionResolver: ExpressionResolver)
extends TreeNodeResolver[Aggregate, LogicalPlan] {
extends TreeNodeResolver[Aggregate, LogicalPlan]
with AliasHelper {
private val scopes = operatorResolver.getNameScopes
private val lcaResolver = expressionResolver.getLcaResolver

/**
* Resolve [[Aggregate]] operator.
*
* 1. Resolve the child (inline table).
* 2. Resolve aggregate expressions using [[ExpressionResolver.resolveAggregateExpressions]] and
* 2. Clear [[NameScope.availableAliases]]. Those are only relevant for the immediate aggregate
* expressions for output prioritization to work correctly in
* [[NameScope.tryResolveMultipartNameByOutput]].
* 3. Resolve aggregate expressions using [[ExpressionResolver.resolveAggregateExpressions]] and
* set [[NameScope.ordinalReplacementExpressions]] for grouping expressions resolution.
* 3. If there's just one [[UnresolvedAttribute]] with a single-part name "ALL", expand it using
* 4. If there's just one [[UnresolvedAttribute]] with a single-part name "ALL", expand it using
* aggregate expressions which don't contain aggregate functions. There should not exist a
* column with that name in the lower operator's output, otherwise it takes precedence.
* 4. Resolve grouping expressions using [[ExpressionResolver.resolveGroupingExpressions]]. This
* 5. Resolve grouping expressions using [[ExpressionResolver.resolveGroupingExpressions]]. This
* includes alias references to aggregate expressions, which is done in
* [[NameScope.resolveMultipartName]] and replacing [[UnresolvedOrdinals]] with corresponding
* expressions from aggregate list, done in [[OrdinalResolver]].
* 5. Substitute non-deterministic expressions with derived attribute references to an
* artificial [[Project]] list.
* 6. Remove all the unnecessary [[Alias]]es from the grouping (all the aliases) and aggregate
* (keep the outermost one) expressions. This is needed to stay compatible with the
* fixed-point implementation. For example:
*
* {{{ SELECT timestamp(col1:str) FROM VALUES('a') GROUP BY timestamp(col1:str); }}}
*
* Here we end up having inner [[Alias]]es in both the grouping and aggregate expressions
* lists which are uncomparable because they have different expression IDs (thus we have to
* strip them).
*
* If the resulting [[Aggregate]] contains lateral columns references, delegate the resolution of
* these columns to [[LateralColumnAliasResolver.handleLcaInAggregate]]. Otherwise, validate the
Expand All @@ -73,6 +78,8 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
val resolvedAggregate = try {
val resolvedChild = operatorResolver.resolve(unresolvedAggregate.child)

scopes.current.availableAliases.clear()

val resolvedAggregateExpressions = expressionResolver.resolveAggregateExpressions(
unresolvedAggregate.aggregateExpressions,
unresolvedAggregate
Expand Down Expand Up @@ -100,21 +107,25 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
)
}

val partiallyResolvedAggregate = unresolvedAggregate.copy(
groupingExpressions = resolvedGroupingExpressions,
aggregateExpressions = resolvedAggregateExpressions.expressions,
val resolvedGroupingExpressionsWithoutAliases = resolvedGroupingExpressions.map(trimAliases)
val resolvedAggregateExpressionsWithoutAliases =
resolvedAggregateExpressions.expressions.map(trimNonTopLevelAliases)

val resolvedAggregate = unresolvedAggregate.copy(
groupingExpressions = resolvedGroupingExpressionsWithoutAliases,
aggregateExpressions = resolvedAggregateExpressionsWithoutAliases,
child = resolvedChild
)

val resolvedAggregate = tryPullOutNondeterministic(partiallyResolvedAggregate)

if (resolvedAggregateExpressions.hasLateralColumnAlias) {
val aggregateWithLcaResolutionResult = lcaResolver.handleLcaInAggregate(resolvedAggregate)
AggregateResolutionResult(
operator = aggregateWithLcaResolutionResult.resolvedOperator,
outputList = aggregateWithLcaResolutionResult.outputList,
groupingAttributeIds = None,
aggregateListAliases = aggregateWithLcaResolutionResult.aggregateListAliases
groupingAttributeIds =
getGroupingAttributeIds(aggregateWithLcaResolutionResult.baseAggregate),
aggregateListAliases = aggregateWithLcaResolutionResult.aggregateListAliases,
baseAggregate = aggregateWithLcaResolutionResult.baseAggregate
)
} else {
// TODO: This validation function does a post-traversal. This is discouraged in single-pass
Expand All @@ -124,8 +135,9 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
AggregateResolutionResult(
operator = resolvedAggregate,
outputList = resolvedAggregate.aggregateExpressions,
groupingAttributeIds = Some(getGroupingAttributeIds(resolvedAggregate)),
aggregateListAliases = scopes.current.getTopAggregateExpressionAliases
groupingAttributeIds = getGroupingAttributeIds(resolvedAggregate),
aggregateListAliases = scopes.current.getTopAggregateExpressionAliases,
baseAggregate = resolvedAggregate
)
}
} finally {
Expand All @@ -134,8 +146,9 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi

scopes.overwriteOutputAndExtendHiddenOutput(
output = resolvedAggregate.outputList.map(_.toAttribute),
groupingAttributeIds = resolvedAggregate.groupingAttributeIds,
aggregateListAliases = resolvedAggregate.aggregateListAliases
groupingAttributeIds = Some(resolvedAggregate.groupingAttributeIds),
aggregateListAliases = resolvedAggregate.aggregateListAliases,
baseAggregate = Some(resolvedAggregate.baseAggregate)
)

resolvedAggregate.operator
Expand Down Expand Up @@ -208,53 +221,6 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
}
}

/**
* In case there are non-deterministic expressions in either `groupingExpressions` or
* `aggregateExpressions` replace them with attributes created out of corresponding
* non-deterministic expression. Example:
*
* {{{ SELECT RAND() GROUP BY 1; }}}
*
* This query would have the following analyzed plan:
* Aggregate(
* groupingExpressions = [AttributeReference(_nonDeterministic)]
* aggregateExpressions = [Alias(AttributeReference(_nonDeterministic), `rand()`)]
* child = Project(
* projectList = [Alias(Rand(...), `_nondeterministic`)]
* child = OneRowRelation
* )
* )
*/
private def tryPullOutNondeterministic(aggregate: Aggregate): Aggregate = {
val nondeterministicToAttributes: LinkedHashMap[Expression, NamedExpression] =
NondeterministicExpressionCollection.getNondeterministicToAttributes(
aggregate.groupingExpressions
)

if (!nondeterministicToAttributes.isEmpty) {
val newChild = Project(
scopes.current.output ++ nondeterministicToAttributes.values.asScala.toSeq,
aggregate.child
)
val resolvedAggregateExpressions = aggregate.aggregateExpressions.map { expression =>
PullOutNondeterministicExpressionInExpressionTree(expression, nondeterministicToAttributes)
}
val resolvedGroupingExpressions = aggregate.groupingExpressions.map { expression =>
PullOutNondeterministicExpressionInExpressionTree(
expression,
nondeterministicToAttributes
)
}
aggregate.copy(
groupingExpressions = resolvedGroupingExpressions,
aggregateExpressions = resolvedAggregateExpressions,
child = newChild
)
} else {
aggregate
}
}

private def canGroupByAll(expressions: Seq[Expression]): Boolean = {
val isOrderByAll = expressions match {
case Seq(unresolvedAttribute: UnresolvedAttribute) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
package org.apache.spark.sql.catalyst.analysis.resolver

import org.apache.spark.sql.catalyst.expressions.{Alias, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}

/**
* Stores the result of resolution of lateral column aliases in an [[Aggregate]].
* @param resolvedOperator The resolved operator.
* @param outputList The output list of the resolved operator.
* @param aggregateListAliases List of aliases from aggregate list and all artificially inserted
* [[Project]] nodes.
* @param baseAggregate [[Aggregate]] node constructed by [[LateralColumnAliasResolver]] while
* resolving lateral column references in [[Aggregate]].
*/
case class AggregateWithLcaResolutionResult(
resolvedOperator: LogicalPlan,
outputList: Seq[NamedExpression],
aggregateListAliases: Seq[Alias])
aggregateListAliases: Seq[Alias],
baseAggregate: Aggregate)
Loading