19
19
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
20
20
"""
21
21
22
+ from decimal import Decimal
22
23
from itertools import groupby
23
24
from typing import TYPE_CHECKING , Optional
24
25
@@ -251,12 +252,50 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
251
252
If True, conversion from Arrow to Pandas checks for overflow/truncation
252
253
assign_cols_by_name : bool
253
254
If True, then Pandas DataFrames will get columns by name
255
+ int_to_decimal_coercion_enabled : bool
256
+ If True, applies additional coercions in Python before converting to Arrow
257
+ This has performance penalties.
254
258
"""
255
259
256
- def __init__ (self , timezone , safecheck ):
260
+ def __init__ (self , timezone , safecheck , int_to_decimal_coercion_enabled ):
257
261
super (ArrowStreamPandasSerializer , self ).__init__ ()
258
262
self ._timezone = timezone
259
263
self ._safecheck = safecheck
264
+ self ._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
265
+
266
+ @staticmethod
267
+ def _apply_python_coercions (series , arrow_type ):
268
+ """
269
+ Apply additional coercions to the series in Python before converting to Arrow:
270
+ - Convert integer series to decimal type.
271
+ When we have a pandas series of integers that needs to be converted to
272
+ pyarrow.decimal128 (with precision < 20), PyArrow fails with precision errors.
273
+ Explicitly cast to Decimal first.
274
+
275
+ Parameters
276
+ ----------
277
+ series : pandas.Series
278
+ The series to potentially convert
279
+ arrow_type : pyarrow.DataType
280
+ The target arrow type
281
+
282
+ Returns
283
+ -------
284
+ pandas.Series
285
+ The potentially converted pandas series
286
+ """
287
+ import pyarrow .types as types
288
+ import pandas as pd
289
+
290
+ # Convert integer series to Decimal objects
291
+ if (
292
+ types .is_decimal (arrow_type )
293
+ and series .dtype .kind in ["i" , "u" ] # integer types (signed/unsigned)
294
+ and not series .empty
295
+ ):
296
+ series = series .apply (lambda x : Decimal (x ) if pd .notna (x ) else None )
297
+
298
+ return series
260
299
261
300
def arrow_to_pandas (
262
301
self , arrow_column , idx , struct_in_pandas = "dict" , ndarray_as_list = False , spark_type = None
@@ -326,6 +365,9 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
326
365
)
327
366
series = conv (series )
328
367
368
+ if self ._int_to_decimal_coercion_enabled :
369
+ series = self ._apply_python_coercions (series , arrow_type )
370
+
329
371
if hasattr (series .array , "__arrow_array__" ):
330
372
mask = None
331
373
else :
@@ -444,8 +486,11 @@ def __init__(
444
486
ndarray_as_list = False ,
445
487
arrow_cast = False ,
446
488
input_types = None ,
489
+ int_to_decimal_coercion_enabled = False ,
447
490
):
448
- super (ArrowStreamPandasUDFSerializer , self ).__init__ (timezone , safecheck )
491
+ super (ArrowStreamPandasUDFSerializer , self ).__init__ (
492
+ timezone , safecheck , int_to_decimal_coercion_enabled
493
+ )
449
494
self ._assign_cols_by_name = assign_cols_by_name
450
495
self ._df_for_struct = df_for_struct
451
496
self ._struct_in_pandas = struct_in_pandas
@@ -799,7 +844,7 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
799
844
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
800
845
"""
801
846
802
- def __init__ (self , timezone , safecheck ):
847
+ def __init__ (self , timezone , safecheck , int_to_decimal_coercion_enabled ):
803
848
super (ArrowStreamPandasUDTFSerializer , self ).__init__ (
804
849
timezone = timezone ,
805
850
safecheck = safecheck ,
@@ -819,6 +864,8 @@ def __init__(self, timezone, safecheck):
819
864
ndarray_as_list = True ,
820
865
# Enables explicit casting for mismatched return types of Arrow Python UDTFs.
821
866
arrow_cast = True ,
867
+ # Enable additional coercions for UDTF serialization
868
+ int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled ,
822
869
)
823
870
self ._converter_map = dict ()
824
871
@@ -905,6 +952,9 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
905
952
conv = self ._get_or_create_converter_from_pandas (dt )
906
953
series = conv (series )
907
954
955
+ if self ._int_to_decimal_coercion_enabled :
956
+ series = self ._apply_python_coercions (series , arrow_type )
957
+
908
958
if hasattr (series .array , "__arrow_array__" ):
909
959
mask = None
910
960
else :
@@ -1036,9 +1086,13 @@ def __init__(
1036
1086
state_object_schema ,
1037
1087
arrow_max_records_per_batch ,
1038
1088
prefers_large_var_types ,
1089
+ int_to_decimal_coercion_enabled ,
1039
1090
):
1040
1091
super (ApplyInPandasWithStateSerializer , self ).__init__ (
1041
- timezone , safecheck , assign_cols_by_name
1092
+ timezone ,
1093
+ safecheck ,
1094
+ assign_cols_by_name ,
1095
+ int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled ,
1042
1096
)
1043
1097
self .pickleSer = CPickleSerializer ()
1044
1098
self .utf8_deserializer = UTF8Deserializer ()
@@ -1406,9 +1460,19 @@ class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
1406
1460
Limit of the number of records that can be written to a single ArrowRecordBatch in memory.
1407
1461
"""
1408
1462
1409
- def __init__ (self , timezone , safecheck , assign_cols_by_name , arrow_max_records_per_batch ):
1463
+ def __init__ (
1464
+ self ,
1465
+ timezone ,
1466
+ safecheck ,
1467
+ assign_cols_by_name ,
1468
+ arrow_max_records_per_batch ,
1469
+ int_to_decimal_coercion_enabled ,
1470
+ ):
1410
1471
super (TransformWithStateInPandasSerializer , self ).__init__ (
1411
- timezone , safecheck , assign_cols_by_name
1472
+ timezone ,
1473
+ safecheck ,
1474
+ assign_cols_by_name ,
1475
+ int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled ,
1412
1476
)
1413
1477
self .arrow_max_records_per_batch = arrow_max_records_per_batch
1414
1478
self .key_offsets = None
@@ -1482,9 +1546,20 @@ class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
1482
1546
Same as input parameters in TransformWithStateInPandasSerializer.
1483
1547
"""
1484
1548
1485
- def __init__ (self , timezone , safecheck , assign_cols_by_name , arrow_max_records_per_batch ):
1549
+ def __init__ (
1550
+ self ,
1551
+ timezone ,
1552
+ safecheck ,
1553
+ assign_cols_by_name ,
1554
+ arrow_max_records_per_batch ,
1555
+ int_to_decimal_coercion_enabled ,
1556
+ ):
1486
1557
super (TransformWithStateInPandasInitStateSerializer , self ).__init__ (
1487
- timezone , safecheck , assign_cols_by_name , arrow_max_records_per_batch
1558
+ timezone ,
1559
+ safecheck ,
1560
+ assign_cols_by_name ,
1561
+ arrow_max_records_per_batch ,
1562
+ int_to_decimal_coercion_enabled ,
1488
1563
)
1489
1564
self .init_key_offsets = None
1490
1565
0 commit comments