Skip to content

Commit fc694d4

Browse files
committed
Fixes and tests
1 parent 1ce2e29 commit fc694d4

File tree

11 files changed

+106
-388
lines changed

11 files changed

+106
-388
lines changed

python/pyspark/sql/connect/expressions.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,57 @@ def _infer_type(cls, value: Any) -> DataType:
377377
def _from_value(cls, value: Any) -> "LiteralExpression":
378378
return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value))
379379

380+
@classmethod
381+
def _infer_type_from_literal(cls, literal: "proto.Expression.Literal") -> DataType:
382+
if literal.HasField("null"):
383+
return NullType()
384+
elif literal.HasField("binary"):
385+
return BinaryType()
386+
elif literal.HasField("boolean"):
387+
return BooleanType()
388+
elif literal.HasField("byte"):
389+
return ByteType()
390+
elif literal.HasField("short"):
391+
return ShortType()
392+
elif literal.HasField("integer"):
393+
return IntegerType()
394+
elif literal.HasField("long"):
395+
return LongType()
396+
elif literal.HasField("float"):
397+
return FloatType()
398+
elif literal.HasField("double"):
399+
return DoubleType()
400+
elif literal.HasField("decimal"):
401+
return DecimalType()
402+
elif literal.HasField("string"):
403+
return StringType()
404+
elif literal.HasField("date"):
405+
return DateType()
406+
elif literal.HasField("timestamp"):
407+
return TimestampType()
408+
elif literal.HasField("timestamp_ntz"):
409+
return TimestampNTZType()
410+
elif literal.HasField("day_time_interval"):
411+
return DayTimeIntervalType()
412+
elif literal.HasField("array"):
413+
if len(literal.array.elements) == 0:
414+
if literal.array.HasField("element_type"):
415+
return ArrayType(
416+
proto_schema_to_pyspark_data_type(literal.array.element_type), True
417+
)
418+
raise PySparkTypeError(
419+
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
420+
messageParameters={},
421+
)
422+
return ArrayType(
423+
LiteralExpression._infer_type_from_literal(literal.array.elements[0]), True
424+
)
425+
426+
raise PySparkTypeError(
427+
errorClass="UNSUPPORTED_LITERAL",
428+
messageParameters={"literal": str(literal)},
429+
)
430+
380431
@classmethod
381432
def _to_value(
382433
cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None
@@ -426,26 +477,19 @@ def _to_value(
426477
assert dataType is None or isinstance(dataType, DayTimeIntervalType)
427478
return DayTimeIntervalType().fromInternal(literal.day_time_interval)
428479
elif literal.HasField("array"):
429-
elements = literal.array.elements
430-
result = []
431-
if dataType is not None:
432-
assert isinstance(dataType, ArrayType)
433-
elementType = dataType.elementType
434-
elif literal.array.HasField("element_type"):
480+
if literal.array.HasField("element_type"):
435481
elementType = proto_schema_to_pyspark_data_type(literal.array.element_type)
436-
elif len(elements) > 0:
437-
result.append(LiteralExpression._to_value(elements[0], None))
438-
elements = elements[1:]
439-
elementType = LiteralExpression._infer_type(result[0])
482+
if dataType is not None:
483+
assert isinstance(dataType, ArrayType)
484+
assert elementType == dataType.elementType
485+
elif len(literal.array.elements) > 0:
486+
elementType = LiteralExpression._infer_type_from_literal(literal.array.elements[0])
440487
else:
441488
raise PySparkTypeError(
442489
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
443490
messageParameters={},
444491
)
445-
446-
for element in elements:
447-
result.append(LiteralExpression._to_value(element, elementType))
448-
return result
492+
return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements]
449493

450494
raise PySparkTypeError(
451495
errorClass="UNSUPPORTED_LITERAL",

python/pyspark/sql/tests/connect/test_connect_plan.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -972,28 +972,27 @@ def test_column_expressions(self):
972972

973973
def test_literal_expression_with_arrays(self):
974974
l0 = LiteralExpression._from_value(["x", "y", "z"]).to_plan(None).literal
975-
self.assertTrue(l0.array.element_type.HasField("string"))
975+
self.assertFalse(l0.array.element_type.HasField("string"))
976976
self.assertEqual(len(l0.array.elements), 3)
977977
self.assertEqual(l0.array.elements[0].string, "x")
978978
self.assertEqual(l0.array.elements[1].string, "y")
979979
self.assertEqual(l0.array.elements[2].string, "z")
980980

981981
l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
982-
self.assertTrue(l1.array.element_type.HasField("integer"))
982+
self.assertFalse(l1.array.element_type.HasField("integer"))
983983
self.assertEqual(len(l1.array.elements), 2)
984984
self.assertEqual(l1.array.elements[0].integer, 3)
985985
self.assertEqual(l1.array.elements[1].integer, -3)
986986

987987
l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
988-
self.assertTrue(l2.array.element_type.HasField("double"))
988+
self.assertFalse(l2.array.element_type.HasField("double"))
989989
self.assertEqual(len(l2.array.elements), 3)
990990
self.assertTrue(math.isnan(l2.array.elements[0].double))
991991
self.assertEqual(l2.array.elements[1].double, -3.0)
992992
self.assertEqual(l2.array.elements[2].double, 0.0)
993993

994994
l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
995-
self.assertTrue(l3.array.element_type.HasField("array"))
996-
self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
995+
self.assertFalse(l3.array.element_type.HasField("array"))
997996
self.assertEqual(len(l3.array.elements), 2)
998997
self.assertEqual(len(l3.array.elements[0].array.elements), 2)
999998
self.assertEqual(len(l3.array.elements[1].array.elements), 3)
@@ -1003,8 +1002,7 @@ def test_literal_expression_with_arrays(self):
10031002
.to_plan(None)
10041003
.literal
10051004
)
1006-
self.assertTrue(l4.array.element_type.HasField("array"))
1007-
self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
1005+
self.assertFalse(l4.array.element_type.HasField("array"))
10081006
self.assertEqual(len(l4.array.elements), 3)
10091007
self.assertEqual(len(l4.array.elements[0].array.elements), 2)
10101008
self.assertEqual(len(l4.array.elements[1].array.elements), 2)
@@ -1033,6 +1031,8 @@ def test_literal_to_any_conversion(self):
10331031
]:
10341032
lit = LiteralExpression._from_value(value)
10351033
proto_lit = lit.to_plan(None).literal
1034+
if proto_lit.HasField("array"):
1035+
self.assertFalse(proto_lit.array.HasField("element_type"))
10361036
value2 = LiteralExpression._to_value(proto_lit)
10371037
self.assertEqual(value, value2)
10381038

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ object LiteralValueProtoConverter {
138138

139139
def mapBuilder(scalaValue: Any, keyType: DataType, valueType: DataType) = {
140140
val mb = builder.getMapBuilder
141-
.setKeyType(toConnectProtoType(keyType))
142-
.setValueType(toConnectProtoType(valueType))
143141

144142
scalaValue match {
145143
case map: scala.collection.Map[_, _] =>

sql/connect/common/src/test/resources/query-tests/queries/function_lit.json

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,6 @@
358358
}, {
359359
"literal": {
360360
"array": {
361-
"elementType": {
362-
"integer": {
363-
}
364-
},
365361
"elements": [{
366362
"integer": 8
367363
}, {
-4 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)