Skip to content

Commit 9e7737d

Browse files
committed
[SPARK-52449] Make datatypes for Expression.Literal.Map/Expression.Literal.Array optional
1 parent 5724c71 commit 9e7737d

File tree

8 files changed

+481
-198
lines changed

8 files changed

+481
-198
lines changed

python/pyspark/sql/connect/expressions.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Optional,
3030
)
3131

32+
import os
3233
import json
3334
import decimal
3435
import datetime
@@ -84,6 +85,9 @@
8485
from pyspark.sql.connect.window import WindowSpec
8586
from pyspark.sql.connect.plan import LogicalPlan
8687

88+
_optional_data_types_for_map_and_array_literals_enabled = (
89+
os.getenv("CONNECT_OPTIONAL_DATATYPE_FOR_MAP_AND_ARRAY_LITERALS_ENABLED", "false") == "true"
90+
)
8791

8892
class Expression:
8993
"""
@@ -426,11 +430,26 @@ def _to_value(
426430
assert dataType is None or isinstance(dataType, DayTimeIntervalType)
427431
return DayTimeIntervalType().fromInternal(literal.day_time_interval)
428432
elif literal.HasField("array"):
429-
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
433+
elements = literal.array.elements
434+
result = []
430435
if dataType is not None:
431436
assert isinstance(dataType, ArrayType)
432-
assert elementType == dataType.elementType
433-
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
437+
elementType = dataType.elementType
438+
elif literal.array.HasField("element_type"):
439+
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
440+
elif len(elements) > 0:
441+
result.append(LiteralExpression._to_value(elements[0], None))
442+
elements = elements[1:]
443+
elementType = LiteralExpression._infer_type(result[0])
444+
else:
445+
raise PySparkTypeError(
446+
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
447+
messageParameters={},
448+
)
449+
450+
for element in elements:
451+
result.append(LiteralExpression._to_value(element, elementType))
452+
return result
434453

435454
raise PySparkTypeError(
436455
errorClass="UNSUPPORTED_LITERAL",
@@ -475,11 +494,11 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
475494
elif isinstance(self._dataType, DayTimeIntervalType):
476495
expr.literal.day_time_interval = int(self._value)
477496
elif isinstance(self._dataType, ArrayType):
478-
element_type = self._dataType.elementType
479-
expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type))
497+
if not _optional_data_types_for_map_and_array_literals_enabled or len(self._value) == 0:
498+
expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(self._dataType.elementType))
480499
for v in self._value:
481500
expr.literal.array.elements.append(
482-
LiteralExpression(v, element_type).to_plan(session).literal
501+
LiteralExpression(v, self._dataType.elementType).to_plan(session).literal
483502
)
484503
else:
485504
raise PySparkTypeError(

python/pyspark/sql/connect/proto/expressions_pb2.py

Lines changed: 60 additions & 60 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/expressions_pb2.pyi

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,10 @@ class Expression(google.protobuf.message.Message):
475475
ELEMENT_TYPE_FIELD_NUMBER: builtins.int
476476
ELEMENTS_FIELD_NUMBER: builtins.int
477477
@property
478-
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
478+
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
479+
"""(Optional) The element type of the array. Only need to set this when the elements are
480+
empty, since spark 4.1+ supports inferring the element type from the elements.
481+
"""
479482
@property
480483
def elements(
481484
self,
@@ -489,14 +492,25 @@ class Expression(google.protobuf.message.Message):
489492
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
490493
) -> None: ...
491494
def HasField(
492-
self, field_name: typing_extensions.Literal["element_type", b"element_type"]
495+
self,
496+
field_name: typing_extensions.Literal[
497+
"_element_type", b"_element_type", "element_type", b"element_type"
498+
],
493499
) -> builtins.bool: ...
494500
def ClearField(
495501
self,
496502
field_name: typing_extensions.Literal[
497-
"element_type", b"element_type", "elements", b"elements"
503+
"_element_type",
504+
b"_element_type",
505+
"element_type",
506+
b"element_type",
507+
"elements",
508+
b"elements",
498509
],
499510
) -> None: ...
511+
def WhichOneof(
512+
self, oneof_group: typing_extensions.Literal["_element_type", b"_element_type"]
513+
) -> typing_extensions.Literal["element_type"] | None: ...
500514

501515
class Map(google.protobuf.message.Message):
502516
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -506,9 +520,15 @@ class Expression(google.protobuf.message.Message):
506520
KEYS_FIELD_NUMBER: builtins.int
507521
VALUES_FIELD_NUMBER: builtins.int
508522
@property
509-
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
523+
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
524+
"""(Optional) The key type of the map. Only need to set this when the keys are
525+
empty, since spark 4.1+ supports inferring the key type from the keys
526+
"""
510527
@property
511-
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
528+
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
529+
"""(Optional) The value type of the map. Only need to set this when the values are
530+
empty, since spark 4.1+ supports inferring the value type from the values.
531+
"""
512532
@property
513533
def keys(
514534
self,
@@ -532,12 +552,23 @@ class Expression(google.protobuf.message.Message):
532552
def HasField(
533553
self,
534554
field_name: typing_extensions.Literal[
535-
"key_type", b"key_type", "value_type", b"value_type"
555+
"_key_type",
556+
b"_key_type",
557+
"_value_type",
558+
b"_value_type",
559+
"key_type",
560+
b"key_type",
561+
"value_type",
562+
b"value_type",
536563
],
537564
) -> builtins.bool: ...
538565
def ClearField(
539566
self,
540567
field_name: typing_extensions.Literal[
568+
"_key_type",
569+
b"_key_type",
570+
"_value_type",
571+
b"_value_type",
541572
"key_type",
542573
b"key_type",
543574
"keys",
@@ -548,6 +579,14 @@ class Expression(google.protobuf.message.Message):
548579
b"values",
549580
],
550581
) -> None: ...
582+
@typing.overload
583+
def WhichOneof(
584+
self, oneof_group: typing_extensions.Literal["_key_type", b"_key_type"]
585+
) -> typing_extensions.Literal["key_type"] | None: ...
586+
@typing.overload
587+
def WhichOneof(
588+
self, oneof_group: typing_extensions.Literal["_value_type", b"_value_type"]
589+
) -> typing_extensions.Literal["value_type"] | None: ...
551590

552591
class Struct(google.protobuf.message.Message):
553592
DESCRIPTOR: google.protobuf.descriptor.Descriptor

python/pyspark/sql/tests/connect/test_connect_plan.py

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import datetime
2020
import decimal
2121
import math
22+
from contextlib import contextmanager
2223

2324
from pyspark.testing.connectutils import (
2425
PlanOnlyTestFixture,
@@ -1010,54 +1011,66 @@ def test_literal_expression_with_arrays(self):
10101011
self.assertEqual(len(l4.array.elements[1].array.elements), 2)
10111012
self.assertEqual(len(l4.array.elements[2].array.elements), 0)
10121013

1013-
def test_literal_to_any_conversion(self):
1014-
for value in [
1015-
b"binary\0\0asas",
1016-
True,
1017-
False,
1018-
0,
1019-
12,
1020-
-1,
1021-
0.0,
1022-
1.234567,
1023-
decimal.Decimal(0.0),
1024-
decimal.Decimal(1.234567),
1025-
"sss",
1026-
datetime.date(2022, 12, 13),
1027-
datetime.datetime.now(),
1028-
datetime.timedelta(1, 2, 3),
1029-
[1, 2, 3, 4, 5, 6],
1030-
[-1.0, 2.0, 3.0],
1031-
["x", "y", "z"],
1032-
[[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]],
1033-
]:
1034-
lit = LiteralExpression._from_value(value)
1035-
proto_lit = lit.to_plan(None).literal
1036-
value2 = LiteralExpression._to_value(proto_lit)
1037-
self.assertEqual(value, value2)
1038-
1039-
with self.assertRaises(AssertionError):
1040-
lit = LiteralExpression._from_value(1.234567)
1041-
proto_lit = lit.to_plan(None).literal
1042-
LiteralExpression._to_value(proto_lit, StringType())
1043-
1044-
with self.assertRaises(AssertionError):
1045-
lit = LiteralExpression._from_value("1.234567")
1046-
proto_lit = lit.to_plan(None).literal
1047-
LiteralExpression._to_value(proto_lit, DoubleType())
1048-
1049-
with self.assertRaises(AssertionError):
1050-
# build a array<string> proto literal, but with incorrect elements
1051-
proto_lit = proto.Expression().literal
1052-
proto_lit.array.element_type.CopyFrom(pyspark_types_to_proto_types(StringType()))
1053-
proto_lit.array.elements.append(
1054-
LiteralExpression("string", StringType()).to_plan(None).literal
1055-
)
1056-
proto_lit.array.elements.append(
1057-
LiteralExpression(1.234, DoubleType()).to_plan(None).literal
1058-
)
1014+
@contextmanager
1015+
def _optional_data_types_enabled(self, enabled: bool):
1016+
from pyspark.sql.connect.expressions import _optional_data_types_for_map_and_array_literals_enabled
1017+
previous_value = _optional_data_types_for_map_and_array_literals_enabled
1018+
try:
1019+
_optional_data_types_for_map_and_array_literals_enabled = enabled
1020+
yield
1021+
finally:
1022+
_optional_data_types_for_map_and_array_literals_enabled = previous_value
10591023

1060-
LiteralExpression._to_value(proto_lit, DoubleType)
1024+
def test_literal_to_any_conversion(self):
1025+
for optional_data_types_enabled in [True, False]:
1026+
with self._optional_data_types_enabled(optional_data_types_enabled):
1027+
for value in [
1028+
b"binary\0\0asas",
1029+
True,
1030+
False,
1031+
0,
1032+
12,
1033+
-1,
1034+
0.0,
1035+
1.234567,
1036+
decimal.Decimal(0.0),
1037+
decimal.Decimal(1.234567),
1038+
"sss",
1039+
datetime.date(2022, 12, 13),
1040+
datetime.datetime.now(),
1041+
datetime.timedelta(1, 2, 3),
1042+
[1, 2, 3, 4, 5, 6],
1043+
[-1.0, 2.0, 3.0],
1044+
["x", "y", "z"],
1045+
[[1.0, 2.0, 3.0], [4.0, 5.0], [6.0]],
1046+
]:
1047+
lit = LiteralExpression._from_value(value)
1048+
proto_lit = lit.to_plan(None).literal
1049+
value2 = LiteralExpression._to_value(proto_lit)
1050+
self.assertEqual(value, value2)
1051+
1052+
with self.assertRaises(AssertionError):
1053+
lit = LiteralExpression._from_value(1.234567)
1054+
proto_lit = lit.to_plan(None).literal
1055+
LiteralExpression._to_value(proto_lit, StringType())
1056+
1057+
with self.assertRaises(AssertionError):
1058+
lit = LiteralExpression._from_value("1.234567")
1059+
proto_lit = lit.to_plan(None).literal
1060+
LiteralExpression._to_value(proto_lit, DoubleType())
1061+
1062+
with self.assertRaises(AssertionError):
1063+
# build a array<string> proto literal, but with incorrect elements
1064+
proto_lit = proto.Expression().literal
1065+
proto_lit.array.element_type.CopyFrom(pyspark_types_to_proto_types(StringType()))
1066+
proto_lit.array.elements.append(
1067+
LiteralExpression("string", StringType()).to_plan(None).literal
1068+
)
1069+
proto_lit.array.elements.append(
1070+
LiteralExpression(1.234, DoubleType()).to_plan(None).literal
1071+
)
1072+
1073+
LiteralExpression._to_value(proto_lit, DoubleType)
10611074

10621075

10631076
if __name__ == "__main__":

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,19 @@ message Expression {
215215
}
216216

217217
message Array {
218-
DataType element_type = 1;
218+
// (Optional) The element type of the array. Only need to set this when the elements are
219+
// empty, since spark 4.1+ supports inferring the element type from the elements.
220+
optional DataType element_type = 1;
219221
repeated Literal elements = 2;
220222
}
221223

222224
message Map {
223-
DataType key_type = 1;
224-
DataType value_type = 2;
225+
// (Optional) The key type of the map. Only need to set this when the keys are
226+
// empty, since spark 4.1+ supports inferring the key type from the keys
227+
optional DataType key_type = 1;
228+
// (Optional) The value type of the map. Only need to set this when the values are
229+
// empty, since spark 4.1+ supports inferring the value type from the values.
230+
optional DataType value_type = 2;
225231
repeated Literal keys = 3;
226232
repeated Literal values = 4;
227233
}

0 commit comments

Comments
 (0)