@@ -39,6 +39,45 @@ import org.apache.spark.util.SparkClassUtils
39
39
40
40
object LiteralValueProtoConverter {
41
41
42
+ private def setArrayTypeAfterAddingElements (
43
+ ab : proto.Expression .Literal .Array .Builder ,
44
+ elementType : DataType ,
45
+ containsNull : Boolean ,
46
+ useDeprecatedDataTypeFields : Boolean ): Unit = {
47
+ if (useDeprecatedDataTypeFields) {
48
+ ab.setElementType(toConnectProtoType(elementType))
49
+ } else {
50
+ val dataTypeBuilder = proto.DataType .Array .newBuilder()
51
+ if (ab.getElementsCount == 0 || getInferredDataType(ab.getElements(0 )).isEmpty) {
52
+ dataTypeBuilder.setElementType(toConnectProtoType(elementType))
53
+ }
54
+ dataTypeBuilder.setContainsNull(containsNull)
55
+ ab.setDataType(dataTypeBuilder.build())
56
+ }
57
+ }
58
+
59
+ private def setMapTypeAfterAddingKeysAndValues (
60
+ mb : proto.Expression .Literal .Map .Builder ,
61
+ keyType : DataType ,
62
+ valueType : DataType ,
63
+ valueContainsNull : Boolean ,
64
+ useDeprecatedDataTypeFields : Boolean ): Unit = {
65
+ if (useDeprecatedDataTypeFields) {
66
+ mb.setKeyType(toConnectProtoType(keyType))
67
+ mb.setValueType(toConnectProtoType(valueType))
68
+ } else {
69
+ val dataTypeBuilder = proto.DataType .Map .newBuilder()
70
+ if (mb.getKeysCount == 0 || getInferredDataType(mb.getKeys(0 )).isEmpty) {
71
+ dataTypeBuilder.setKeyType(toConnectProtoType(keyType))
72
+ }
73
+ if (mb.getValuesCount == 0 || getInferredDataType(mb.getValues(0 )).isEmpty) {
74
+ dataTypeBuilder.setValueType(toConnectProtoType(valueType))
75
+ }
76
+ dataTypeBuilder.setValueContainsNull(valueContainsNull)
77
+ mb.setDataType(dataTypeBuilder.build())
78
+ }
79
+ }
80
+
42
81
@ scala.annotation.tailrec
43
82
private def toLiteralProtoBuilderInternal (
44
83
literal : Any ,
@@ -58,17 +97,12 @@ object LiteralValueProtoConverter {
58
97
59
98
def arrayBuilder (array : Array [_]) = {
60
99
val ab = builder.getArrayBuilder
61
- if (options.useDeprecatedDataTypeFields) {
62
- ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
63
- } else {
64
- ab.setDataType(
65
- proto.DataType .Array
66
- .newBuilder()
67
- .setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType)))
68
- .setContainsNull(true )
69
- .build())
70
- }
71
100
array.foreach(x => ab.addElements(toLiteralProtoWithOptions(x, None , options)))
101
+ setArrayTypeAfterAddingElements(
102
+ ab,
103
+ toDataType(array.getClass.getComponentType),
104
+ containsNull = true ,
105
+ options.useDeprecatedDataTypeFields)
72
106
ab
73
107
}
74
108
@@ -122,16 +156,6 @@ object LiteralValueProtoConverter {
122
156
123
157
def arrayBuilder (scalaValue : Any , elementType : DataType , containsNull : Boolean ) = {
124
158
val ab = builder.getArrayBuilder
125
- if (options.useDeprecatedDataTypeFields) {
126
- ab.setElementType(toConnectProtoType(elementType))
127
- } else {
128
- ab.setDataType(
129
- proto.DataType .Array
130
- .newBuilder()
131
- .setElementType(toConnectProtoType(elementType))
132
- .setContainsNull(containsNull)
133
- .build())
134
- }
135
159
scalaValue match {
136
160
case a : Array [_] =>
137
161
a.foreach(item =>
@@ -142,7 +166,11 @@ object LiteralValueProtoConverter {
142
166
case other =>
143
167
throw new IllegalArgumentException (s " literal $other not supported (yet). " )
144
168
}
145
-
169
+ setArrayTypeAfterAddingElements(
170
+ ab,
171
+ elementType,
172
+ containsNull,
173
+ options.useDeprecatedDataTypeFields)
146
174
ab
147
175
}
148
176
@@ -152,19 +180,6 @@ object LiteralValueProtoConverter {
152
180
valueType : DataType ,
153
181
valueContainsNull : Boolean ) = {
154
182
val mb = builder.getMapBuilder
155
- if (options.useDeprecatedDataTypeFields) {
156
- mb.setKeyType(toConnectProtoType(keyType))
157
- mb.setValueType(toConnectProtoType(valueType))
158
- } else {
159
- mb.setDataType(
160
- proto.DataType .Map
161
- .newBuilder()
162
- .setKeyType(toConnectProtoType(keyType))
163
- .setValueType(toConnectProtoType(valueType))
164
- .setValueContainsNull(valueContainsNull)
165
- .build())
166
- }
167
-
168
183
scalaValue match {
169
184
case map : scala.collection.Map [_, _] =>
170
185
map.foreach { case (k, v) =>
@@ -174,7 +189,12 @@ object LiteralValueProtoConverter {
174
189
case other =>
175
190
throw new IllegalArgumentException (s " literal $other not supported (yet). " )
176
191
}
177
-
192
+ setMapTypeAfterAddingKeysAndValues(
193
+ mb,
194
+ keyType,
195
+ valueType,
196
+ valueContainsNull,
197
+ options.useDeprecatedDataTypeFields)
178
198
mb
179
199
}
180
200
@@ -414,6 +434,9 @@ object LiteralValueProtoConverter {
414
434
case proto.Expression .Literal .LiteralTypeCase .ARRAY =>
415
435
toCatalystArray(literal.getArray)
416
436
437
+ case proto.Expression .Literal .LiteralTypeCase .MAP =>
438
+ toCatalystMap(literal.getMap)
439
+
417
440
case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
418
441
toCatalystStruct(literal.getStruct)
419
442
0 commit comments