17
17
18
18
package org .apache .spark .sql .catalyst .analysis .resolver
19
19
20
- import java .util .{ HashSet , LinkedHashMap }
20
+ import java .util .HashSet
21
21
22
- import scala .jdk .CollectionConverters ._
23
-
24
- import org .apache .spark .sql .catalyst .analysis .{
25
- AnalysisErrorAt ,
26
- NondeterministicExpressionCollection ,
27
- UnresolvedAttribute
28
- }
22
+ import org .apache .spark .sql .catalyst .analysis .{AnalysisErrorAt , UnresolvedAttribute }
29
23
import org .apache .spark .sql .catalyst .expressions .{
30
24
Alias ,
25
+ AliasHelper ,
31
26
AttributeReference ,
32
27
Expression ,
33
28
ExprId ,
34
- ExprUtils ,
35
- NamedExpression
29
+ ExprUtils
36
30
}
37
- import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan , Project }
31
+ import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan }
38
32
39
33
/**
40
34
* Resolves an [[Aggregate ]] by resolving its child, aggregate expressions and grouping
41
35
* expressions. Updates the [[NameScopeStack ]] with its output and performs validation
42
36
* related to [[Aggregate ]] resolution.
43
37
*/
44
38
class AggregateResolver (operatorResolver : Resolver , expressionResolver : ExpressionResolver )
45
- extends TreeNodeResolver [Aggregate , LogicalPlan ] {
39
+ extends TreeNodeResolver [Aggregate , LogicalPlan ]
40
+ with AliasHelper {
46
41
private val scopes = operatorResolver.getNameScopes
47
42
private val lcaResolver = expressionResolver.getLcaResolver
48
43
49
44
/**
50
45
* Resolve [[Aggregate ]] operator.
51
46
*
52
47
* 1. Resolve the child (inline table).
53
- * 2. Resolve aggregate expressions using [[ExpressionResolver.resolveAggregateExpressions ]] and
48
+ * 2. Clear [[NameScope.availableAliases ]]. Those are only relevant for the immediate aggregate
49
+ * expressions for output prioritization to work correctly in
50
+ * [[NameScope.tryResolveMultipartNameByOutput ]].
51
+ * 3. Resolve aggregate expressions using [[ExpressionResolver.resolveAggregateExpressions ]] and
54
52
* set [[NameScope.ordinalReplacementExpressions ]] for grouping expressions resolution.
55
- * 3 . If there's just one [[UnresolvedAttribute ]] with a single-part name "ALL", expand it using
53
+ * 4 . If there's just one [[UnresolvedAttribute ]] with a single-part name "ALL", expand it using
56
54
* aggregate expressions which don't contain aggregate functions. There should not exist a
57
55
* column with that name in the lower operator's output, otherwise it takes precedence.
58
- * 4 . Resolve grouping expressions using [[ExpressionResolver.resolveGroupingExpressions ]]. This
56
+ * 5 . Resolve grouping expressions using [[ExpressionResolver.resolveGroupingExpressions ]]. This
59
57
* includes alias references to aggregate expressions, which is done in
60
58
* [[NameScope.resolveMultipartName ]] and replacing [[UnresolvedOrdinals ]] with corresponding
61
59
* expressions from aggregate list, done in [[OrdinalResolver ]].
62
- * 5. Substitute non-deterministic expressions with derived attribute references to an
63
- * artificial [[Project ]] list.
60
+ * 6. Remove all the unnecessary [[Alias ]]es from the grouping (all the aliases) and aggregate
61
+ * (keep the outermost one) expressions. This is needed to stay compatible with the
62
+ * fixed-point implementation. For example:
63
+ *
64
+ * {{{ SELECT timestamp(col1:str) FROM VALUES('a') GROUP BY timestamp(col1:str); }}}
65
+ *
66
+ * Here we end up having inner [[Alias ]]es in both the grouping and aggregate expressions
67
+ * lists which are uncomparable because they have different expression IDs (thus we have to
68
+ * strip them).
64
69
*
65
70
* If the resulting [[Aggregate ]] contains lateral columns references, delegate the resolution of
66
71
* these columns to [[LateralColumnAliasResolver.handleLcaInAggregate ]]. Otherwise, validate the
@@ -73,6 +78,8 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
73
78
val resolvedAggregate = try {
74
79
val resolvedChild = operatorResolver.resolve(unresolvedAggregate.child)
75
80
81
+ scopes.current.availableAliases.clear()
82
+
76
83
val resolvedAggregateExpressions = expressionResolver.resolveAggregateExpressions(
77
84
unresolvedAggregate.aggregateExpressions,
78
85
unresolvedAggregate
@@ -100,21 +107,25 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
100
107
)
101
108
}
102
109
103
- val partiallyResolvedAggregate = unresolvedAggregate.copy(
104
- groupingExpressions = resolvedGroupingExpressions,
105
- aggregateExpressions = resolvedAggregateExpressions.expressions,
110
+ val resolvedGroupingExpressionsWithoutAliases = resolvedGroupingExpressions.map(trimAliases)
111
+ val resolvedAggregateExpressionsWithoutAliases =
112
+ resolvedAggregateExpressions.expressions.map(trimNonTopLevelAliases)
113
+
114
+ val resolvedAggregate = unresolvedAggregate.copy(
115
+ groupingExpressions = resolvedGroupingExpressionsWithoutAliases,
116
+ aggregateExpressions = resolvedAggregateExpressionsWithoutAliases,
106
117
child = resolvedChild
107
118
)
108
119
109
- val resolvedAggregate = tryPullOutNondeterministic(partiallyResolvedAggregate)
110
-
111
120
if (resolvedAggregateExpressions.hasLateralColumnAlias) {
112
121
val aggregateWithLcaResolutionResult = lcaResolver.handleLcaInAggregate(resolvedAggregate)
113
122
AggregateResolutionResult (
114
123
operator = aggregateWithLcaResolutionResult.resolvedOperator,
115
124
outputList = aggregateWithLcaResolutionResult.outputList,
116
- groupingAttributeIds = None ,
117
- aggregateListAliases = aggregateWithLcaResolutionResult.aggregateListAliases
125
+ groupingAttributeIds =
126
+ getGroupingAttributeIds(aggregateWithLcaResolutionResult.baseAggregate),
127
+ aggregateListAliases = aggregateWithLcaResolutionResult.aggregateListAliases,
128
+ baseAggregate = aggregateWithLcaResolutionResult.baseAggregate
118
129
)
119
130
} else {
120
131
// TODO: This validation function does a post-traversal. This is discouraged in single-pass
@@ -124,8 +135,9 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
124
135
AggregateResolutionResult (
125
136
operator = resolvedAggregate,
126
137
outputList = resolvedAggregate.aggregateExpressions,
127
- groupingAttributeIds = Some (getGroupingAttributeIds(resolvedAggregate)),
128
- aggregateListAliases = scopes.current.getTopAggregateExpressionAliases
138
+ groupingAttributeIds = getGroupingAttributeIds(resolvedAggregate),
139
+ aggregateListAliases = scopes.current.getTopAggregateExpressionAliases,
140
+ baseAggregate = resolvedAggregate
129
141
)
130
142
}
131
143
} finally {
@@ -134,8 +146,9 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
134
146
135
147
scopes.overwriteOutputAndExtendHiddenOutput(
136
148
output = resolvedAggregate.outputList.map(_.toAttribute),
137
- groupingAttributeIds = resolvedAggregate.groupingAttributeIds,
138
- aggregateListAliases = resolvedAggregate.aggregateListAliases
149
+ groupingAttributeIds = Some (resolvedAggregate.groupingAttributeIds),
150
+ aggregateListAliases = resolvedAggregate.aggregateListAliases,
151
+ baseAggregate = Some (resolvedAggregate.baseAggregate)
139
152
)
140
153
141
154
resolvedAggregate.operator
@@ -208,53 +221,6 @@ class AggregateResolver(operatorResolver: Resolver, expressionResolver: Expressi
208
221
}
209
222
}
210
223
211
- /**
212
- * In case there are non-deterministic expressions in either `groupingExpressions` or
213
- * `aggregateExpressions` replace them with attributes created out of corresponding
214
- * non-deterministic expression. Example:
215
- *
216
- * {{{ SELECT RAND() GROUP BY 1; }}}
217
- *
218
- * This query would have the following analyzed plan:
219
- * Aggregate(
220
- * groupingExpressions = [AttributeReference(_nonDeterministic)]
221
- * aggregateExpressions = [Alias(AttributeReference(_nonDeterministic), `rand()`)]
222
- * child = Project(
223
- * projectList = [Alias(Rand(...), `_nondeterministic`)]
224
- * child = OneRowRelation
225
- * )
226
- * )
227
- */
228
- private def tryPullOutNondeterministic (aggregate : Aggregate ): Aggregate = {
229
- val nondeterministicToAttributes : LinkedHashMap [Expression , NamedExpression ] =
230
- NondeterministicExpressionCollection .getNondeterministicToAttributes(
231
- aggregate.groupingExpressions
232
- )
233
-
234
- if (! nondeterministicToAttributes.isEmpty) {
235
- val newChild = Project (
236
- scopes.current.output ++ nondeterministicToAttributes.values.asScala.toSeq,
237
- aggregate.child
238
- )
239
- val resolvedAggregateExpressions = aggregate.aggregateExpressions.map { expression =>
240
- PullOutNondeterministicExpressionInExpressionTree (expression, nondeterministicToAttributes)
241
- }
242
- val resolvedGroupingExpressions = aggregate.groupingExpressions.map { expression =>
243
- PullOutNondeterministicExpressionInExpressionTree (
244
- expression,
245
- nondeterministicToAttributes
246
- )
247
- }
248
- aggregate.copy(
249
- groupingExpressions = resolvedGroupingExpressions,
250
- aggregateExpressions = resolvedAggregateExpressions,
251
- child = newChild
252
- )
253
- } else {
254
- aggregate
255
- }
256
- }
257
-
258
224
private def canGroupByAll (expressions : Seq [Expression ]): Boolean = {
259
225
val isOrderByAll = expressions match {
260
226
case Seq (unresolvedAttribute : UnresolvedAttribute ) =>
0 commit comments