Skip to content

Commit 64ae03e

Browse files
committed
perf: Make Sum use PrimitiveGroupsAccumulator
1 parent 07cdc38 commit 64ae03e

File tree

10 files changed

+1549
-133
lines changed

10 files changed

+1549
-133
lines changed

datafusion/src/cube_ext/joinagg.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ impl ExecutionPlan for CrossJoinAggExec {
245245
&AggregateMode::Full,
246246
self.group_expr.len(),
247247
)?;
248-
let mut accumulators = create_accumulation_state(&self.agg_expr)?;
248+
let mut accumulators: hash_aggregate::AccumulationState =
249+
create_accumulation_state(&self.agg_expr)?;
249250
for partition in 0..self.join.right.output_partitioning().partition_count() {
250251
let mut batches = self.join.right.execute(partition).await?;
251252
while let Some(right) = batches.next().await {
@@ -273,7 +274,7 @@ impl ExecutionPlan for CrossJoinAggExec {
273274
let out_schema = self.schema.clone();
274275
let r = hash_aggregate::create_batch_from_map(
275276
&AggregateMode::Full,
276-
&accumulators,
277+
accumulators,
277278
self.group_expr.len(),
278279
&out_schema,
279280
)?;

datafusion/src/physical_plan/aggregates.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ pub fn create_aggregate_expr(
144144
))
145145
}
146146
(AggregateFunction::Sum, false) => {
147-
Arc::new(expressions::Sum::new(arg, name, return_type))
147+
Arc::new(expressions::Sum::new(arg, name, return_type, &arg_types[0]))
148148
}
149149
(AggregateFunction::Sum, true) => {
150150
return Err(DataFusionError::NotImplemented(

datafusion/src/physical_plan/expressions/sum.rs

Lines changed: 168 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use std::sync::Arc;
2424
use crate::error::{DataFusionError, Result};
2525
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
2626
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
27+
use crate::physical_plan::groups_accumulator_prim_op::PrimitiveGroupsAccumulator;
2728
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2829
use crate::scalar::ScalarValue;
2930
use arrow::compute;
@@ -49,6 +50,7 @@ use smallvec::SmallVec;
4950
pub struct Sum {
5051
name: String,
5152
data_type: DataType,
53+
input_data_type: DataType,
5254
expr: Arc<dyn PhysicalExpr>,
5355
nullable: bool,
5456
}
@@ -80,11 +82,16 @@ impl Sum {
8082
expr: Arc<dyn PhysicalExpr>,
8183
name: impl Into<String>,
8284
data_type: DataType,
85+
input_data_type: &DataType,
8386
) -> Self {
87+
// Note: data_type = sum_return_type(input_data_type) in the actual caller, so we don't
88+
// really need two params. But, we keep the four params to break symmetry with other
89+
// accumulators and any code that might use 3 params, such as the generic_test_op macro.
8490
Self {
8591
name: name.into(),
8692
expr,
8793
data_type,
94+
input_data_type: input_data_type.clone(),
8895
nullable: true,
8996
}
9097
}
@@ -127,12 +134,147 @@ impl AggregateExpr for Sum {
127134
fn create_groups_accumulator(
128135
&self,
129136
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
130-
let data_type = self.data_type.clone();
131-
Ok(Some(Box::new(
132-
GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(move || {
133-
SumAccumulator::try_new(&data_type)
134-
}),
135-
)))
137+
use arrow::datatypes::ArrowPrimitiveType;
138+
139+
macro_rules! make_accumulator {
140+
($T:ty, $U:ty) => {
141+
Box::new(PrimitiveGroupsAccumulator::<$T, $U, _, _>::new(
142+
&<$T as ArrowPrimitiveType>::DATA_TYPE,
143+
|x: &mut <$T as ArrowPrimitiveType>::Native,
144+
y: <$U as ArrowPrimitiveType>::Native| {
145+
*x = *x + (y as <$T as ArrowPrimitiveType>::Native);
146+
},
147+
|x: &mut <$T as ArrowPrimitiveType>::Native,
148+
y: <$T as ArrowPrimitiveType>::Native| {
149+
*x = *x + y;
150+
},
151+
))
152+
};
153+
}
154+
155+
// Note that upstream uses x.add_wrapping(y) for the sum functions -- but here we just mimic
156+
// the current datafusion Sum accumulator implementation using native +. (That native +
157+
// specifically is the one in the expressions *x = *x + ... above.)
158+
Ok(Some(match (&self.data_type, &self.input_data_type) {
159+
(DataType::Int64, DataType::Int64) => make_accumulator!(
160+
arrow::datatypes::Int64Type,
161+
arrow::datatypes::Int64Type
162+
),
163+
(DataType::Int64, DataType::Int32) => make_accumulator!(
164+
arrow::datatypes::Int64Type,
165+
arrow::datatypes::Int32Type
166+
),
167+
(DataType::Int64, DataType::Int16) => make_accumulator!(
168+
arrow::datatypes::Int64Type,
169+
arrow::datatypes::Int16Type
170+
),
171+
(DataType::Int64, DataType::Int8) => {
172+
make_accumulator!(arrow::datatypes::Int64Type, arrow::datatypes::Int8Type)
173+
}
174+
175+
(DataType::Int96, DataType::Int96) => make_accumulator!(
176+
arrow::datatypes::Int96Type,
177+
arrow::datatypes::Int96Type
178+
),
179+
180+
(DataType::Int64Decimal(0), DataType::Int64Decimal(0)) => make_accumulator!(
181+
arrow::datatypes::Int64Decimal0Type,
182+
arrow::datatypes::Int64Decimal0Type
183+
),
184+
(DataType::Int64Decimal(1), DataType::Int64Decimal(1)) => make_accumulator!(
185+
arrow::datatypes::Int64Decimal1Type,
186+
arrow::datatypes::Int64Decimal1Type
187+
),
188+
(DataType::Int64Decimal(2), DataType::Int64Decimal(2)) => make_accumulator!(
189+
arrow::datatypes::Int64Decimal2Type,
190+
arrow::datatypes::Int64Decimal2Type
191+
),
192+
(DataType::Int64Decimal(3), DataType::Int64Decimal(3)) => make_accumulator!(
193+
arrow::datatypes::Int64Decimal3Type,
194+
arrow::datatypes::Int64Decimal3Type
195+
),
196+
(DataType::Int64Decimal(4), DataType::Int64Decimal(4)) => make_accumulator!(
197+
arrow::datatypes::Int64Decimal4Type,
198+
arrow::datatypes::Int64Decimal4Type
199+
),
200+
(DataType::Int64Decimal(5), DataType::Int64Decimal(5)) => make_accumulator!(
201+
arrow::datatypes::Int64Decimal5Type,
202+
arrow::datatypes::Int64Decimal5Type
203+
),
204+
(DataType::Int64Decimal(10), DataType::Int64Decimal(10)) => {
205+
make_accumulator!(
206+
arrow::datatypes::Int64Decimal10Type,
207+
arrow::datatypes::Int64Decimal10Type
208+
)
209+
}
210+
211+
(DataType::Int96Decimal(0), DataType::Int96Decimal(0)) => make_accumulator!(
212+
arrow::datatypes::Int96Decimal0Type,
213+
arrow::datatypes::Int96Decimal0Type
214+
),
215+
(DataType::Int96Decimal(1), DataType::Int96Decimal(1)) => make_accumulator!(
216+
arrow::datatypes::Int96Decimal1Type,
217+
arrow::datatypes::Int96Decimal1Type
218+
),
219+
(DataType::Int96Decimal(2), DataType::Int96Decimal(2)) => make_accumulator!(
220+
arrow::datatypes::Int96Decimal2Type,
221+
arrow::datatypes::Int96Decimal2Type
222+
),
223+
(DataType::Int96Decimal(3), DataType::Int96Decimal(3)) => make_accumulator!(
224+
arrow::datatypes::Int96Decimal3Type,
225+
arrow::datatypes::Int96Decimal3Type
226+
),
227+
(DataType::Int96Decimal(4), DataType::Int96Decimal(4)) => make_accumulator!(
228+
arrow::datatypes::Int96Decimal4Type,
229+
arrow::datatypes::Int96Decimal4Type
230+
),
231+
(DataType::Int96Decimal(5), DataType::Int96Decimal(5)) => make_accumulator!(
232+
arrow::datatypes::Int96Decimal5Type,
233+
arrow::datatypes::Int96Decimal5Type
234+
),
235+
(DataType::Int96Decimal(10), DataType::Int96Decimal(10)) => {
236+
make_accumulator!(
237+
arrow::datatypes::Int96Decimal10Type,
238+
arrow::datatypes::Int96Decimal10Type
239+
)
240+
}
241+
242+
(DataType::UInt64, DataType::UInt64) => make_accumulator!(
243+
arrow::datatypes::UInt64Type,
244+
arrow::datatypes::UInt64Type
245+
),
246+
(DataType::UInt64, DataType::UInt32) => make_accumulator!(
247+
arrow::datatypes::UInt64Type,
248+
arrow::datatypes::UInt32Type
249+
),
250+
(DataType::UInt64, DataType::UInt16) => make_accumulator!(
251+
arrow::datatypes::UInt64Type,
252+
arrow::datatypes::UInt16Type
253+
),
254+
(DataType::UInt64, DataType::UInt8) => make_accumulator!(
255+
arrow::datatypes::UInt64Type,
256+
arrow::datatypes::UInt8Type
257+
),
258+
259+
(DataType::Float32, DataType::Float32) => make_accumulator!(
260+
arrow::datatypes::Float32Type,
261+
arrow::datatypes::Float32Type
262+
),
263+
(DataType::Float64, DataType::Float64) => make_accumulator!(
264+
arrow::datatypes::Float64Type,
265+
arrow::datatypes::Float64Type
266+
),
267+
268+
_ => {
269+
// This case should never be reached because we've handled all sum_return_type
270+
// arg_type values. Nonetheless:
271+
let data_type = self.data_type.clone();
272+
273+
Box::new(GroupsAccumulatorFlatAdapter::<SumAccumulator>::new(
274+
move || SumAccumulator::try_new(&data_type),
275+
))
276+
}
277+
}))
136278
}
137279

138280
fn name(&self) -> &str {
@@ -416,13 +558,27 @@ mod tests {
416558
use arrow::datatypes::*;
417559
use arrow::record_batch::RecordBatch;
418560

561+
// A wrapper to make Sum::new, which now has an input_type argument, work with
562+
// generic_test_op!.
563+
struct SumTestStandin;
564+
impl SumTestStandin {
565+
fn new(
566+
expr: Arc<dyn PhysicalExpr>,
567+
name: impl Into<String>,
568+
data_type: DataType,
569+
) -> Sum {
570+
Sum::new(expr, name, data_type.clone(), &data_type)
571+
}
572+
}
573+
419574
#[test]
420575
fn sum_i32() -> Result<()> {
421576
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
577+
422578
generic_test_op!(
423579
a,
424580
DataType::Int32,
425-
Sum,
581+
SumTestStandin,
426582
ScalarValue::from(15i64),
427583
DataType::Int64
428584
)
@@ -440,7 +596,7 @@ mod tests {
440596
generic_test_op!(
441597
a,
442598
DataType::Int32,
443-
Sum,
599+
SumTestStandin,
444600
ScalarValue::from(13i64),
445601
DataType::Int64
446602
)
@@ -452,7 +608,7 @@ mod tests {
452608
generic_test_op!(
453609
a,
454610
DataType::Int32,
455-
Sum,
611+
SumTestStandin,
456612
ScalarValue::Int64(None),
457613
DataType::Int64
458614
)
@@ -465,7 +621,7 @@ mod tests {
465621
generic_test_op!(
466622
a,
467623
DataType::UInt32,
468-
Sum,
624+
SumTestStandin,
469625
ScalarValue::from(15u64),
470626
DataType::UInt64
471627
)
@@ -478,7 +634,7 @@ mod tests {
478634
generic_test_op!(
479635
a,
480636
DataType::Float32,
481-
Sum,
637+
SumTestStandin,
482638
ScalarValue::from(15_f32),
483639
DataType::Float32
484640
)
@@ -491,7 +647,7 @@ mod tests {
491647
generic_test_op!(
492648
a,
493649
DataType::Float64,
494-
Sum,
650+
SumTestStandin,
495651
ScalarValue::from(15_f64),
496652
DataType::Float64
497653
)

datafusion/src/physical_plan/groups_accumulator.rs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
//! Vectorized [`GroupsAccumulator`]
1919
2020
use crate::error::{DataFusionError, Result};
21-
use crate::scalar::ScalarValue;
2221
use arrow::array::{ArrayRef, BooleanArray};
23-
use smallvec::SmallVec;
2422

2523
/// From upstream: This replaces a datafusion_common::{not_impl_err} import.
2624
macro_rules! not_impl_err {
@@ -194,10 +192,6 @@ pub trait GroupsAccumulator: Send {
194192
/// `n`. See [`EmitTo::First`] for more details.
195193
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef>;
196194

197-
// TODO: Remove this?
198-
/// evaluate for a particular group index.
199-
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue>;
200-
201195
/// Returns the intermediate aggregate state for this accumulator,
202196
/// used for multi-phase grouping, resetting its internal state.
203197
///
@@ -216,10 +210,6 @@ pub trait GroupsAccumulator: Send {
216210
/// [`Accumulator::state`]: crate::accumulator::Accumulator::state
217211
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
218212

219-
// TODO: Remove this?
220-
/// Looks at the state for a particular group index.
221-
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>>;
222-
223213
/// Merges intermediate state (the output from [`Self::state`])
224214
/// into this accumulator's current state.
225215
///

datafusion/src/physical_plan/groups_accumulator_adapter.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ use arrow::{
3333
compute,
3434
datatypes::UInt32Type,
3535
};
36-
use smallvec::SmallVec;
3736

3837
/// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`]
3938
///
@@ -345,10 +344,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
345344
result
346345
}
347346

348-
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
349-
self.states[group_index].accumulator.evaluate()
350-
}
351-
352347
// filtered_null_mask(opt_filter, &values);
353348
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
354349
let vec_size_pre = self.states.allocated_size();
@@ -385,10 +380,6 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
385380
Ok(arrays)
386381
}
387382

388-
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
389-
self.states[group_index].accumulator.state()
390-
}
391-
392383
fn merge_batch(
393384
&mut self,
394385
values: &[ArrayRef],

datafusion/src/physical_plan/groups_accumulator_flat_adapter.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -387,10 +387,6 @@ impl<AccumulatorType: Accumulator> GroupsAccumulator
387387
result
388388
}
389389

390-
fn peek_evaluate(&self, group_index: usize) -> Result<ScalarValue> {
391-
self.accumulators[group_index].evaluate()
392-
}
393-
394390
// filtered_null_mask(opt_filter, &values);
395391
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
396392
let vec_size_pre = self.accumulators.allocated_size();
@@ -428,10 +424,6 @@ impl<AccumulatorType: Accumulator> GroupsAccumulator
428424
Ok(arrays)
429425
}
430426

431-
fn peek_state(&self, group_index: usize) -> Result<SmallVec<[ScalarValue; 2]>> {
432-
self.accumulators[group_index].state()
433-
}
434-
435427
fn merge_batch(
436428
&mut self,
437429
values: &[ArrayRef],

0 commit comments

Comments
 (0)