Skip to content

Commit 51d6d74

Browse files
committed
perf: min/max groups accumulator
1 parent 64ae03e commit 51d6d74

File tree

1 file changed

+173
-13
lines changed
  • datafusion/src/physical_plan/expressions

1 file changed

+173
-13
lines changed

datafusion/src/physical_plan/expressions/min_max.rs

Lines changed: 173 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ use std::sync::Arc;
2424
use crate::error::{DataFusionError, Result};
2525
use crate::physical_plan::groups_accumulator::GroupsAccumulator;
2626
use crate::physical_plan::groups_accumulator_flat_adapter::GroupsAccumulatorFlatAdapter;
27+
use crate::physical_plan::groups_accumulator_prim_op::PrimitiveGroupsAccumulator;
2728
use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr};
2829
use crate::scalar::ScalarValue;
2930
use arrow::compute;
30-
use arrow::datatypes::{DataType, TimeUnit};
31+
use arrow::datatypes::{ArrowPrimitiveType, DataType, TimeUnit};
3132
use arrow::{
3233
array::{
3334
ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
@@ -108,12 +109,92 @@ impl AggregateExpr for Max {
108109
fn create_groups_accumulator(
109110
&self,
110111
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
111-
let data_type = self.data_type.clone();
112-
Ok(Some(Box::new(
113-
GroupsAccumulatorFlatAdapter::<MaxAccumulator>::new(move || {
114-
MaxAccumulator::try_new(&data_type)
115-
}),
116-
)))
112+
use arrow::datatypes::ArrowPrimitiveType;
113+
114+
macro_rules! make_max_accumulator {
115+
($T:ty) => {
116+
Box::new(
117+
PrimitiveGroupsAccumulator::<$T, $T, _, _>::new(
118+
&<$T as ArrowPrimitiveType>::DATA_TYPE,
119+
|x: &mut <$T as ArrowPrimitiveType>::Native,
120+
y: <$T as ArrowPrimitiveType>::Native| {
121+
*x = (*x).max(y);
122+
},
123+
|x: &mut <$T as ArrowPrimitiveType>::Native,
124+
y: <$T as ArrowPrimitiveType>::Native| {
125+
*x = (*x).max(y);
126+
},
127+
)
128+
.with_starting_value(<$T as ArrowPrimitiveType>::Native::MIN),
129+
)
130+
};
131+
}
132+
let acc: Box<dyn GroupsAccumulator> = match &self.data_type {
133+
DataType::Float64 => make_max_accumulator!(arrow::datatypes::Float64Type),
134+
DataType::Float32 => make_max_accumulator!(arrow::datatypes::Float32Type),
135+
DataType::Int64 => make_max_accumulator!(arrow::datatypes::Int64Type),
136+
DataType::Int96 => make_max_accumulator!(arrow::datatypes::Int96Type),
137+
DataType::Int64Decimal(0) => {
138+
make_max_accumulator!(arrow::datatypes::Int64Decimal0Type)
139+
}
140+
DataType::Int64Decimal(1) => {
141+
make_max_accumulator!(arrow::datatypes::Int64Decimal1Type)
142+
}
143+
DataType::Int64Decimal(2) => {
144+
make_max_accumulator!(arrow::datatypes::Int64Decimal2Type)
145+
}
146+
DataType::Int64Decimal(3) => {
147+
make_max_accumulator!(arrow::datatypes::Int64Decimal3Type)
148+
}
149+
DataType::Int64Decimal(4) => {
150+
make_max_accumulator!(arrow::datatypes::Int64Decimal4Type)
151+
}
152+
DataType::Int64Decimal(5) => {
153+
make_max_accumulator!(arrow::datatypes::Int64Decimal5Type)
154+
}
155+
DataType::Int64Decimal(10) => {
156+
make_max_accumulator!(arrow::datatypes::Int64Decimal10Type)
157+
}
158+
DataType::Int96Decimal(0) => {
159+
make_max_accumulator!(arrow::datatypes::Int96Decimal0Type)
160+
}
161+
DataType::Int96Decimal(1) => {
162+
make_max_accumulator!(arrow::datatypes::Int96Decimal1Type)
163+
}
164+
DataType::Int96Decimal(2) => {
165+
make_max_accumulator!(arrow::datatypes::Int96Decimal2Type)
166+
}
167+
DataType::Int96Decimal(3) => {
168+
make_max_accumulator!(arrow::datatypes::Int96Decimal3Type)
169+
}
170+
DataType::Int96Decimal(4) => {
171+
make_max_accumulator!(arrow::datatypes::Int96Decimal4Type)
172+
}
173+
DataType::Int96Decimal(5) => {
174+
make_max_accumulator!(arrow::datatypes::Int96Decimal5Type)
175+
}
176+
DataType::Int96Decimal(10) => {
177+
make_max_accumulator!(arrow::datatypes::Int96Decimal10Type)
178+
}
179+
DataType::Int32 => make_max_accumulator!(arrow::datatypes::Int32Type),
180+
DataType::Int16 => make_max_accumulator!(arrow::datatypes::Int16Type),
181+
DataType::Int8 => make_max_accumulator!(arrow::datatypes::Int8Type),
182+
DataType::UInt64 => make_max_accumulator!(arrow::datatypes::UInt64Type),
183+
DataType::UInt32 => make_max_accumulator!(arrow::datatypes::UInt32Type),
184+
DataType::UInt16 => make_max_accumulator!(arrow::datatypes::UInt16Type),
185+
DataType::UInt8 => make_max_accumulator!(arrow::datatypes::UInt8Type),
186+
_ => {
187+
// Not all types (strings) can use primitive accumulators. And strings use
188+
// max_string as the $OP in typed_min_match_batch.
189+
190+
// Timestamps presently take this branch.
191+
let data_type = self.data_type.clone();
192+
Box::new(GroupsAccumulatorFlatAdapter::<MaxAccumulator>::new(
193+
move || MaxAccumulator::try_new(&data_type),
194+
))
195+
}
196+
};
197+
Ok(Some(acc))
117198
}
118199

119200
fn name(&self) -> &str {
@@ -547,12 +628,91 @@ impl AggregateExpr for Min {
547628
fn create_groups_accumulator(
548629
&self,
549630
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
550-
let data_type = self.data_type.clone();
551-
Ok(Some(Box::new(
552-
GroupsAccumulatorFlatAdapter::<MinAccumulator>::new(move || {
553-
MinAccumulator::try_new(&data_type)
554-
}),
555-
)))
631+
macro_rules! make_min_accumulator {
632+
($T:ty) => {
633+
Box::new(
634+
PrimitiveGroupsAccumulator::<$T, $T, _, _>::new(
635+
&<$T as ArrowPrimitiveType>::DATA_TYPE,
636+
|x: &mut <$T as ArrowPrimitiveType>::Native,
637+
y: <$T as ArrowPrimitiveType>::Native| {
638+
*x = (*x).min(y);
639+
},
640+
|x: &mut <$T as ArrowPrimitiveType>::Native,
641+
y: <$T as ArrowPrimitiveType>::Native| {
642+
*x = (*x).min(y);
643+
},
644+
)
645+
.with_starting_value(<$T as ArrowPrimitiveType>::Native::MAX),
646+
)
647+
};
648+
}
649+
650+
let acc: Box<dyn GroupsAccumulator> = match &self.data_type {
651+
DataType::Float64 => make_min_accumulator!(arrow::datatypes::Float64Type),
652+
DataType::Float32 => make_min_accumulator!(arrow::datatypes::Float32Type),
653+
DataType::Int64 => make_min_accumulator!(arrow::datatypes::Int64Type),
654+
DataType::Int96 => make_min_accumulator!(arrow::datatypes::Int96Type),
655+
DataType::Int64Decimal(0) => {
656+
make_min_accumulator!(arrow::datatypes::Int64Decimal0Type)
657+
}
658+
DataType::Int64Decimal(1) => {
659+
make_min_accumulator!(arrow::datatypes::Int64Decimal1Type)
660+
}
661+
DataType::Int64Decimal(2) => {
662+
make_min_accumulator!(arrow::datatypes::Int64Decimal2Type)
663+
}
664+
DataType::Int64Decimal(3) => {
665+
make_min_accumulator!(arrow::datatypes::Int64Decimal3Type)
666+
}
667+
DataType::Int64Decimal(4) => {
668+
make_min_accumulator!(arrow::datatypes::Int64Decimal4Type)
669+
}
670+
DataType::Int64Decimal(5) => {
671+
make_min_accumulator!(arrow::datatypes::Int64Decimal5Type)
672+
}
673+
DataType::Int64Decimal(10) => {
674+
make_min_accumulator!(arrow::datatypes::Int64Decimal10Type)
675+
}
676+
DataType::Int96Decimal(0) => {
677+
make_min_accumulator!(arrow::datatypes::Int96Decimal0Type)
678+
}
679+
DataType::Int96Decimal(1) => {
680+
make_min_accumulator!(arrow::datatypes::Int96Decimal1Type)
681+
}
682+
DataType::Int96Decimal(2) => {
683+
make_min_accumulator!(arrow::datatypes::Int96Decimal2Type)
684+
}
685+
DataType::Int96Decimal(3) => {
686+
make_min_accumulator!(arrow::datatypes::Int96Decimal3Type)
687+
}
688+
DataType::Int96Decimal(4) => {
689+
make_min_accumulator!(arrow::datatypes::Int96Decimal4Type)
690+
}
691+
DataType::Int96Decimal(5) => {
692+
make_min_accumulator!(arrow::datatypes::Int96Decimal5Type)
693+
}
694+
DataType::Int96Decimal(10) => {
695+
make_min_accumulator!(arrow::datatypes::Int96Decimal10Type)
696+
}
697+
DataType::Int32 => make_min_accumulator!(arrow::datatypes::Int32Type),
698+
DataType::Int16 => make_min_accumulator!(arrow::datatypes::Int16Type),
699+
DataType::Int8 => make_min_accumulator!(arrow::datatypes::Int8Type),
700+
DataType::UInt64 => make_min_accumulator!(arrow::datatypes::UInt64Type),
701+
DataType::UInt32 => make_min_accumulator!(arrow::datatypes::UInt32Type),
702+
DataType::UInt16 => make_min_accumulator!(arrow::datatypes::UInt16Type),
703+
DataType::UInt8 => make_min_accumulator!(arrow::datatypes::UInt8Type),
704+
_ => {
705+
// Not all types (strings) can use primitive accumulators. And strings use
706+
// min_string as the $OP in typed_min_match_batch.
707+
708+
// Timestamps presently take this branch.
709+
let data_type = self.data_type.clone();
710+
Box::new(GroupsAccumulatorFlatAdapter::<MinAccumulator>::new(
711+
move || MinAccumulator::try_new(&data_type),
712+
))
713+
}
714+
};
715+
Ok(Some(acc))
556716
}
557717

558718
fn name(&self) -> &str {

0 commit comments

Comments
 (0)