Skip to content

Commit 7e5b478

Browse files
committed
Fixes and tests
1 parent 1ce2e29 commit 7e5b478

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

python/pyspark/sql/tests/connect/test_connect_plan.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -972,28 +972,27 @@ def test_column_expressions(self):
972972

973973
def test_literal_expression_with_arrays(self):
974974
l0 = LiteralExpression._from_value(["x", "y", "z"]).to_plan(None).literal
975-
self.assertTrue(l0.array.element_type.HasField("string"))
975+
self.assertFalse(l0.array.element_type.HasField("string"))
976976
self.assertEqual(len(l0.array.elements), 3)
977977
self.assertEqual(l0.array.elements[0].string, "x")
978978
self.assertEqual(l0.array.elements[1].string, "y")
979979
self.assertEqual(l0.array.elements[2].string, "z")
980980

981981
l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal
982-
self.assertTrue(l1.array.element_type.HasField("integer"))
982+
self.assertFalse(l1.array.element_type.HasField("integer"))
983983
self.assertEqual(len(l1.array.elements), 2)
984984
self.assertEqual(l1.array.elements[0].integer, 3)
985985
self.assertEqual(l1.array.elements[1].integer, -3)
986986

987987
l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal
988-
self.assertTrue(l2.array.element_type.HasField("double"))
988+
self.assertFalse(l2.array.element_type.HasField("double"))
989989
self.assertEqual(len(l2.array.elements), 3)
990990
self.assertTrue(math.isnan(l2.array.elements[0].double))
991991
self.assertEqual(l2.array.elements[1].double, -3.0)
992992
self.assertEqual(l2.array.elements[2].double, 0.0)
993993

994994
l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal
995-
self.assertTrue(l3.array.element_type.HasField("array"))
996-
self.assertTrue(l3.array.element_type.array.element_type.HasField("integer"))
995+
self.assertFalse(l3.array.element_type.HasField("array"))
997996
self.assertEqual(len(l3.array.elements), 2)
998997
self.assertEqual(len(l3.array.elements[0].array.elements), 2)
999998
self.assertEqual(len(l3.array.elements[1].array.elements), 3)
@@ -1003,8 +1002,7 @@ def test_literal_expression_with_arrays(self):
10031002
.to_plan(None)
10041003
.literal
10051004
)
1006-
self.assertTrue(l4.array.element_type.HasField("array"))
1007-
self.assertTrue(l4.array.element_type.array.element_type.HasField("double"))
1005+
self.assertFalse(l4.array.element_type.HasField("array"))
10081006
self.assertEqual(len(l4.array.elements), 3)
10091007
self.assertEqual(len(l4.array.elements[0].array.elements), 2)
10101008
self.assertEqual(len(l4.array.elements[1].array.elements), 2)
@@ -1033,6 +1031,8 @@ def test_literal_to_any_conversion(self):
10331031
]:
10341032
lit = LiteralExpression._from_value(value)
10351033
proto_lit = lit.to_plan(None).literal
1034+
if proto_lit.HasField("array"):
1035+
self.assertFalse(proto_lit.array.HasField("element_type"))
10361036
value2 = LiteralExpression._to_value(proto_lit)
10371037
self.assertEqual(value, value2)
10381038

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ object LiteralValueProtoConverter {
138138

139139
def mapBuilder(scalaValue: Any, keyType: DataType, valueType: DataType) = {
140140
val mb = builder.getMapBuilder
141-
.setKeyType(toConnectProtoType(keyType))
142-
.setValueType(toConnectProtoType(valueType))
143141

144142
scalaValue match {
145143
case map: scala.collection.Map[_, _] =>

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,35 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i
5858
}
5959
}
6060

61-
test("invalid array literal - empty array") {
61+
test("element type of array literal is set for empty array") {
62+
val literalProto =
63+
LiteralValueProtoConverter.toLiteralProto(Array[Int](), ArrayType(IntegerType))
64+
assert(literalProto.getArray.hasElementType)
65+
}
66+
67+
test("element type of array literal is not set for non-empty array") {
68+
val literalProto =
69+
LiteralValueProtoConverter.toLiteralProto(Array(1, 2, 3), ArrayType(IntegerType))
70+
assert(!literalProto.getArray.hasElementType)
71+
}
72+
73+
test("key and value type of map literal is set for empty map") {
74+
val literalProto = LiteralValueProtoConverter.toLiteralProto(
75+
Map[Int, Int](),
76+
MapType(IntegerType, IntegerType))
77+
assert(literalProto.getMap.hasKeyType)
78+
assert(literalProto.getMap.hasValueType)
79+
}
80+
81+
test("key and value type of map literal is not set for non-empty map") {
82+
val literalProto = LiteralValueProtoConverter.toLiteralProto(
83+
Map(1 -> 2, 3 -> 4, 5 -> 6),
84+
MapType(IntegerType, IntegerType))
85+
assert(!literalProto.getMap.hasKeyType)
86+
assert(!literalProto.getMap.hasValueType)
87+
}
88+
89+
test("invalid array literal") {
6290
val literalProto = proto.Expression.Literal
6391
.newBuilder()
6492
.setArray(proto.Expression.Literal.Array.newBuilder())

0 commit comments

Comments
 (0)