Skip to content

Commit d1d1836

Browse files
committed
fix: Fix f64/f32 group min/max accumulator in case of +/-infinity
1 parent 7940cf9 commit d1d1836

File tree

3 files changed

+124
-8
lines changed

3 files changed

+124
-8
lines changed

datafusion/src/physical_plan/expressions/count.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
//! Defines physical expressions that can evaluated at runtime during query execution
1919
2020
use std::any::Any;
21-
use std::sync::Arc;
2221
use std::mem::size_of;
22+
use std::sync::Arc;
2323

2424
use crate::error::{DataFusionError, Result};
2525
use crate::physical_plan::groups_accumulator::{EmitTo, GroupsAccumulator};

datafusion/src/physical_plan/expressions/min_max.rs

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
//! Defines physical expressions that can evaluated at runtime during query execution
1919
20+
use core::{f32, f64};
2021
use std::any::Any;
2122
use std::convert::TryFrom;
2223
use std::sync::Arc;
@@ -111,6 +112,9 @@ impl AggregateExpr for Max {
111112
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
112113
macro_rules! make_max_accumulator {
113114
($T:ty) => {
115+
make_max_accumulator!($T, <$T as ArrowPrimitiveType>::Native::MIN)
116+
};
117+
($T:ty, $STARTING_VALUE:expr) => {
114118
Box::new(
115119
PrimitiveGroupsAccumulator::<$T, $T, _, _>::new(
116120
&<$T as ArrowPrimitiveType>::DATA_TYPE,
@@ -123,13 +127,17 @@ impl AggregateExpr for Max {
123127
*x = (*x).max(y);
124128
},
125129
)
126-
.with_starting_value(<$T as ArrowPrimitiveType>::Native::MIN),
130+
.with_starting_value($STARTING_VALUE),
127131
)
128132
};
129133
}
130134
let acc: Box<dyn GroupsAccumulator> = match &self.data_type {
131-
DataType::Float64 => make_max_accumulator!(arrow::datatypes::Float64Type),
132-
DataType::Float32 => make_max_accumulator!(arrow::datatypes::Float32Type),
135+
DataType::Float64 => {
136+
make_max_accumulator!(arrow::datatypes::Float64Type, f64::NEG_INFINITY)
137+
}
138+
DataType::Float32 => {
139+
make_max_accumulator!(arrow::datatypes::Float32Type, f32::NEG_INFINITY)
140+
}
133141
DataType::Int64 => make_max_accumulator!(arrow::datatypes::Int64Type),
134142
DataType::Int96 => make_max_accumulator!(arrow::datatypes::Int96Type),
135143
DataType::Int64Decimal(0) => {
@@ -628,6 +636,9 @@ impl AggregateExpr for Min {
628636
) -> arrow::error::Result<Option<Box<dyn GroupsAccumulator>>> {
629637
macro_rules! make_min_accumulator {
630638
($T:ty) => {
639+
make_min_accumulator!($T, <$T as ArrowPrimitiveType>::Native::MAX)
640+
};
641+
($T:ty, $STARTING_VALUE:expr) => {
631642
Box::new(
632643
PrimitiveGroupsAccumulator::<$T, $T, _, _>::new(
633644
&<$T as ArrowPrimitiveType>::DATA_TYPE,
@@ -640,14 +651,18 @@ impl AggregateExpr for Min {
640651
*x = (*x).min(y);
641652
},
642653
)
643-
.with_starting_value(<$T as ArrowPrimitiveType>::Native::MAX),
654+
.with_starting_value($STARTING_VALUE),
644655
)
645656
};
646657
}
647658

648659
let acc: Box<dyn GroupsAccumulator> = match &self.data_type {
649-
DataType::Float64 => make_min_accumulator!(arrow::datatypes::Float64Type),
650-
DataType::Float32 => make_min_accumulator!(arrow::datatypes::Float32Type),
660+
DataType::Float64 => {
661+
make_min_accumulator!(arrow::datatypes::Float64Type, f64::INFINITY)
662+
}
663+
DataType::Float32 => {
664+
make_min_accumulator!(arrow::datatypes::Float32Type, f32::INFINITY)
665+
}
651666
DataType::Int64 => make_min_accumulator!(arrow::datatypes::Int64Type),
652667
DataType::Int96 => make_min_accumulator!(arrow::datatypes::Int96Type),
653668
DataType::Int64Decimal(0) => {
@@ -770,9 +785,13 @@ impl Accumulator for MinAccumulator {
770785

771786
#[cfg(test)]
772787
mod tests {
788+
use core::f64;
789+
773790
use super::*;
791+
use crate::generic_grouped_test_op;
774792
use crate::physical_plan::expressions::col;
775793
use crate::physical_plan::expressions::tests::aggregate;
794+
use crate::physical_plan::expressions::tests::grouped_aggregate;
776795
use crate::{error::Result, generic_test_op};
777796
use arrow::datatypes::*;
778797
use arrow::record_batch::RecordBatch;
@@ -974,6 +993,30 @@ mod tests {
974993
)
975994
}
976995

996+
#[test]
997+
fn max_f64_infinity() -> Result<()> {
998+
let a: ArrayRef = Arc::new(Float64Array::from(vec![f64::NEG_INFINITY]));
999+
generic_test_op!(
1000+
a,
1001+
DataType::Float64,
1002+
Max,
1003+
ScalarValue::from(f64::NEG_INFINITY),
1004+
DataType::Float64
1005+
)
1006+
}
1007+
1008+
#[test]
1009+
fn max_f64_infinity_grouped() -> Result<()> {
1010+
let a: ArrayRef = Arc::new(Float64Array::from(vec![f64::NEG_INFINITY]));
1011+
generic_grouped_test_op!(
1012+
a,
1013+
DataType::Float64,
1014+
Max,
1015+
ScalarValue::from(f64::NEG_INFINITY),
1016+
DataType::Float64
1017+
)
1018+
}
1019+
9771020
#[test]
9781021
fn min_f64() -> Result<()> {
9791022
let a: ArrayRef =
@@ -986,4 +1029,28 @@ mod tests {
9861029
DataType::Float64
9871030
)
9881031
}
1032+
1033+
#[test]
1034+
fn min_f64_infinity() -> Result<()> {
1035+
let a: ArrayRef = Arc::new(Float64Array::from(vec![f64::INFINITY]));
1036+
generic_test_op!(
1037+
a,
1038+
DataType::Float64,
1039+
Min,
1040+
ScalarValue::from(f64::INFINITY),
1041+
DataType::Float64
1042+
)
1043+
}
1044+
1045+
#[test]
1046+
fn min_f64_infinity_grouped() -> Result<()> {
1047+
let a: ArrayRef = Arc::new(Float64Array::from(vec![f64::INFINITY]));
1048+
generic_grouped_test_op!(
1049+
a,
1050+
DataType::Float64,
1051+
Min,
1052+
ScalarValue::from(f64::INFINITY),
1053+
DataType::Float64
1054+
)
1055+
}
9891056
}

datafusion/src/physical_plan/expressions/mod.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,14 @@ impl PhysicalSortExpr {
121121
#[cfg(test)]
122122
mod tests {
123123
use super::*;
124-
use crate::{error::Result, physical_plan::AggregateExpr, scalar::ScalarValue};
124+
use crate::{
125+
error::Result,
126+
physical_plan::{
127+
groups_accumulator::{EmitTo, GroupsAccumulator},
128+
AggregateExpr,
129+
},
130+
scalar::ScalarValue,
131+
};
125132

126133
/// macro to perform an aggregation and verify the result.
127134
#[macro_export]
@@ -159,4 +166,46 @@ mod tests {
159166
accum.update_batch(&values)?;
160167
accum.evaluate()
161168
}
169+
170+
/// macro to perform a grouped aggregation and verify the result.
171+
#[macro_export]
172+
macro_rules! generic_grouped_test_op {
173+
($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{
174+
let schema = Schema::new(vec![Field::new("a", $DATATYPE, false)]);
175+
176+
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?;
177+
178+
let agg = Arc::new(<$OP>::new(
179+
col("a", &schema)?,
180+
"bla".to_string(),
181+
$EXPECTED_DATATYPE,
182+
));
183+
let actual = grouped_aggregate(&batch, agg)?;
184+
let expected = ScalarValue::from($EXPECTED);
185+
186+
assert_eq!(expected, actual);
187+
188+
Ok(())
189+
}};
190+
}
191+
192+
pub fn grouped_aggregate(
193+
batch: &RecordBatch,
194+
agg: Arc<dyn AggregateExpr>,
195+
) -> Result<ScalarValue> {
196+
let accum = agg.create_groups_accumulator()?;
197+
let mut accum: Box<dyn GroupsAccumulator> =
198+
accum.ok_or(DataFusionError::Internal(
199+
"create_groups_accumulator not supported".to_owned(),
200+
))?;
201+
let expr = agg.expressions();
202+
let values = expr
203+
.iter()
204+
.map(|e| e.evaluate(batch))
205+
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
206+
.collect::<Result<Vec<_>>>()?;
207+
accum.update_batch(&values, &vec![0; values[0].len()], None, 1)?;
208+
let results = accum.evaluate(EmitTo::All)?;
209+
ScalarValue::try_from_array(&results, 0)
210+
}
162211
}

0 commit comments

Comments
 (0)