Skip to content

Commit a8f045a

Browse files
authored
feat: Add GroupsAccumulator and GroupsAccumulatorFlatAdapter, extending API (#174)
This is based on upstream GroupsAccumulator and GroupsAccumulatorAdapter, but extends the API so that existing hash aggregation works with it. We basically don't really use the upstream interface (at this time). We still use basic Accumulator for types that do not implement GroupsAccumulator, and hash aggregation code handles this poorly.
1 parent b3acc9f commit a8f045a

File tree

8 files changed

+1780
-67
lines changed

8 files changed

+1780
-67
lines changed

datafusion/src/cube_ext/joinagg.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::execution::context::{ExecutionContextState, ExecutionProps};
2525
use crate::logical_plan::{DFSchemaRef, Expr, LogicalPlan, UserDefinedLogicalNode};
2626
use crate::optimizer::optimizer::OptimizerRule;
2727
use crate::optimizer::utils::from_plan;
28-
use crate::physical_plan::hash_aggregate::{Accumulators, AggregateMode};
28+
use crate::physical_plan::hash_aggregate::{create_accumulation_state, AggregateMode};
2929
use crate::physical_plan::planner::{physical_name, ExtensionPlanner};
3030
use crate::physical_plan::{hash_aggregate, PhysicalPlanner};
3131
use crate::physical_plan::{
@@ -245,7 +245,7 @@ impl ExecutionPlan for CrossJoinAggExec {
245245
&AggregateMode::Full,
246246
self.group_expr.len(),
247247
)?;
248-
let mut accumulators = Accumulators::new();
248+
let mut accumulators = create_accumulation_state(&self.agg_expr)?;
249249
for partition in 0..self.join.right.output_partitioning().partition_count() {
250250
let mut batches = self.join.right.execute(partition).await?;
251251
while let Some(right) = batches.next().await {

datafusion/src/physical_plan/expressions/average.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ use std::convert::TryFrom;
2222
use std::sync::Arc;
2323

2424
use crate::error::{DataFusionError, Result};
25+
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
26+
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
2527
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2628
use crate::scalar::ScalarValue;
2729
use arrow::compute;
@@ -112,6 +114,23 @@ impl AggregateExpr for Avg {
112114
)?))
113115
}
114116

117+
fn uses_groups_accumulator(&self) -> bool {
118+
return true;
119+
}
120+
121+
/// the groups accumulator used to accumulate values from the expression. If this returns None,
122+
/// create_accumulator must be used.
123+
fn create_groups_accumulator(
124+
&self,
125+
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
126+
Ok(Some(Box::new(
127+
GroupsAccumulatorFlatAdapter::<AvgAccumulator>::new(|| {
128+
// avg is f64 (as in create_accumulator)
129+
AvgAccumulator::try_new(&DataType::Float64)
130+
}),
131+
)))
132+
}
133+
115134
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
116135
vec![self.expr.clone()]
117136
}

datafusion/src/physical_plan/expressions/sum.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ use std::convert::TryFrom;
2222
use std::sync::Arc;
2323

2424
use crate::error::{DataFusionError, Result};
25+
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
26+
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
2527
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2628
use crate::scalar::ScalarValue;
2729
use arrow::compute;
@@ -42,7 +44,7 @@ use super::format_state_name;
4244
use smallvec::smallvec;
4345
use smallvec::SmallVec;
4446

45-
// SUM aggregate expression
47+
/// SUM aggregate expression
4648
#[derive(Debug)]
4749
pub struct Sum {
4850
name: String,
@@ -118,6 +120,23 @@ impl AggregateExpr for Sum {
118120
Ok(Box::new(SumAccumulator::try_new(&self.data_type)?))
119121
}
120122

123+
fn uses_groups_accumulator(&self) -> bool {
124+
return true;
125+
}
126+
127+
/// the groups accumulator used to accumulate values from the expression. If this returns None,
128+
/// create_accumulator must be used.
129+
fn create_groups_accumulator(
130+
&self,
131+
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
132+
let data_type = self.data_type.clone();
133+
Ok(Some(Box::new(
134+
GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(move || {
135+
SumAccumulator::try_new(&data_type)
136+
}),
137+
)))
138+
}
139+
121140
fn name(&self) -> &str {
122141
&self.name
123142
}

0 commit comments

Comments
 (0)