@@ -31,8 +31,8 @@ use arrow::{
31
31
} ,
32
32
compute:: { cast_with_options, take, unary, CastOptions } ,
33
33
datatypes:: {
34
- ArrowPrimitiveType , Decimal128Type , DecimalType , Float32Type , Float64Type , Int64Type ,
35
- TimestampMicrosecondType ,
34
+ is_validate_decimal_precision , ArrowPrimitiveType , Decimal128Type , Float32Type ,
35
+ Float64Type , Int64Type , TimestampMicrosecondType ,
36
36
} ,
37
37
error:: ArrowError ,
38
38
record_batch:: RecordBatch ,
@@ -1287,38 +1287,25 @@ where
1287
1287
for i in 0 ..input. len ( ) {
1288
1288
if input. is_null ( i) {
1289
1289
cast_array. append_null ( ) ;
1290
- } else {
1291
- let input_value = input. value ( i) . as_ ( ) ;
1292
- let value = ( input_value * mul) . round ( ) . to_i128 ( ) ;
1293
-
1294
- match value {
1295
- Some ( v) => {
1296
- if Decimal128Type :: validate_decimal_precision ( v, precision) . is_err ( ) {
1297
- if eval_mode == EvalMode :: Ansi {
1298
- return Err ( SparkError :: NumericValueOutOfRange {
1299
- value : input_value. to_string ( ) ,
1300
- precision,
1301
- scale,
1302
- } ) ;
1303
- } else {
1304
- cast_array. append_null ( ) ;
1305
- }
1306
- }
1307
- cast_array. append_value ( v) ;
1308
- }
1309
- None => {
1310
- if eval_mode == EvalMode :: Ansi {
1311
- return Err ( SparkError :: NumericValueOutOfRange {
1312
- value : input_value. to_string ( ) ,
1313
- precision,
1314
- scale,
1315
- } ) ;
1316
- } else {
1317
- cast_array. append_null ( ) ;
1318
- }
1319
- }
1290
+ continue ;
1291
+ }
1292
+
1293
+ let input_value = input. value ( i) . as_ ( ) ;
1294
+ if let Some ( v) = ( input_value * mul) . round ( ) . to_i128 ( ) {
1295
+ if is_validate_decimal_precision ( v, precision) {
1296
+ cast_array. append_value ( v) ;
1297
+ continue ;
1320
1298
}
1299
+ } ;
1300
+
1301
+ if eval_mode == EvalMode :: Ansi {
1302
+ return Err ( SparkError :: NumericValueOutOfRange {
1303
+ value : input_value. to_string ( ) ,
1304
+ precision,
1305
+ scale,
1306
+ } ) ;
1321
1307
}
1308
+ cast_array. append_null ( ) ;
1322
1309
}
1323
1310
1324
1311
let res = Arc :: new (
@@ -2203,6 +2190,7 @@ mod tests {
2203
2190
use arrow:: array:: StringArray ;
2204
2191
use arrow:: datatypes:: TimestampMicrosecondType ;
2205
2192
use arrow:: datatypes:: { Field , Fields , TimeUnit } ;
2193
+ use core:: f64;
2206
2194
use std:: str:: FromStr ;
2207
2195
2208
2196
use super :: * ;
@@ -2671,4 +2659,35 @@ mod tests {
2671
2659
unreachable ! ( )
2672
2660
}
2673
2661
}
2662
+
2663
+ #[ test]
2664
+ fn test_cast_float_to_decimal ( ) {
2665
+ let a: ArrayRef = Arc :: new ( Float64Array :: from ( vec ! [
2666
+ Some ( 42. ) ,
2667
+ Some ( 0.5153125 ) ,
2668
+ Some ( -42.4242415 ) ,
2669
+ Some ( 42e-314 ) ,
2670
+ Some ( 0. ) ,
2671
+ Some ( -4242.424242 ) ,
2672
+ Some ( f64 :: INFINITY ) ,
2673
+ Some ( f64 :: NEG_INFINITY ) ,
2674
+ Some ( f64 :: NAN ) ,
2675
+ None ,
2676
+ ] ) ) ;
2677
+ let b =
2678
+ cast_floating_point_to_decimal128 :: < Float64Type > ( & a, 8 , 6 , EvalMode :: Legacy ) . unwrap ( ) ;
2679
+ assert_eq ! ( b. len( ) , a. len( ) ) ;
2680
+ let casted = b. as_primitive :: < Decimal128Type > ( ) ;
2681
+ assert_eq ! ( casted. value( 0 ) , 42000000 ) ;
2682
+ // https://github.com/apache/datafusion-comet/issues/1371
2683
+ // assert_eq!(casted.value(1), 515313);
2684
+ assert_eq ! ( casted. value( 2 ) , -42424242 ) ;
2685
+ assert_eq ! ( casted. value( 3 ) , 0 ) ;
2686
+ assert_eq ! ( casted. value( 4 ) , 0 ) ;
2687
+ assert ! ( casted. is_null( 5 ) ) ;
2688
+ assert ! ( casted. is_null( 6 ) ) ;
2689
+ assert ! ( casted. is_null( 7 ) ) ;
2690
+ assert ! ( casted. is_null( 8 ) ) ;
2691
+ assert ! ( casted. is_null( 9 ) ) ;
2692
+ }
2674
2693
}
0 commit comments