Skip to content

[WIP][SPARK-52821][PYTHON] add int->DecimalType pyspark udf return type coercion #51538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def createDataFrame(

safecheck = configs["spark.sql.execution.pandas.convertToArrowArraySafely"]

ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true")
ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true", False)

_table = pa.Table.from_batches(
[
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ def _create_from_pandas_with_arrow(
jsparkSession = self._jsparkSession

safecheck = self._jconf.arrowSafeTypeConversion()
ser = ArrowStreamPandasSerializer(timezone, safecheck)
ser = ArrowStreamPandasSerializer(timezone, safecheck, False)

@no_type_check
def reader_func(temp_filename):
Expand Down
92 changes: 84 additions & 8 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
"""

from decimal import Decimal
from itertools import groupby
from typing import TYPE_CHECKING, Optional

Expand Down Expand Up @@ -250,12 +251,51 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
If True, conversion from Arrow to Pandas checks for overflow/truncation
assign_cols_by_name : bool
If True, then Pandas DataFrames will get columns by name
int_to_decimal_coercion_enabled : bool
If True, applies additional coercions in Python before converting to Arrow
This has performance penalties.
"""

def __init__(self, timezone, safecheck):
def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled

@staticmethod
def _apply_python_coercions(series, arrow_type):
"""
Apply additional coercions to the series in Python before converting to Arrow:
- Convert integer series to decimal type.
When we have a pandas series of integers that needs to be converted to
pyarrow.decimal128 (with precision < 20), PyArrow fails with precision errors.
Explicitly cast to Decimal first.

Parameters
----------
series : pandas.Series
The series to potentially convert
arrow_type : pyarrow.DataType
The target arrow type

Returns
-------
pandas.Series
The potentially converted pandas series
"""
import pyarrow.types as types
import pandas as pd
from decimal import Decimal

# Convert integer series to Decimal objects
if (
types.is_decimal(arrow_type)
and series.dtype.kind in ["i", "u"]
and not series.empty # integer types (signed/unsigned)
):
series = series.apply(lambda x: Decimal(x) if pd.notna(x) else None)

return series

def arrow_to_pandas(
self, arrow_column, idx, struct_in_pandas="dict", ndarray_as_list=False, spark_type=None
Expand Down Expand Up @@ -325,6 +365,9 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
)
series = conv(series)

if self._int_to_decimal_coercion_enabled:
series = self._apply_python_coercions(series, arrow_type)

if hasattr(series.array, "__arrow_array__"):
mask = None
else:
Expand Down Expand Up @@ -443,8 +486,11 @@ def __init__(
ndarray_as_list=False,
arrow_cast=False,
input_types=None,
int_to_decimal_coercion_enabled=False,
):
super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
super(ArrowStreamPandasUDFSerializer, self).__init__(
timezone, safecheck, int_to_decimal_coercion_enabled
)
self._assign_cols_by_name = assign_cols_by_name
self._df_for_struct = df_for_struct
self._struct_in_pandas = struct_in_pandas
Expand Down Expand Up @@ -700,7 +746,7 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
"""

def __init__(self, timezone, safecheck):
def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
super(ArrowStreamPandasUDTFSerializer, self).__init__(
timezone=timezone,
safecheck=safecheck,
Expand All @@ -720,6 +766,8 @@ def __init__(self, timezone, safecheck):
ndarray_as_list=True,
# Enables explicit casting for mismatched return types of Arrow Python UDTFs.
arrow_cast=True,
# Enable additional coercions for UDTF serialization
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
self._converter_map = dict()

Expand Down Expand Up @@ -806,6 +854,9 @@ def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
conv = self._get_or_create_converter_from_pandas(dt)
series = conv(series)

if self._int_to_decimal_coercion_enabled:
series = self._apply_python_coercions(series, arrow_type)

if hasattr(series.array, "__arrow_array__"):
mask = None
else:
Expand Down Expand Up @@ -937,9 +988,13 @@ def __init__(
state_object_schema,
arrow_max_records_per_batch,
prefers_large_var_types,
int_to_decimal_coercion_enabled,
):
super(ApplyInPandasWithStateSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name
timezone,
safecheck,
assign_cols_by_name,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
self.pickleSer = CPickleSerializer()
self.utf8_deserializer = UTF8Deserializer()
Expand Down Expand Up @@ -1307,9 +1362,19 @@ class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
Limit of the number of records that can be written to a single ArrowRecordBatch in memory.
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch):
def __init__(
self,
timezone,
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
int_to_decimal_coercion_enabled,
):
super(TransformWithStateInPandasSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name
timezone,
safecheck,
assign_cols_by_name,
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch
self.key_offsets = None
Expand Down Expand Up @@ -1383,9 +1448,20 @@ class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSe
Same as input parameters in TransformWithStateInPandasSerializer.
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch):
def __init__(
self,
timezone,
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
int_to_decimal_coercion_enabled,
):
super(TransformWithStateInPandasInitStateSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch
timezone,
safecheck,
assign_cols_by_name,
arrow_max_records_per_batch,
int_to_decimal_coercion_enabled,
)
self.init_key_offsets = None

Expand Down
48 changes: 48 additions & 0 deletions python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,54 @@ def check_apply_in_pandas_returning_incompatible_type(self):
error_message_regex=expected,
)

def test_cogroup_apply_int_to_decimal_coercion(self):
left = self.data1.limit(3)
right = self.data2.limit(3)

def int_to_decimal_merge(lft, rgt):
return pd.DataFrame(
[
{
"id": 1,
"decimal_result": 98765,
"left_count": len(lft),
"right_count": len(rgt),
}
]
)

with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
):
result = (
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(
int_to_decimal_merge,
"id long, decimal_result decimal(10,2), left_count long, right_count long",
)
.collect()
)
self.assertTrue(len(result) > 0)
for row in result:
self.assertEqual(row.decimal_result, 98765.00)

with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
):
with self.assertRaisesRegex(
PythonException, "Exception thrown when converting pandas.Series"
):
(
left.groupby("id")
.cogroup(right.groupby("id"))
.applyInPandas(
int_to_decimal_merge,
"id long, decimal_result decimal(10,2), left_count long, right_count long",
)
.collect()
)

def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
df = self.spark.range(0, 10).toDF("v1")
df = df.withColumn("v2", udf(lambda x: x + 1, "int")(df["v1"])).withColumn(
Expand Down
31 changes: 31 additions & 0 deletions python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,37 @@ def check_apply_in_pandas_returning_incompatible_type(self):
output_schema="id long, mean string",
)

def test_apply_in_pandas_int_to_decimal_coercion(self):
def int_to_decimal_func(key, pdf):
return pd.DataFrame([{"id": key[0], "decimal_result": 12345}])

with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": True}
):
result = (
self.data.groupby("id")
.applyInPandas(int_to_decimal_func, schema="id long, decimal_result decimal(10,2)")
.collect()
)

self.assertTrue(len(result) > 0)
for row in result:
self.assertEqual(row.decimal_result, 12345.00)

with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": False}
):
with self.assertRaisesRegex(
PythonException, "Exception thrown when converting pandas.Series"
):
(
self.data.groupby("id")
.applyInPandas(
int_to_decimal_func, schema="id long, decimal_result decimal(10,2)"
)
.collect()
)

def test_datatype_string(self):
df = self.data

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import unittest
from typing import cast
from decimal import Decimal

from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
from pyspark.sql.types import (
Expand All @@ -31,6 +32,7 @@
StructType,
StructField,
Row,
DecimalType,
)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
Expand Down Expand Up @@ -314,6 +316,90 @@ def assert_test():
finally:
q.stop()

def _test_apply_in_pandas_with_state_decimal_coercion(self, coercion_enabled, should_succeed):
input_path = tempfile.mkdtemp()

with open(input_path + "/numeric-test.txt", "w") as fw:
fw.write("group1,123\ngroup2,456\ngroup1,789\n")

df = (
self.spark.readStream.format("csv")
.option("header", "false")
.schema("key string, value int")
.load(input_path)
)

for q in self.spark.streams.active:
q.stop()
self.assertTrue(df.isStreaming)

output_type = StructType(
[
StructField("key", StringType()),
StructField("decimal_sum", DecimalType(10, 2)),
StructField("count", LongType()),
]
)
state_type = StructType([StructField("sum", LongType()), StructField("count", LongType())])

def stateful_func(key, pdf_iter, state):
current_sum = state.get[0] if state.exists else 0
current_count = state.get[1] if state.exists else 0

for pdf in pdf_iter:
current_sum += pdf["value"].sum()
current_count += len(pdf)

state.update((current_sum, current_count))
yield pd.DataFrame(
{"key": [key[0]], "decimal_sum": [current_sum], "count": [current_count]}
)

def check_results(batch_df, _):
if should_succeed:
results = batch_df.sort("key").collect()
for row in results:
assert isinstance(row["decimal_sum"], Decimal)
if row["key"] == "group1":
assert row["decimal_sum"] == Decimal("912.00")
elif row["key"] == "group2":
assert row["decimal_sum"] == Decimal("456.00")

with self.sql_conf(
{"spark.sql.execution.pythonUDF.pandas.intToDecimalCoercionEnabled": coercion_enabled}
):
q = (
df.groupBy(df["key"])
.applyInPandasWithState(
stateful_func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout
)
.writeStream.queryName(f"test_coercion_{coercion_enabled}")
.foreachBatch(check_results)
.outputMode("update")
.start()
)

self.assertTrue(q.isActive)

if should_succeed:
q.processAllAvailable()
self.assertTrue(q.exception() is None)
else:
with self.assertRaises(Exception) as context:
q.processAllAvailable()
self.assertIn("STREAM_FAILED", str(context.exception))

q.stop()

def test_apply_in_pandas_with_state_int_to_decimal_coercion(self):
self._test_apply_in_pandas_with_state_decimal_coercion(
coercion_enabled=True, should_succeed=True
)

self._test_apply_in_pandas_with_state_decimal_coercion(
coercion_enabled=False, should_succeed=False
)


class GroupedApplyInPandasWithStateTests(
GroupedApplyInPandasWithStateTestsMixin, ReusedSQLTestCase
Expand Down
Loading