Skip to content

Commit d0e68a4

Browse files
authored
fix: Support aggregate expressions in QUALIFY (apache#17313)
* fix: aggregate references within qualify * add sql integration tests * update slts * appease clippy * update snap
1 parent b084aa4 commit d0e68a4

File tree

5 files changed

+243
-56
lines changed

5 files changed

+243
-56
lines changed

datafusion/sql/src/select.rs

Lines changed: 95 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use crate::query::to_order_by_exprs_with_select;
2424
use crate::utils::{
2525
check_columns_satisfy_exprs, extract_aliases, rebase_expr, resolve_aliases_to_exprs,
2626
resolve_columns, resolve_positions_to_exprs, rewrite_recursive_unnests_bottom_up,
27-
CheckColumnsSatisfyExprsPurpose,
27+
CheckColumnsMustReferenceAggregatePurpose, CheckColumnsSatisfyExprsPurpose,
2828
};
2929

3030
use datafusion_common::error::DataFusionErrorBuilder;
@@ -84,6 +84,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
8484
// Handle named windows before processing the projection expression
8585
check_conflicting_windows(&select.named_window)?;
8686
self.match_window_definitions(&mut select.projection, &select.named_window)?;
87+
8788
// Process the SELECT expressions
8889
let select_exprs = self.prepare_select_exprs(
8990
&base_plan,
@@ -146,39 +147,6 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
146147
})
147148
.transpose()?;
148149

149-
// Optionally the QUALIFY expression.
150-
let qualify_expr_opt = select
151-
.qualify
152-
.map::<Result<Expr>, _>(|qualify_expr| {
153-
let qualify_expr = self.sql_expr_to_logical_expr(
154-
qualify_expr,
155-
&combined_schema,
156-
planner_context,
157-
)?;
158-
// This step "dereferences" any aliases in the QUALIFY clause.
159-
//
160-
// This is how we support queries with QUALIFY expressions that
161-
// refer to aliased columns.
162-
//
163-
// For example:
164-
//
165-
// select row_number() over (PARTITION BY id) as rk from users qualify rk > 1;
166-
//
167-
// are rewritten as, respectively:
168-
//
169-
// select row_number() over (PARTITION BY id) as rk from users qualify row_number() over (PARTITION BY id) > 1;
170-
//
171-
let qualify_expr = resolve_aliases_to_exprs(qualify_expr, &alias_map)?;
172-
normalize_col(qualify_expr, &projected_plan)
173-
})
174-
.transpose()?;
175-
176-
// The outer expressions we will search through for aggregates.
177-
// Aggregates may be sourced from the SELECT list or from the HAVING expression.
178-
let aggr_expr_haystack = select_exprs.iter().chain(having_expr_opt.iter());
179-
// All of the aggregate expressions (deduplicated).
180-
let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack);
181-
182150
// All of the group by expressions
183151
let group_by_exprs = if let GroupByExpr::Expressions(exprs, _) = select.group_by {
184152
exprs
@@ -223,22 +191,61 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
223191
.collect()
224192
};
225193

194+
// Optionally the QUALIFY expression.
195+
let qualify_expr_opt = select
196+
.qualify
197+
.map::<Result<Expr>, _>(|qualify_expr| {
198+
let qualify_expr = self.sql_expr_to_logical_expr(
199+
qualify_expr,
200+
&combined_schema,
201+
planner_context,
202+
)?;
203+
// This step "dereferences" any aliases in the QUALIFY clause.
204+
//
205+
// This is how we support queries with QUALIFY expressions that
206+
// refer to aliased columns.
207+
//
208+
// For example:
209+
//
210+
// select row_number() over (PARTITION BY id) as rk from users qualify rk > 1;
211+
//
212+
// are rewritten as, respectively:
213+
//
214+
// select row_number() over (PARTITION BY id) as rk from users qualify row_number() over (PARTITION BY id) > 1;
215+
//
216+
let qualify_expr = resolve_aliases_to_exprs(qualify_expr, &alias_map)?;
217+
normalize_col(qualify_expr, &projected_plan)
218+
})
219+
.transpose()?;
220+
221+
// The outer expressions we will search through for aggregates.
222+
// Aggregates may be sourced from the SELECT list or from the HAVING expression.
223+
let aggr_expr_haystack = select_exprs
224+
.iter()
225+
.chain(having_expr_opt.iter())
226+
.chain(qualify_expr_opt.iter());
227+
// All of the aggregate expressions (deduplicated).
228+
let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack);
229+
226230
// Process group by, aggregation or having
227-
let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs
228-
.is_empty()
229-
|| !aggr_exprs.is_empty()
230-
{
231+
let (
232+
plan,
233+
mut select_exprs_post_aggr,
234+
having_expr_post_aggr,
235+
qualify_expr_post_aggr,
236+
) = if !group_by_exprs.is_empty() || !aggr_exprs.is_empty() {
231237
self.aggregate(
232238
&base_plan,
233239
&select_exprs,
234240
having_expr_opt.as_ref(),
241+
qualify_expr_opt.as_ref(),
235242
&group_by_exprs,
236243
&aggr_exprs,
237244
)?
238245
} else {
239246
match having_expr_opt {
240247
Some(having_expr) => return plan_err!("HAVING clause references: {having_expr} must appear in the GROUP BY clause or be used in an aggregate function"),
241-
None => (base_plan.clone(), select_exprs.clone(), having_expr_opt)
248+
None => (base_plan.clone(), select_exprs.clone(), having_expr_opt, qualify_expr_opt)
242249
}
243250
};
244251

@@ -252,11 +259,15 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
252259

253260
// The outer expressions we will search through for window functions.
254261
// Window functions may be sourced from the SELECT list or from the QUALIFY expression.
255-
let windows_expr_haystack =
256-
select_exprs_post_aggr.iter().chain(qualify_expr_opt.iter());
257-
// All of the window expressions (deduplicated).
262+
let windows_expr_haystack = select_exprs_post_aggr
263+
.iter()
264+
.chain(qualify_expr_post_aggr.iter());
265+
// All of the window expressions (deduplicated and rewritten to reference aggregates as
266+
// columns from input).
258267
let window_func_exprs = find_window_exprs(windows_expr_haystack);
259268

269+
// Process window functions after aggregation as they can reference
270+
// aggregate functions in their body
260271
let plan = if window_func_exprs.is_empty() {
261272
plan
262273
} else {
@@ -273,7 +284,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
273284

274285
// Process QUALIFY clause after window functions
275286
// QUALIFY filters the results of window functions, similar to how HAVING filters aggregates
276-
let plan = if let Some(qualify_expr) = qualify_expr_opt {
287+
let plan = if let Some(qualify_expr) = qualify_expr_post_aggr {
277288
// Validate that QUALIFY is used with window functions
278289
if window_func_exprs.is_empty() {
279290
return plan_err!(
@@ -839,36 +850,42 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
839850

840851
/// Create an aggregate plan.
841852
///
842-
/// An aggregate plan consists of grouping expressions, aggregate expressions, and an
843-
/// optional HAVING expression (which is a filter on the output of the aggregate).
853+
/// An aggregate plan consists of grouping expressions, aggregate expressions, an
854+
/// optional HAVING expression (which is a filter on the output of the aggregate),
855+
/// and an optional QUALIFY clause which may reference aggregates.
844856
///
845857
/// # Arguments
846858
///
847859
/// * `input` - The input plan that will be aggregated. The grouping, aggregate, and
848860
/// "having" expressions must all be resolvable from this plan.
849861
/// * `select_exprs` - The projection expressions from the SELECT clause.
850862
/// * `having_expr_opt` - Optional HAVING clause.
863+
/// * `qualify_expr_opt` - Optional QUALIFY clause.
851864
/// * `group_by_exprs` - Grouping expressions from the GROUP BY clause. These can be column
852865
/// references or more complex expressions.
853866
/// * `aggr_exprs` - Aggregate expressions, such as `SUM(a)` or `COUNT(1)`.
854867
///
855868
/// # Return
856869
///
857-
/// The return value is a triplet of the following items:
870+
/// The return value is a quadruplet of the following items:
858871
///
859872
/// * `plan` - A [LogicalPlan::Aggregate] plan for the newly created aggregate.
860873
/// * `select_exprs_post_aggr` - The projection expressions rewritten to reference columns from
861874
/// the aggregate
862875
/// * `having_expr_post_aggr` - The "having" expression rewritten to reference a column from
863876
/// the aggregate
877+
/// * `qualify_expr_post_aggr` - The "qualify" expression rewritten to reference a column from
878+
/// the aggregate
879+
#[allow(clippy::type_complexity)]
864880
fn aggregate(
865881
&self,
866882
input: &LogicalPlan,
867883
select_exprs: &[Expr],
868884
having_expr_opt: Option<&Expr>,
885+
qualify_expr_opt: Option<&Expr>,
869886
group_by_exprs: &[Expr],
870887
aggr_exprs: &[Expr],
871-
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>)> {
888+
) -> Result<(LogicalPlan, Vec<Expr>, Option<Expr>, Option<Expr>)> {
872889
// create the aggregate plan
873890
let options =
874891
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
@@ -932,7 +949,9 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
932949
check_columns_satisfy_exprs(
933950
&column_exprs_post_aggr,
934951
&select_exprs_post_aggr,
935-
CheckColumnsSatisfyExprsPurpose::ProjectionMustReferenceAggregate,
952+
CheckColumnsSatisfyExprsPurpose::Aggregate(
953+
CheckColumnsMustReferenceAggregatePurpose::Projection,
954+
),
936955
)?;
937956

938957
// Rewrite the HAVING expression to use the columns produced by the
@@ -944,15 +963,41 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
944963
check_columns_satisfy_exprs(
945964
&column_exprs_post_aggr,
946965
std::slice::from_ref(&having_expr_post_aggr),
947-
CheckColumnsSatisfyExprsPurpose::HavingMustReferenceAggregate,
966+
CheckColumnsSatisfyExprsPurpose::Aggregate(
967+
CheckColumnsMustReferenceAggregatePurpose::Having,
968+
),
948969
)?;
949970

950971
Some(having_expr_post_aggr)
951972
} else {
952973
None
953974
};
954975

955-
Ok((plan, select_exprs_post_aggr, having_expr_post_aggr))
976+
// Rewrite the QUALIFY expression to use the columns produced by the
977+
// aggregation.
978+
let qualify_expr_post_aggr = if let Some(qualify_expr) = qualify_expr_opt {
979+
let qualify_expr_post_aggr =
980+
rebase_expr(qualify_expr, &aggr_projection_exprs, input)?;
981+
982+
check_columns_satisfy_exprs(
983+
&column_exprs_post_aggr,
984+
std::slice::from_ref(&qualify_expr_post_aggr),
985+
CheckColumnsSatisfyExprsPurpose::Aggregate(
986+
CheckColumnsMustReferenceAggregatePurpose::Qualify,
987+
),
988+
)?;
989+
990+
Some(qualify_expr_post_aggr)
991+
} else {
992+
None
993+
};
994+
995+
Ok((
996+
plan,
997+
select_exprs_post_aggr,
998+
having_expr_post_aggr,
999+
qualify_expr_post_aggr,
1000+
))
9561001
}
9571002

9581003
// If the projection is done over a named window, that window

datafusion/sql/src/utils.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,30 @@ pub(crate) fn rebase_expr(
9292
.data()
9393
}
9494

95+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96+
pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
97+
Projection,
98+
Having,
99+
Qualify,
100+
}
101+
95102
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96103
pub(crate) enum CheckColumnsSatisfyExprsPurpose {
97-
ProjectionMustReferenceAggregate,
98-
HavingMustReferenceAggregate,
104+
Aggregate(CheckColumnsMustReferenceAggregatePurpose),
99105
}
100106

101107
impl CheckColumnsSatisfyExprsPurpose {
102108
fn message_prefix(&self) -> &'static str {
103109
match self {
104-
CheckColumnsSatisfyExprsPurpose::ProjectionMustReferenceAggregate => {
110+
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
105111
"Column in SELECT must be in GROUP BY or an aggregate function"
106112
}
107-
CheckColumnsSatisfyExprsPurpose::HavingMustReferenceAggregate => {
113+
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
108114
"Column in HAVING must be in GROUP BY or an aggregate function"
109115
}
116+
Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
117+
"Column in QUALIFY must be in GROUP BY or an aggregate function"
118+
}
110119
}
111120
}
112121

@@ -162,7 +171,7 @@ fn check_column_satisfies_expr(
162171
purpose.diagnostic_message(expr),
163172
expr.spans().and_then(|spans| spans.first()),
164173
)
165-
.with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None);
174+
.with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
166175

167176
return plan_err!(
168177
"{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",

datafusion/sql/tests/cases/diagnostic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ fn test_missing_non_aggregate_in_group_by() -> Result<()> {
184184
let diag = do_query(query);
185185
assert_snapshot!(diag.message, @"'person.first_name' must appear in GROUP BY clause because it's not an aggregate expression");
186186
assert_eq!(diag.span, Some(spans["a"]));
187-
assert_snapshot!(diag.helps[0].message, @"Either add 'person.first_name' to GROUP BY clause, or use an aggregare function like ANY_VALUE(person.first_name)");
187+
assert_snapshot!(diag.helps[0].message, @"Either add 'person.first_name' to GROUP BY clause, or use an aggregate function like ANY_VALUE(person.first_name)");
188188
Ok(())
189189
}
190190

datafusion/sql/tests/sql_integration.rs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4202,6 +4202,67 @@ Projection: person.id, row_number() PARTITION BY [person.age] ORDER BY [person.i
42024202
);
42034203
}
42044204

4205+
#[test]
4206+
fn test_select_qualify_aggregate_reference() {
4207+
let sql = "
4208+
SELECT
4209+
person.id,
4210+
ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY person.id) as rn
4211+
FROM person
4212+
GROUP BY
4213+
person.id
4214+
QUALIFY rn = 1 AND SUM(person.age) > 0";
4215+
let plan = logical_plan(sql).unwrap();
4216+
assert_snapshot!(
4217+
plan,
4218+
@r"
4219+
Projection: person.id, row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn
4220+
Filter: row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = Int64(1) AND sum(person.age) > Int64(0)
4221+
WindowAggr: windowExpr=[[row_number() PARTITION BY [person.id] ORDER BY [person.id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
4222+
Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]]
4223+
TableScan: person
4224+
"
4225+
);
4226+
}
4227+
4228+
#[test]
4229+
fn test_select_qualify_aggregate_reference_within_window_function() {
4230+
let sql = "
4231+
SELECT
4232+
person.id
4233+
FROM person
4234+
GROUP BY
4235+
person.id
4236+
QUALIFY ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY SUM(person.age) DESC) = 1";
4237+
let plan = logical_plan(sql).unwrap();
4238+
assert_snapshot!(
4239+
plan,
4240+
@r"
4241+
Projection: person.id
4242+
Filter: row_number() PARTITION BY [person.id] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = Int64(1)
4243+
WindowAggr: windowExpr=[[row_number() PARTITION BY [person.id] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
4244+
Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]]
4245+
TableScan: person
4246+
"
4247+
);
4248+
}
4249+
4250+
#[test]
4251+
fn test_select_qualify_aggregate_invalid_column_reference() {
4252+
let sql = "
4253+
SELECT
4254+
person.id
4255+
FROM person
4256+
GROUP BY
4257+
person.id
4258+
QUALIFY ROW_NUMBER() OVER (PARTITION BY person.id ORDER BY person.age DESC) = 1";
4259+
let err = logical_plan(sql).unwrap_err();
4260+
assert_snapshot!(
4261+
err.strip_backtrace(),
4262+
@r#"Error during planning: Column in QUALIFY must be in GROUP BY or an aggregate function: While expanding wildcard, column "person.age" must appear in the GROUP BY clause or must be part of an aggregate function, currently only "person.id" appears in the SELECT clause satisfies this requirement"#
4263+
);
4264+
}
4265+
42054266
#[test]
42064267
fn test_select_qualify_without_window_function() {
42074268
let sql = "SELECT person.id FROM person QUALIFY person.id > 1";

0 commit comments

Comments
 (0)