Skip to content

Commit ca2bc8b

Browse files
committed
fix: Add non-missing aggregates to projection, update Union schema
Signed-off-by: Alex Qyoun-ae <[email protected]>
1 parent 46de680 commit ca2bc8b

File tree

1 file changed

+110
-3
lines changed

1 file changed

+110
-3
lines changed

datafusion/core/src/logical_plan/builder.rs

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -698,11 +698,11 @@ impl LogicalPlanBuilder {
698698

699699
let mut missing_exprs = Vec::with_capacity(missing_aggr_exprs.len());
700700
for missing_aggr_expr in missing_aggr_exprs {
701+
let expr_name = missing_aggr_expr.name(input_schema)?;
702+
alias_map.insert(expr_name.clone(), expr_name);
701703
if aggr_expr.contains(missing_aggr_expr) {
702704
continue;
703705
}
704-
let expr_name = missing_aggr_expr.name(input_schema)?;
705-
alias_map.insert(expr_name.clone(), expr_name);
706706
missing_exprs.push(missing_aggr_expr.clone());
707707
}
708708

@@ -725,6 +725,31 @@ impl LogicalPlanBuilder {
725725
schema: Arc::new(new_schema),
726726
}))
727727
}
728+
LogicalPlan::Union(Union {
729+
inputs,
730+
schema: _,
731+
alias,
732+
}) => {
733+
let inputs = inputs
734+
.into_iter()
735+
.map(|input_plan| {
736+
self.add_missing_aggr_exprs(
737+
input_plan,
738+
missing_aggr_exprs,
739+
alias_map,
740+
)
741+
})
742+
.collect::<Result<Vec<_>>>()?;
743+
let Some(first_input) = inputs.first() else {
744+
return Err(DataFusionError::Internal("Inputs in union are empty".to_string()));
745+
};
746+
let schema = Arc::clone(first_input.schema());
747+
Ok(LogicalPlan::Union(Union {
748+
inputs,
749+
schema,
750+
alias,
751+
}))
752+
}
728753
_ => {
729754
let new_inputs = curr_plan
730755
.inputs()
@@ -1482,7 +1507,7 @@ pub(crate) fn table_udfs(plan: LogicalPlan, udtf_expr: Vec<Expr>) -> Result<Logi
14821507

14831508
#[cfg(test)]
14841509
mod tests {
1485-
use arrow::datatypes::{DataType, Field};
1510+
use arrow::datatypes::{DataType, Field, TimeUnit};
14861511

14871512
use crate::logical_plan::StringifiedPlan;
14881513

@@ -1725,6 +1750,76 @@ mod tests {
17251750
}
17261751
}
17271752

1753+
#[test]
1754+
fn plan_builder_order_by_missing_aggr() -> Result<()> {
1755+
let builder = LogicalPlanBuilder::scan_empty(Some("Ecom"), &ecom_schema(), None)?
1756+
.filter(col("Ecom.status").eq(lit("completed")))?;
1757+
1758+
let first_plan = builder
1759+
.aggregate(
1760+
[col("Ecom.created"), col("Ecom.status")],
1761+
[sum(col("Ecom.sumPrice"))],
1762+
)?
1763+
.filter(col("SUM(Ecom.sumPrice)").is_not_null())?
1764+
.project([
1765+
col("Ecom.created").alias("Ecom[created]"),
1766+
col("Ecom.status").alias("Ecom[status]"),
1767+
lit(false).alias("[IsGrandTotalRowTotal]"),
1768+
col("SUM(Ecom.sumPrice)").alias("[count]"),
1769+
])?
1770+
.build()?;
1771+
1772+
let second_plan = builder
1773+
.aggregate(Vec::<Expr>::new(), [sum(col("Ecom.sumPrice"))])?
1774+
.filter(col("SUM(Ecom.sumPrice)").is_not_null())?
1775+
.project([
1776+
Expr::Literal(ScalarValue::Null).alias("Ecom[created]"),
1777+
Expr::Literal(ScalarValue::Null).alias("Ecom[status]"),
1778+
lit(true).alias("[IsGrandTotalRowTotal]"),
1779+
col("SUM(Ecom.sumPrice)").alias("[count]"),
1780+
])?
1781+
.build()?;
1782+
1783+
let plan_before_sort = LogicalPlanBuilder::from(first_plan)
1784+
.union(second_plan)?
1785+
.sort([
1786+
col("[IsGrandTotalRowTotal]").sort(false, true),
1787+
col("Ecom[created]").sort(true, false),
1788+
col("Ecom[status]").sort(true, false),
1789+
])?
1790+
.limit(None, Some(502))?
1791+
.build()?;
1792+
1793+
let plan_with_sort = LogicalPlanBuilder::from(plan_before_sort)
1794+
.sort([
1795+
col("[IsGrandTotalRowTotal]").sort(false, true),
1796+
sum(col("Ecom.sumPrice")).sort(true, false),
1797+
col("Ecom[status]").sort(true, false),
1798+
])?
1799+
.build()?;
1800+
1801+
let expected = "\
1802+
Projection: #Ecom[created], #Ecom[status], #[IsGrandTotalRowTotal], #[count]\
1803+
\n Sort: #[IsGrandTotalRowTotal] DESC NULLS FIRST, #SUM(Ecom.sumPrice) ASC NULLS LAST, #Ecom[status] ASC NULLS LAST\
1804+
\n Limit: skip=None, fetch=502\
1805+
\n Sort: #[IsGrandTotalRowTotal] DESC NULLS FIRST, #Ecom[created] ASC NULLS LAST, #Ecom[status] ASC NULLS LAST\
1806+
\n Union\
1807+
\n Projection: #Ecom.created AS Ecom[created], #Ecom.status AS Ecom[status], Boolean(false) AS [IsGrandTotalRowTotal], #SUM(Ecom.sumPrice) AS [count], #SUM(Ecom.sumPrice)\
1808+
\n Filter: #SUM(Ecom.sumPrice) IS NOT NULL\
1809+
\n Aggregate: groupBy=[[#Ecom.created, #Ecom.status]], aggr=[[SUM(#Ecom.sumPrice)]]\
1810+
\n Filter: #Ecom.status = Utf8(\"completed\")\
1811+
\n TableScan: Ecom projection=None\
1812+
\n Projection: CAST(NULL AS Timestamp(Nanosecond, None)) AS Ecom[created], CAST(NULL AS Utf8) AS Ecom[status], Boolean(true) AS [IsGrandTotalRowTotal], #SUM(Ecom.sumPrice) AS [count], #SUM(Ecom.sumPrice)\
1813+
\n Filter: #SUM(Ecom.sumPrice) IS NOT NULL\
1814+
\n Aggregate: groupBy=[[]], aggr=[[SUM(#Ecom.sumPrice)]]\
1815+
\n Filter: #Ecom.status = Utf8(\"completed\")\
1816+
\n TableScan: Ecom projection=None";
1817+
1818+
assert_eq!(expected, format!("{:?}", plan_with_sort));
1819+
1820+
Ok(())
1821+
}
1822+
17281823
fn employee_schema() -> Schema {
17291824
Schema::new(vec![
17301825
Field::new("id", DataType::Int32, false),
@@ -1735,6 +1830,18 @@ mod tests {
17351830
])
17361831
}
17371832

1833+
fn ecom_schema() -> Schema {
1834+
Schema::new(vec![
1835+
Field::new(
1836+
"created",
1837+
DataType::Timestamp(TimeUnit::Nanosecond, None),
1838+
true,
1839+
),
1840+
Field::new("status", DataType::Utf8, true),
1841+
Field::new("sumPrice", DataType::Float64, true),
1842+
])
1843+
}
1844+
17381845
#[test]
17391846
fn stringified_plan() {
17401847
let stringified_plan =

0 commit comments

Comments
 (0)