Skip to content

Commit 1f3af8f

Browse files
authored
fix: Fix overflow handling when casting float to decimal (#1914)
1 parent c65754a commit 1f3af8f

File tree

1 file changed

+51
-32
lines changed
  • native/spark-expr/src/conversion_funcs

1 file changed

+51
-32
lines changed

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ use arrow::{
3131
},
3232
compute::{cast_with_options, take, unary, CastOptions},
3333
datatypes::{
34-
ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type,
35-
TimestampMicrosecondType,
34+
is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type,
35+
Float64Type, Int64Type, TimestampMicrosecondType,
3636
},
3737
error::ArrowError,
3838
record_batch::RecordBatch,
@@ -1287,38 +1287,25 @@ where
12871287
for i in 0..input.len() {
12881288
if input.is_null(i) {
12891289
cast_array.append_null();
1290-
} else {
1291-
let input_value = input.value(i).as_();
1292-
let value = (input_value * mul).round().to_i128();
1293-
1294-
match value {
1295-
Some(v) => {
1296-
if Decimal128Type::validate_decimal_precision(v, precision).is_err() {
1297-
if eval_mode == EvalMode::Ansi {
1298-
return Err(SparkError::NumericValueOutOfRange {
1299-
value: input_value.to_string(),
1300-
precision,
1301-
scale,
1302-
});
1303-
} else {
1304-
cast_array.append_null();
1305-
}
1306-
}
1307-
cast_array.append_value(v);
1308-
}
1309-
None => {
1310-
if eval_mode == EvalMode::Ansi {
1311-
return Err(SparkError::NumericValueOutOfRange {
1312-
value: input_value.to_string(),
1313-
precision,
1314-
scale,
1315-
});
1316-
} else {
1317-
cast_array.append_null();
1318-
}
1319-
}
1290+
continue;
1291+
}
1292+
1293+
let input_value = input.value(i).as_();
1294+
if let Some(v) = (input_value * mul).round().to_i128() {
1295+
if is_validate_decimal_precision(v, precision) {
1296+
cast_array.append_value(v);
1297+
continue;
13201298
}
1299+
};
1300+
1301+
if eval_mode == EvalMode::Ansi {
1302+
return Err(SparkError::NumericValueOutOfRange {
1303+
value: input_value.to_string(),
1304+
precision,
1305+
scale,
1306+
});
13211307
}
1308+
cast_array.append_null();
13221309
}
13231310

13241311
let res = Arc::new(
@@ -2203,6 +2190,7 @@ mod tests {
22032190
use arrow::array::StringArray;
22042191
use arrow::datatypes::TimestampMicrosecondType;
22052192
use arrow::datatypes::{Field, Fields, TimeUnit};
2193+
use core::f64;
22062194
use std::str::FromStr;
22072195

22082196
use super::*;
@@ -2671,4 +2659,35 @@ mod tests {
26712659
unreachable!()
26722660
}
26732661
}
2662+
2663+
#[test]
2664+
fn test_cast_float_to_decimal() {
2665+
let a: ArrayRef = Arc::new(Float64Array::from(vec![
2666+
Some(42.),
2667+
Some(0.5153125),
2668+
Some(-42.4242415),
2669+
Some(42e-314),
2670+
Some(0.),
2671+
Some(-4242.424242),
2672+
Some(f64::INFINITY),
2673+
Some(f64::NEG_INFINITY),
2674+
Some(f64::NAN),
2675+
None,
2676+
]));
2677+
let b =
2678+
cast_floating_point_to_decimal128::<Float64Type>(&a, 8, 6, EvalMode::Legacy).unwrap();
2679+
assert_eq!(b.len(), a.len());
2680+
let casted = b.as_primitive::<Decimal128Type>();
2681+
assert_eq!(casted.value(0), 42000000);
2682+
// https://github.com/apache/datafusion-comet/issues/1371
2683+
// assert_eq!(casted.value(1), 515313);
2684+
assert_eq!(casted.value(2), -42424242);
2685+
assert_eq!(casted.value(3), 0);
2686+
assert_eq!(casted.value(4), 0);
2687+
assert!(casted.is_null(5));
2688+
assert!(casted.is_null(6));
2689+
assert!(casted.is_null(7));
2690+
assert!(casted.is_null(8));
2691+
assert!(casted.is_null(9));
2692+
}
26742693
}

0 commit comments

Comments
 (0)