Skip to content

Commit 0682838

Browse files
committed
[SPARK-52449] Make datatypes for Expression.Literal.Map/Expression.Literal.Array optional
1 parent e888e37 commit 0682838

File tree

15 files changed

+352
-389
lines changed

15 files changed

+352
-389
lines changed

python/pyspark/sql/connect/expressions.py

Lines changed: 71 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
Optional,
3030
)
3131

32-
import json
33-
import decimal
3432
import datetime
33+
import decimal
34+
import json
3535
import warnings
3636
from threading import Lock
3737

@@ -377,6 +377,52 @@ def _infer_type(cls, value: Any) -> DataType:
377377
def _from_value(cls, value: Any) -> "LiteralExpression":
378378
return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))
379379

380+
@classmethod
381+
def _infer_type_from_literal(cls, literal: "proto.Expression.Literal") -> Optional[DataType]:
382+
if literal.HasField("null"):
383+
return NullType()
384+
elif literal.HasField("binary"):
385+
return BinaryType()
386+
elif literal.HasField("boolean"):
387+
return BooleanType()
388+
elif literal.HasField("byte"):
389+
return ByteType()
390+
elif literal.HasField("short"):
391+
return ShortType()
392+
elif literal.HasField("integer"):
393+
return IntegerType()
394+
elif literal.HasField("long"):
395+
return LongType()
396+
elif literal.HasField("float"):
397+
return FloatType()
398+
elif literal.HasField("double"):
399+
return DoubleType()
400+
elif literal.HasField("date"):
401+
return DateType()
402+
elif literal.HasField("timestamp"):
403+
return TimestampType()
404+
elif literal.HasField("timestamp_ntz"):
405+
return TimestampNTZType()
406+
elif literal.HasField("array"):
407+
if literal.array.HasField("element_type"):
408+
return ArrayType(
409+
proto_schema_to_pyspark_data_type(literal.array.element_type), True
410+
)
411+
element_type = None
412+
if len(literal.array.elements) > 0:
413+
element_type = LiteralExpression._infer_type_from_literal(literal.array.elements[0])
414+
415+
if element_type is None:
416+
raise PySparkTypeError(
417+
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
418+
messageParameters={},
419+
)
420+
return ArrayType(element_type, True)
421+
# Not all data types support inferring the data type from the literal at the moment.
422+
# e.g. the type of DayTimeInterval contains extra information like start_field and
423+
# end_field and cannot be inferred from the literal.
424+
return None
425+
380426
@classmethod
381427
def _to_value(
382428
cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
@@ -426,10 +472,20 @@ def _to_value(
426472
assert dataType is None or isinstance(dataType, DayTimeIntervalType)
427473
return DayTimeIntervalType().fromInternal(literal.day_time_interval)
428474
elif literal.HasField("array"):
429-
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
430-
if dataType is not None:
431-
assert isinstance(dataType, ArrayType)
432-
assert elementType == dataType.elementType
475+
elementType = None
476+
if literal.array.HasField("element_type"):
477+
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
478+
if dataType is not None:
479+
assert isinstance(dataType, ArrayType)
480+
assert elementType == dataType.elementType
481+
elif len(literal.array.elements) > 0:
482+
elementType = LiteralExpression._infer_type_from_literal(literal.array.elements[0])
483+
484+
if elementType is None:
485+
raise PySparkTypeError(
486+
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
487+
messageParameters={},
488+
)
433489
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
434490

435491
raise PySparkTypeError(
@@ -475,11 +531,17 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
475531
elif isinstance(self._dataType, DayTimeIntervalType):
476532
expr.literal.day_time_interval = int(self._value)
477533
elif isinstance(self._dataType, ArrayType):
478-
element_type = self._dataType.elementType
479-
expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type))
480534
for v in self._value:
481535
expr.literal.array.elements.append(
482-
LiteralExpression(v, element_type).to_plan(session).literal
536+
LiteralExpression(v, self._dataType.elementType).to_plan(session).literal
537+
)
538+
if (
539+
len(self._value) == 0
540+
or LiteralExpression._infer_type_from_literal(expr.literal.array.elements[0])
541+
is None
542+
):
543+
expr.literal.array.element_type.CopyFrom(
544+
pyspark_types_to_proto_types(self._dataType.elementType)
483545
)
484546
else:
485547
raise PySparkTypeError(

python/pyspark/sql/connect/proto/expressions_pb2.pyi

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,11 @@ class Expression(google.protobuf.message.Message):
475475
ELEMENT_TYPE_FIELD_NUMBER: builtins.int
476476
ELEMENTS_FIELD_NUMBER: builtins.int
477477
@property
478-
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
478+
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
479+
"""(Optional) The element type of the array. Only need to set this when the elements are
480+
empty or the element type is not inferable, since spark 4.1+ supports
481+
inferring the element type from the elements.
482+
"""
479483
@property
480484
def elements(
481485
self,
@@ -506,9 +510,17 @@ class Expression(google.protobuf.message.Message):
506510
KEYS_FIELD_NUMBER: builtins.int
507511
VALUES_FIELD_NUMBER: builtins.int
508512
@property
509-
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
513+
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
514+
"""(Optional) The key type of the map. Only need to set this when the keys are
515+
empty or the key type is not inferable, since spark 4.1+ supports
516+
inferring the key type from the keys
517+
"""
510518
@property
511-
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
519+
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
520+
"""(Optional) The value type of the map. Only need to set this when the values are
521+
empty or the value type is not inferable, since spark 4.1+ supports
522+
inferring the value type from the values.
523+
"""
512524
@property
513525
def keys(
514526
self,

python/pyspark/sql/tests/connect/test_connect_plan.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -979,21 +979,20 @@ def test_literal_expression_with_arrays(self):
979979
self.assertEqual(l0.array.elements[2].string, "z")
980980

981981
l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
982-
self.assertTrue(l1.array.element_type.HasField("integer"))
982+
self.assertFalse(l1.array.element_type.HasField("integer"))
983983
self.assertEqual(len(l1.array.elements), 2)
984984
self.assertEqual(l1.array.elements[0].integer, 3)
985985
self.assertEqual(l1.array.elements[1].integer, -3)
986986

987987
l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
988-
self.assertTrue(l2.array.element_type.HasField("double"))
988+
self.assertFalse(l2.array.element_type.HasField("double"))
989989
self.assertEqual(len(l2.array.elements), 3)
990990
self.assertTrue(math.isnan(l2.array.elements[0].double))
991991
self.assertEqual(l2.array.elements[1].double, -3.0)
992992
self.assertEqual(l2.array.elements[2].double, 0.0)
993993

994994
l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
995-
self.assertTrue(l3.array.element_type.HasField("array"))
996-
self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
995+
self.assertFalse(l3.array.element_type.HasField("array"))
997996
self.assertEqual(len(l3.array.elements), 2)
998997
self.assertEqual(len(l3.array.elements[0].array.elements), 2)
999998
self.assertEqual(len(l3.array.elements[1].array.elements), 3)
@@ -1003,8 +1002,7 @@ def test_literal_expression_with_arrays(self):
10031002
.to_plan(None)
10041003
.literal
10051004
)
1006-
self.assertTrue(l4.array.element_type.HasField("array"))
1007-
self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
1005+
self.assertFalse(l4.array.element_type.HasField("array"))
10081006
self.assertEqual(len(l4.array.elements), 3)
10091007
self.assertEqual(len(l4.array.elements[0].array.elements), 2)
10101008
self.assertEqual(len(l4.array.elements[1].array.elements), 2)

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,21 @@ message Expression {
215215
}
216216

217217
message Array {
218+
// (Optional) The element type of the array. Only need to set this when the elements are
219+
// empty or the element type is not inferable, since spark 4.1+ supports
220+
// inferring the element type from the elements.
218221
DataType element_type = 1;
219222
repeated Literal elements = 2;
220223
}
221224

222225
message Map {
226+
// (Optional) The key type of the map. Only need to set this when the keys are
227+
// empty or the key type is not inferable, since spark 4.1+ supports
228+
// inferring the key type from the keys
223229
DataType key_type = 1;
230+
// (Optional) The value type of the map. Only need to set this when the values are
231+
// empty or the value type is not inferable, since spark 4.1+ supports
232+
// inferring the value type from the values.
224233
DataType value_type = 2;
225234
repeated Literal keys = 3;
226235
repeated Literal values = 4;

0 commit comments

Comments
 (0)