Skip to content

Commit c65754a

Browse files
authored
fix: Add overflow check to evaluate of sum decimal accumulator (#1922)
1 parent eab58d4 commit c65754a

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

native/spark-expr/src/agg_funcs/sum_decimal.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,10 @@ impl Accumulator for SumDecimalAccumulator {
210210
// are null, in this case we'll return null
211211
// 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
212212
// non-ANSI mode Spark returns null.
213-
if self.is_empty || !self.is_not_null {
213+
if self.is_empty
214+
|| !self.is_not_null
215+
|| !is_valid_decimal_precision(self.sum, self.precision)
216+
{
214217
ScalarValue::new_primitive::<Decimal128Type>(
215218
None,
216219
&DataType::Decimal128(self.precision, self.scale),
@@ -375,11 +378,17 @@ impl GroupsAccumulator for SumDecimalGroupsAccumulator {
375378
// are null, in this case we'll return null
376379
// 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In
377380
// non-ANSI mode Spark returns null.
381+
let result = emit_to.take_needed(&mut self.sum);
382+
result.iter().enumerate().for_each(|(i, &v)| {
383+
if !is_valid_decimal_precision(v, self.precision) {
384+
self.is_not_null.set_bit(i, false);
385+
}
386+
});
387+
378388
let nulls = build_bool_state(&mut self.is_not_null, &emit_to);
379389
let is_empty = build_bool_state(&mut self.is_empty, &emit_to);
380390
let x = (!&is_empty).bitand(&nulls);
381391

382-
let result = emit_to.take_needed(&mut self.sum);
383392
let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x)))
384393
.with_data_type(self.result_type.clone());
385394

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,21 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
958958
}
959959
}
960960

961+
test("sum overflow on decimal(38, _)") {
962+
withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
963+
val table = "overflow_decimal_38"
964+
withTable(table) {
965+
sql(s"create table $table(a decimal(38, 2), b INT) using parquet")
966+
sql(s"insert into $table values(42.00, 1), (999999999999999999999999999999999999.99, 1)")
967+
checkSparkAnswerAndNumOfAggregates(s"select sum(a) from $table", 2)
968+
sql(s"insert into $table values(42.00, 2), (99999999999999999999999999999999.99, 2)")
969+
sql(s"insert into $table values(999999999999999999999999999999999999.99, 3)")
970+
sql(s"insert into $table values(99999999999999999999999999999999.99, 4)")
971+
checkSparkAnswerAndNumOfAggregates(s"select sum(a) from $table group by b order by b", 2)
972+
}
973+
}
974+
}
975+
961976
test("test final avg") {
962977
withSQLConf(
963978
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",

0 commit comments

Comments
 (0)