Skip to content

Commit 829d095

Browse files
committed
[SPARK-52821] add int->DecimalType pyspark udf return type coercion
1 parent c7c1021 commit 829d095

File tree

11 files changed

+362
-18
lines changed

11 files changed

+362
-18
lines changed

python/pyspark/sql/connect/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def createDataFrame(
619619

620620
safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"]
621621

622-
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")
622+
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true", False)
623623

624624
_table = pa.Table.from_batches(
625625
[

python/pyspark/sql/pandas/conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def _create_from_pandas_with_arrow(
739739
jsparkSession = self._jsparkSession
740740

741741
safecheck = self._jconf.arrowSafeTypeConversion()
742-
ser = ArrowStreamPandasSerializer(timezone, safecheck)
742+
ser = ArrowStreamPandasSerializer(timezone, safecheck, False)
743743

744744
@no_type_check
745745
def reader_func(temp_filename):

python/pyspark/sql/pandas/serializers.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
2020
"""
2121

22+
from decimal import Decimal
2223
from itertools import groupby
2324
from typing import TYPE_CHECKING, Optional
2425

@@ -251,12 +252,50 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
251252
If True, conversion from Arrow to Pandas checks for overflow/truncation
252253
assign_cols_by_name : bool
253254
If True, then Pandas DataFrames will get columns by name
255+
int_to_decimal_coercion_enabled : bool
256+
If True, applies additional coercions in Python before converting to Arrow
257+
This has performance penalties.
254258
"""
255259

256-
def __init__(self, timezone, safecheck):
260+
def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
257261
super(ArrowStreamPandasSerializer, self).__init__()
258262
self._timezone = timezone
259263
self._safecheck = safecheck
264+
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
265+
266+
@staticmethod
267+
def _apply_python_coercions(series, arrow_type):
268+
"""
269+
Apply additional coercions to the series in Python before converting to Arrow:
270+
- Convert integer series to decimal type.
271+
When we have a pandas series of integers that needs to be converted to
272+
pyarrow.decimal128 (with precision < 20), PyArrow fails with precision errors.
273+
Explicitly cast to Decimal first.
274+
275+
Parameters
276+
----------
277+
series : pandas.Series
278+
The series to potentially convert
279+
arrow_type : pyarrow.DataType
280+
The target arrow type
281+
282+
Returns
283+
-------
284+
pandas.Series
285+
The potentially converted pandas series
286+
"""
287+
import pyarrow.types as types
288+
import pandas as pd
289+
290+
# Convert integer series to Decimal objects
291+
if (
292+
types.is_decimal(arrow_type)
293+
and series.dtype.kind in ["i", "u"] # integer types (signed/unsigned)
294+
and not series.empty
295+
):
296+
series = series.apply(lambda x: Decimal(x) if pd.notna(x) else None)
297+
298+
return series
260299

261300
def arrow_to_pandas(
262301
self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None
@@ -326,6 +365,9 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
326365
)
327366
series = conv(series)
328367

368+
if self._int_to_decimal_coercion_enabled:
369+
series = self._apply_python_coercions(series, arrow_type)
370+
329371
if hasattr(series.array, "__arrow_array__"):
330372
mask = None
331373
else:
@@ -444,8 +486,11 @@ def __init__(
444486
ndarray_as_list=False,
445487
arrow_cast=False,
446488
input_types=None,
489+
int_to_decimal_coercion_enabled=False,
447490
):
448-
super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
491+
super(ArrowStreamPandasUDFSerializer, self).__init__(
492+
timezone, safecheck, int_to_decimal_coercion_enabled
493+
)
449494
self._assign_cols_by_name = assign_cols_by_name
450495
self._df_for_struct = df_for_struct
451496
self._struct_in_pandas = struct_in_pandas
@@ -799,7 +844,7 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
799844
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
800845
"""
801846

802-
def __init__(self, timezone, safecheck):
847+
def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
803848
super(ArrowStreamPandasUDTFSerializer, self).__init__(
804849
timezone=timezone,
805850
safecheck=safecheck,
@@ -819,6 +864,8 @@ def __init__(self, timezone, safecheck):
819864
ndarray_as_list=True,
820865
# Enables explicit casting for mismatched return types of Arrow Python UDTFs.
821866
arrow_cast=True,
867+
# Enable additional coercions for UDTF serialization
868+
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
822869
)
823870
self._converter_map = dict()
824871

@@ -905,6 +952,9 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
905952
conv = self._get_or_create_converter_from_pandas(dt)
906953
series = conv(series)
907954

955+
if self._int_to_decimal_coercion_enabled:
956+
series = self._apply_python_coercions(series, arrow_type)
957+
908958
if hasattr(series.array, "__arrow_array__"):
909959
mask = None
910960
else:
@@ -1036,9 +1086,13 @@ def __init__(
10361086
state_object_schema,
10371087
arrow_max_records_per_batch,
10381088
prefers_large_var_types,
1089+
int_to_decimal_coercion_enabled,
10391090
):
10401091
super(ApplyInPandasWithStateSerializer, self).__init__(
1041-
timezone, safecheck, assign_cols_by_name
1092+
timezone,
1093+
safecheck,
1094+
assign_cols_by_name,
1095+
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
10421096
)
10431097
self.pickleSer = CPickleSerializer()
10441098
self.utf8_deserializer = UTF8Deserializer()
@@ -1406,9 +1460,19 @@ class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
14061460
Limit of the number of records that can be written to a single ArrowRecordBatch in memory.
14071461
"""
14081462

1409-
def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch):
1463+
def __init__(
1464+
self,
1465+
timezone,
1466+
safecheck,
1467+
assign_cols_by_name,
1468+
arrow_max_records_per_batch,
1469+
int_to_decimal_coercion_enabled,
1470+
):
14101471
super(TransformWithStateInPandasSerializer, self).__init__(
1411-
timezone, safecheck, assign_cols_by_name
1472+
timezone,
1473+
safecheck,
1474+
assign_cols_by_name,
1475+
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
14121476
)
14131477
self.arrow_max_records_per_batch = arrow_max_records_per_batch
14141478
self.key_offsets = None
@@ -1482,9 +1546,20 @@ class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
14821546
Same as input parameters in TransformWithStateInPandasSerializer.
14831547
"""
14841548

1485-
def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch):
1549+
def __init__(
1550+
self,
1551+
timezone,
1552+
safecheck,
1553+
assign_cols_by_name,
1554+
arrow_max_records_per_batch,
1555+
int_to_decimal_coercion_enabled,
1556+
):
14861557
super(TransformWithStateInPandasInitStateSerializer, self).__init__(
1487-
timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch
1558+
timezone,
1559+
safecheck,
1560+
assign_cols_by_name,
1561+
arrow_max_records_per_batch,
1562+
int_to_decimal_coercion_enabled,
14881563
)
14891564
self.init_key_offsets = None
14901565

python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,54 @@ def check_apply_in_pandas_returning_incompatible_type(self):
281281
error_message_regex=expected,
282282
)
283283

284+
def test_cogroup_apply_int_to_decimal_coercion(self):
285+
left = self.data1.limit(3)
286+
right = self.data2.limit(3)
287+
288+
def int_to_decimal_merge(lft, rgt):
289+
return pd.DataFrame(
290+
[
291+
{
292+
"id": 1,
293+
"decimal_result": 98765,
294+
"left_count": len(lft),
295+
"right_count": len(rgt),
296+
}
297+
]
298+
)
299+
300+
with self.sql_conf(
301+
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
302+
):
303+
result = (
304+
left.groupby("id")
305+
.cogroup(right.groupby("id"))
306+
.applyInPandas(
307+
int_to_decimal_merge,
308+
"id long, decimal_result decimal(10,2), left_count long, right_count long",
309+
)
310+
.collect()
311+
)
312+
self.assertTrue(len(result) > 0)
313+
for row in result:
314+
self.assertEqual(row.decimal_result, 98765.00)
315+
316+
with self.sql_conf(
317+
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
318+
):
319+
with self.assertRaisesRegex(
320+
PythonException, "Exception thrown when converting pandas.Series"
321+
):
322+
(
323+
left.groupby("id")
324+
.cogroup(right.groupby("id"))
325+
.applyInPandas(
326+
int_to_decimal_merge,
327+
"id long, decimal_result decimal(10,2), left_count long, right_count long",
328+
)
329+
.collect()
330+
)
331+
284332
def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
285333
df = self.spark.range(0, 10).toDF("v1")
286334
df = df.withColumn("v2", udf(lambda x: x + 1, "int")(df["v1"])).withColumn(

python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,37 @@ def check_apply_in_pandas_returning_incompatible_type(self):
387387
output_schema="id long, mean string",
388388
)
389389

390+
def test_apply_in_pandas_int_to_decimal_coercion(self):
391+
def int_to_decimal_func(key, pdf):
392+
return pd.DataFrame([{"id": key[0], "decimal_result": 12345}])
393+
394+
with self.sql_conf(
395+
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
396+
):
397+
result = (
398+
self.data.groupby("id")
399+
.applyInPandas(int_to_decimal_func, schema="id long, decimal_result decimal(10,2)")
400+
.collect()
401+
)
402+
403+
self.assertTrue(len(result) > 0)
404+
for row in result:
405+
self.assertEqual(row.decimal_result, 12345.00)
406+
407+
with self.sql_conf(
408+
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
409+
):
410+
with self.assertRaisesRegex(
411+
PythonException, "Exception thrown when converting pandas.Series"
412+
):
413+
(
414+
self.data.groupby("id")
415+
.applyInPandas(
416+
int_to_decimal_func, schema="id long, decimal_result decimal(10,2)"
417+
)
418+
.collect()
419+
)
420+
390421
def test_datatype_string(self):
391422
df = self.data
392423

python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import unittest
2525
from typing import cast
26+
from decimal import Decimal
2627

2728
from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
2829
from pyspark.sql.types import (
@@ -31,6 +32,7 @@
3132
StructType,
3233
StructField,
3334
Row,
35+
DecimalType,
3436
)
3537
from pyspark.testing.sqlutils import (
3638
ReusedSQLTestCase,
@@ -59,7 +61,12 @@ def conf(cls):
5961
cfg.set("spark.sql.shuffle.partitions", "5")
6062
return cfg
6163

62-
def _test_apply_in_pandas_with_state_basic(self, func, check_results):
64+
def _test_apply_in_pandas_with_state_basic(self, func, check_results, output_type=None):
65+
if output_type is None:
66+
output_type = StructType(
67+
[StructField("key", StringType()), StructField("countAsString", StringType())]
68+
)
69+
6370
input_path = tempfile.mkdtemp()
6471

6572
def prepare_test_resource():
@@ -75,9 +82,6 @@ def prepare_test_resource():
7582
q.stop()
7683
self.assertTrue(df.isStreaming)
7784

78-
output_type = StructType(
79-
[StructField("key", StringType()), StructField("countAsString", StringType())]
80-
)
8185
state_type = StructType([StructField("c", LongType())])
8286

8387
q = (
@@ -314,6 +318,26 @@ def assert_test():
314318
finally:
315319
q.stop()
316320

321+
def test_apply_in_pandas_with_state_int_to_decimal_coercion(self):
322+
def func(key, pdf_iter, state):
323+
assert isinstance(state, GroupState)
324+
yield pd.DataFrame({"key": [key[0]], "decimal_sum": [1]})
325+
326+
def check_results(batch_df, _):
327+
assert set(batch_df.sort("key").collect()) == {
328+
Row(key="hello", decimal_sum=Decimal("1.00")),
329+
Row(key="this", decimal_sum=Decimal("1.00")),
330+
}, "Decimal coercion failed: " + str(batch_df.sort("key").collect())
331+
332+
output_type = StructType(
333+
[StructField("key", StringType()), StructField("decimal_sum", DecimalType(10, 2))]
334+
)
335+
336+
with self.sql_conf(
337+
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
338+
):
339+
self._test_apply_in_pandas_with_state_basic(func, check_results, output_type)
340+
317341

318342
class GroupedApplyInPandasWithStateTests(
319343
GroupedApplyInPandasWithStateTestsMixin, ReusedSQLTestCase

0 commit comments

Comments
 (0)