diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 1be61b7ce8fda..46d9eba1029eb 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -526,6 +526,7 @@ def __hash__(self): "pyspark.sql.tests.test_dataframe", "pyspark.sql.tests.test_collection", "pyspark.sql.tests.test_creation", + "pyspark.sql.tests.test_conversion", "pyspark.sql.tests.test_listener", "pyspark.sql.tests.test_observation", "pyspark.sql.tests.test_repartition", diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py index 20e27c5c044ab..bf36522a423a2 100644 --- a/python/pyspark/sql/conversion.py +++ b/python/pyspark/sql/conversion.py @@ -18,7 +18,7 @@ import array import datetime import decimal -from typing import TYPE_CHECKING, Any, Callable, List, Sequence +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, overload from pyspark.errors import PySparkValueError from pyspark.sql.pandas.types import _dedup_names, _deduplicate_field_names, to_arrow_schema @@ -91,16 +91,33 @@ def _need_converter( else: return False + @overload + @staticmethod + def _create_converter(dataType: DataType, nullable: bool = True) -> Callable: + pass + + @overload + @staticmethod + def _create_converter( + dataType: DataType, nullable: bool = True, *, none_on_identity: bool = True + ) -> Optional[Callable]: + pass + @staticmethod def _create_converter( dataType: DataType, nullable: bool = True, - ) -> Callable: + *, + none_on_identity: bool = False, + ) -> Optional[Callable]: assert dataType is not None and isinstance(dataType, DataType) assert isinstance(nullable, bool) if not LocalDataToArrowConversion._need_converter(dataType, nullable): - return lambda value: value + if none_on_identity: + return None + else: + return lambda value: value if isinstance(dataType, NullType): @@ -113,10 +130,13 @@ def convert_null(value: Any) -> Any: elif isinstance(dataType, StructType): field_names = dataType.fieldNames() + len_field_names = len(field_names) dedup_field_names = _dedup_names(dataType.names) field_convs = [ - LocalDataToArrowConversion._create_converter(field.dataType, field.nullable) + LocalDataToArrowConversion._create_converter( + field.dataType, field.nullable, none_on_identity=True + ) for field in dataType.fields ] @@ -126,71 +146,105 @@ def convert_struct(value: Any) -> Any: raise PySparkValueError(f"input for {dataType} must not be None") return None else: - assert isinstance(value, (tuple, dict)) or hasattr( - value, "__dict__" - ), f"{type(value)} {value}" - - _dict = {} - if ( - not isinstance(value, Row) - and not isinstance(value, tuple) # inherited namedtuple - and hasattr(value, "__dict__") - ): - value = value.__dict__ - if isinstance(value, dict): - for i, field in enumerate(field_names): - _dict[dedup_field_names[i]] = field_convs[i](value.get(field)) - else: - if len(value) != len(field_names): + # The `value` should be tuple, dict, or have `__dict__`. + if isinstance(value, tuple): # `Row` inherits `tuple` + if len(value) != len_field_names: raise PySparkValueError( errorClass="AXIS_LENGTH_MISMATCH", messageParameters={ - "expected_length": str(len(field_names)), + "expected_length": str(len_field_names), "actual_length": str(len(value)), }, ) - for i in range(len(field_names)): - _dict[dedup_field_names[i]] = field_convs[i](value[i]) - - return _dict + return { + dedup_field_names[i]: ( + field_convs[i](value[i]) # type: ignore[misc] + if field_convs[i] is not None + else value[i] + ) + for i in range(len_field_names) + } + elif isinstance(value, dict): + return { + dedup_field_names[i]: ( + field_convs[i](value.get(field)) # type: ignore[misc] + if field_convs[i] is not None + else value.get(field) + ) + for i, field in enumerate(field_names) + } + else: + assert hasattr(value, "__dict__"), f"{type(value)} {value}" + value = value.__dict__ + return { + dedup_field_names[i]: ( + field_convs[i](value.get(field)) # type: ignore[misc] + if field_convs[i] is not None + else value.get(field) + ) + for i, field in enumerate(field_names) + } return convert_struct elif isinstance(dataType, ArrayType): element_conv = LocalDataToArrowConversion._create_converter( - dataType.elementType, dataType.containsNull + dataType.elementType, dataType.containsNull, none_on_identity=True ) - def convert_array(value: Any) -> Any: - if value is None: - if not nullable: - raise PySparkValueError(f"input for {dataType} must not be None") - return None - else: - assert isinstance(value, (list, array.array)) - return [element_conv(v) for v in value] + if element_conv is None: + + def convert_array(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + else: + assert isinstance(value, (list, array.array)) + return list(value) + + else: + + def convert_array(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + else: + assert isinstance(value, (list, array.array)) + return [element_conv(v) for v in value] return convert_array elif isinstance(dataType, MapType): - key_conv = LocalDataToArrowConversion._create_converter(dataType.keyType) + key_conv = LocalDataToArrowConversion._create_converter( + dataType.keyType, nullable=False + ) value_conv = LocalDataToArrowConversion._create_converter( - dataType.valueType, dataType.valueContainsNull + dataType.valueType, dataType.valueContainsNull, none_on_identity=True ) - def convert_map(value: Any) -> Any: - if value is None: - if not nullable: - raise PySparkValueError(f"input for {dataType} must not be None") - return None - else: - assert isinstance(value, dict) + if value_conv is None: + + def convert_map(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + else: + assert isinstance(value, dict) + return [(key_conv(k), v) for k, v in value.items()] - _tuples = [] - for k, v in value.items(): - _tuples.append((key_conv(k), value_conv(v))) + else: - return _tuples + def convert_map(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + else: + assert isinstance(value, dict) + return [(key_conv(k), value_conv(v)) for k, v in value.items()] return convert_map @@ -266,15 +320,29 @@ def convert_string(value: Any) -> Any: elif isinstance(dataType, UserDefinedType): udt: UserDefinedType = dataType - conv = LocalDataToArrowConversion._create_converter(udt.sqlType()) + conv = LocalDataToArrowConversion._create_converter( + udt.sqlType(), nullable=nullable, none_on_identity=True + ) - def convert_udt(value: Any) -> Any: - if value is None: - if not nullable: - raise PySparkValueError(f"input for {dataType} must not be None") - return None - else: - return conv(udt.serialize(value)) + if conv is None: + + def convert_udt(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + else: + return udt.serialize(value) + + else: + + def convert_udt(value: Any) -> Any: + if value is None: + if not nullable: + raise PySparkValueError(f"input for {dataType} must not be None") + return None + else: + return conv(udt.serialize(value)) return convert_udt @@ -301,7 +369,10 @@ def convert_other(value: Any) -> Any: return convert_other else: - return lambda value: value + if none_on_identity: + return None + else: + return lambda value: value @staticmethod def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool) -> "pa.Table": @@ -318,7 +389,7 @@ def convert(data: Sequence[Any], schema: StructType, use_large_var_types: bool) def to_row(item: Any) -> tuple: if item is None: return tuple([None] * len_column_names) - elif isinstance(item, (Row, tuple)): + elif isinstance(item, tuple): # `Row` inherits `tuple` if len(item) != len_column_names: raise PySparkValueError( errorClass="AXIS_LENGTH_MISMATCH", @@ -350,11 +421,16 @@ def to_row(item: Any) -> tuple: if len_column_names > 0: column_convs = [ - LocalDataToArrowConversion._create_converter(field.dataType, field.nullable) + LocalDataToArrowConversion._create_converter( + field.dataType, field.nullable, none_on_identity=True + ) for field in schema.fields ] - pylist = [[conv(row[i]) for row in rows] for i, conv in enumerate(column_convs)] + pylist = [ + [conv(row[i]) for row in rows] if conv is not None else [row[i] for row in rows] + for i, conv in enumerate(column_convs) + ] pa_schema = to_arrow_schema( StructType( @@ -402,12 +478,29 @@ def _need_converter(dataType: DataType) -> bool: else: return False + @overload @staticmethod def _create_converter(dataType: DataType) -> Callable: + pass + + @overload + @staticmethod + def _create_converter( + dataType: DataType, *, none_on_identity: bool = True + ) -> Optional[Callable]: + pass + + @staticmethod + def _create_converter( + dataType: DataType, *, none_on_identity: bool = False + ) -> Optional[Callable]: assert dataType is not None and isinstance(dataType, DataType) if not ArrowTableToRowsConversion._need_converter(dataType): - return lambda value: value + if none_on_identity: + return None + else: + return lambda value: value if isinstance(dataType, NullType): return lambda value: None @@ -417,7 +510,8 @@ def _create_converter(dataType: DataType) -> Callable: dedup_field_names = _dedup_names(field_names) field_convs = [ - ArrowTableToRowsConversion._create_converter(f.dataType) for f in dataType.fields + ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True) + for f in dataType.fields ] def convert_struct(value: Any) -> Any: @@ -427,7 +521,9 @@ def convert_struct(value: Any) -> Any: assert isinstance(value, dict) _values = [ - field_convs[i](value.get(name, None)) + field_convs[i](value.get(name, None)) # type: ignore[misc] + if field_convs[i] is not None + else value.get(name, None) for i, name in enumerate(dedup_field_names) ] return _create_row(field_names, _values) @@ -435,28 +531,79 @@ def convert_struct(value: Any) -> Any: return convert_struct elif isinstance(dataType, ArrayType): - element_conv = ArrowTableToRowsConversion._create_converter(dataType.elementType) + element_conv = ArrowTableToRowsConversion._create_converter( + dataType.elementType, none_on_identity=True + ) - def convert_array(value: Any) -> Any: - if value is None: - return None - else: - assert isinstance(value, list) - return [element_conv(v) for v in value] + if element_conv is None: + + def convert_array(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, list) + return value + + else: + + def convert_array(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, list) + return [element_conv(v) for v in value] return convert_array elif isinstance(dataType, MapType): - key_conv = ArrowTableToRowsConversion._create_converter(dataType.keyType) - value_conv = ArrowTableToRowsConversion._create_converter(dataType.valueType) + key_conv = ArrowTableToRowsConversion._create_converter( + dataType.keyType, none_on_identity=True + ) + value_conv = ArrowTableToRowsConversion._create_converter( + dataType.valueType, none_on_identity=True + ) + + if key_conv is None: + if value_conv is None: + + def convert_map(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, list) + assert all(isinstance(t, tuple) and len(t) == 2 for t in value) + return dict(value) + + else: + + def convert_map(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, list) + assert all(isinstance(t, tuple) and len(t) == 2 for t in value) + return dict((t[0], value_conv(t[1])) for t in value) + + else: + if value_conv is None: + + def convert_map(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, list) + assert all(isinstance(t, tuple) and len(t) == 2 for t in value) + return dict((key_conv(t[0]), t[1]) for t in value) - def convert_map(value: Any) -> Any: - if value is None: - return None else: - assert isinstance(value, list) - assert all(isinstance(t, tuple) and len(t) == 2 for t in value) - return dict((key_conv(t[0]), value_conv(t[1])) for t in value) + + def convert_map(value: Any) -> Any: + if value is None: + return None + else: + assert isinstance(value, list) + assert all(isinstance(t, tuple) and len(t) == 2 for t in value) + return dict((key_conv(t[0]), value_conv(t[1])) for t in value) return convert_map @@ -496,13 +643,25 @@ def convert_timestample_ntz(value: Any) -> Any: elif isinstance(dataType, UserDefinedType): udt: UserDefinedType = dataType - conv = ArrowTableToRowsConversion._create_converter(udt.sqlType()) + conv = ArrowTableToRowsConversion._create_converter( + udt.sqlType(), none_on_identity=True + ) - def convert_udt(value: Any) -> Any: - if value is None: - return None - else: - return udt.deserialize(conv(value)) + if conv is None: + + def convert_udt(value: Any) -> Any: + if value is None: + return None + else: + return udt.deserialize(value) + + else: + + def convert_udt(value: Any) -> Any: + if value is None: + return None + else: + return udt.deserialize(conv(value)) return convert_udt @@ -523,7 +682,10 @@ def convert_variant(value: Any) -> Any: return convert_variant else: - return lambda value: value + if none_on_identity: + return None + else: + return lambda value: value @staticmethod def convert(table: "pa.Table", schema: StructType) -> List[Row]: @@ -538,11 +700,12 @@ def convert(table: "pa.Table", schema: StructType) -> List[Row]: if len(fields) > 0: field_converters = [ - ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields + ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True) + for f in schema.fields ] columnar_data = [ - [conv(v) for v in column.to_pylist()] + [conv(v) for v in column.to_pylist()] if conv is not None else column.to_pylist() for column, conv in zip(table.columns, field_converters) ] diff --git a/python/pyspark/sql/tests/test_conversion.py b/python/pyspark/sql/tests/test_conversion.py new file mode 100644 index 0000000000000..2b18fe8d04d7a --- /dev/null +++ b/python/pyspark/sql/tests/test_conversion.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.sql.conversion import ArrowTableToRowsConversion, LocalDataToArrowConversion +from pyspark.sql.types import ( + ArrayType, + BinaryType, + IntegerType, + MapType, + Row, + StringType, + StructType, +) +from pyspark.testing.objects import ExamplePoint, ExamplePointUDT +from pyspark.testing.utils import have_pyarrow, pyarrow_requirement_message + + +@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) +class ConversionTests(unittest.TestCase): + def test_conversion(self): + data = [ + ( + i if i % 2 == 0 else None, + str(i), + i, + str(i).encode(), + [j if j % 2 == 0 else None for j in range(i)], + list(range(i)), + [str(j).encode() for j in range(i)], + {str(j): j if j % 2 == 0 else None for j in range(i)}, + {str(j): j for j in range(i)}, + {str(j): str(j).encode() for j in range(i)}, + (i if i % 2 == 0 else None, str(i), i, str(i).encode()), + {"i": i if i % 2 == 0 else None, "s": str(i), "ii": i, "b": str(i).encode()}, + ExamplePoint(float(i), float(i)), + ) + for i in range(5) + ] + schema = ( + StructType() + .add("i", IntegerType()) + .add("s", StringType()) + .add("ii", IntegerType(), nullable=False) + .add("b", BinaryType()) + .add("arr_i", ArrayType(IntegerType())) + .add("arr_ii", ArrayType(IntegerType(), containsNull=False)) + .add("arr_b", ArrayType(BinaryType())) + .add("map_i", MapType(StringType(), IntegerType())) + .add("map_ii", MapType(StringType(), IntegerType(), valueContainsNull=False)) + .add("map_b", MapType(StringType(), BinaryType())) + .add( + "struct_t", + StructType() + .add("i", IntegerType()) + .add("s", StringType()) + .add("ii", IntegerType(), nullable=False) + .add("b", BinaryType()), + ) + .add( + "struct_d", + StructType() + .add("i", IntegerType()) + .add("s", StringType()) + .add("ii", IntegerType(), nullable=False) + .add("b", BinaryType()), + ) + .add("udt", ExamplePointUDT()) + ) + + tbl = LocalDataToArrowConversion.convert(data, schema, use_large_var_types=False) + actual = ArrowTableToRowsConversion.convert(tbl, schema) + + for a, e in zip( + actual, + [ + Row( + i=i if i % 2 == 0 else None, + s=str(i), + ii=i, + b=str(i).encode(), + arr_i=[j if j % 2 == 0 else None for j in range(i)], + arr_ii=list(range(i)), + arr_b=[str(j).encode() for j in range(i)], + map_i={str(j): j if j % 2 == 0 else None for j in range(i)}, + map_ii={str(j): j for j in range(i)}, + map_b={str(j): str(j).encode() for j in range(i)}, + struct_t=Row(i=i if i % 2 == 0 else None, s=str(i), ii=i, b=str(i).encode()), + struct_d=Row(i=i if i % 2 == 0 else None, s=str(i), ii=i, b=str(i).encode()), + udt=ExamplePoint(float(i), float(i)), + ) + for i in range(5) + ], + ): + with self.subTest(expected=e): + self.assertEqual(a, e) + + +if __name__ == "__main__": + from pyspark.sql.tests.test_conversion import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)