Skip to content

Commit 81af068

Browse files
committed
[SPARK-52449] Make datatypes for Expression.Literal.Map/Expression.Literal.Array optional
1 parent 1f1bacc commit 81af068

File tree

10 files changed

+158
-330
lines changed

10 files changed

+158
-330
lines changed

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,45 @@ import org.apache.spark.util.SparkClassUtils
3939

4040
object LiteralValueProtoConverter {
4141

42+
private def setArrayTypeAfterAddingElements(
43+
ab: proto.Expression.Literal.Array.Builder,
44+
elementType: DataType,
45+
containsNull: Boolean,
46+
useDeprecatedDataTypeFields: Boolean): Unit = {
47+
if (useDeprecatedDataTypeFields) {
48+
ab.setElementType(toConnectProtoType(elementType))
49+
} else {
50+
val dataTypeBuilder = proto.DataType.Array.newBuilder()
51+
if (ab.getElementsCount == 0 || getInferredDataType(ab.getElements(0)).isEmpty) {
52+
dataTypeBuilder.setElementType(toConnectProtoType(elementType))
53+
}
54+
dataTypeBuilder.setContainsNull(containsNull)
55+
ab.setDataType(dataTypeBuilder.build())
56+
}
57+
}
58+
59+
private def setMapTypeAfterAddingKeysAndValues(
60+
mb: proto.Expression.Literal.Map.Builder,
61+
keyType: DataType,
62+
valueType: DataType,
63+
valueContainsNull: Boolean,
64+
useDeprecatedDataTypeFields: Boolean): Unit = {
65+
if (useDeprecatedDataTypeFields) {
66+
mb.setKeyType(toConnectProtoType(keyType))
67+
mb.setValueType(toConnectProtoType(valueType))
68+
} else {
69+
val dataTypeBuilder = proto.DataType.Map.newBuilder()
70+
if (mb.getKeysCount == 0 || getInferredDataType(mb.getKeys(0)).isEmpty) {
71+
dataTypeBuilder.setKeyType(toConnectProtoType(keyType))
72+
}
73+
if (mb.getValuesCount == 0 || getInferredDataType(mb.getValues(0)).isEmpty) {
74+
dataTypeBuilder.setValueType(toConnectProtoType(valueType))
75+
}
76+
dataTypeBuilder.setValueContainsNull(valueContainsNull)
77+
mb.setDataType(dataTypeBuilder.build())
78+
}
79+
}
80+
4281
@scala.annotation.tailrec
4382
private def toLiteralProtoBuilderInternal(
4483
literal: Any,
@@ -58,17 +97,12 @@ object LiteralValueProtoConverter {
5897

5998
def arrayBuilder(array: Array[_]) = {
6099
val ab = builder.getArrayBuilder
61-
if (options.useDeprecatedDataTypeFields) {
62-
ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
63-
} else {
64-
ab.setDataType(
65-
proto.DataType.Array
66-
.newBuilder()
67-
.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
68-
.setContainsNull(true)
69-
.build())
70-
}
71100
array.foreach(x => ab.addElements(toLiteralProtoWithOptions(x, None, options)))
101+
setArrayTypeAfterAddingElements(
102+
ab,
103+
toDataType(array.getClass.getComponentType),
104+
containsNull = true,
105+
options.useDeprecatedDataTypeFields)
72106
ab
73107
}
74108

@@ -122,16 +156,6 @@ object LiteralValueProtoConverter {
122156

123157
def arrayBuilder(scalaValue: Any, elementType: DataType, containsNull: Boolean) = {
124158
val ab = builder.getArrayBuilder
125-
if (options.useDeprecatedDataTypeFields) {
126-
ab.setElementType(toConnectProtoType(elementType))
127-
} else {
128-
ab.setDataType(
129-
proto.DataType.Array
130-
.newBuilder()
131-
.setElementType(toConnectProtoType(elementType))
132-
.setContainsNull(containsNull)
133-
.build())
134-
}
135159
scalaValue match {
136160
case a: Array[_] =>
137161
a.foreach(item =>
@@ -142,7 +166,11 @@ object LiteralValueProtoConverter {
142166
case other =>
143167
throw new IllegalArgumentException(s"literal $other not supported (yet).")
144168
}
145-
169+
setArrayTypeAfterAddingElements(
170+
ab,
171+
elementType,
172+
containsNull,
173+
options.useDeprecatedDataTypeFields)
146174
ab
147175
}
148176

@@ -152,19 +180,6 @@ object LiteralValueProtoConverter {
152180
valueType: DataType,
153181
valueContainsNull: Boolean) = {
154182
val mb = builder.getMapBuilder
155-
if (options.useDeprecatedDataTypeFields) {
156-
mb.setKeyType(toConnectProtoType(keyType))
157-
mb.setValueType(toConnectProtoType(valueType))
158-
} else {
159-
mb.setDataType(
160-
proto.DataType.Map
161-
.newBuilder()
162-
.setKeyType(toConnectProtoType(keyType))
163-
.setValueType(toConnectProtoType(valueType))
164-
.setValueContainsNull(valueContainsNull)
165-
.build())
166-
}
167-
168183
scalaValue match {
169184
case map: scala.collection.Map[_, _] =>
170185
map.foreach { case (k, v) =>
@@ -174,7 +189,12 @@ object LiteralValueProtoConverter {
174189
case other =>
175190
throw new IllegalArgumentException(s"literal $other not supported (yet).")
176191
}
177-
192+
setMapTypeAfterAddingKeysAndValues(
193+
mb,
194+
keyType,
195+
valueType,
196+
valueContainsNull,
197+
options.useDeprecatedDataTypeFields)
178198
mb
179199
}
180200

@@ -414,6 +434,9 @@ object LiteralValueProtoConverter {
414434
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
415435
toCatalystArray(literal.getArray)
416436

437+
case proto.Expression.Literal.LiteralTypeCase.MAP =>
438+
toCatalystMap(literal.getMap)
439+
417440
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
418441
toCatalystStruct(literal.getStruct)
419442

sql/connect/common/src/test/resources/query-tests/queries/function_lit.json

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,10 +364,6 @@
364364
"integer": 6
365365
}],
366366
"dataType": {
367-
"elementType": {
368-
"integer": {
369-
}
370-
},
371367
"containsNull": true
372368
}
373369
}
-4 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)