Skip to content

Commit 9ee011c

Browse files
committed
[SPARK-52449] Make datatypes for Expression.Literal.Map/Expression.Literal.Array optional
1 parent d88298a commit 9ee011c

File tree

3 files changed

+215
-77
lines changed

3 files changed

+215
-77
lines changed

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,19 @@ message Expression {
214214
}
215215

216216
message Array {
217-
DataType element_type = 1;
217+
// (Optional) The element type of the array. Only need to set this when the elements is
218+
// empty since now we support infer the element type from the elements.
219+
optional DataType element_type = 1;
218220
repeated Literal elements = 2;
219221
}
220222

221223
message Map {
222-
DataType key_type = 1;
223-
DataType value_type = 2;
224+
// (Optional) The key type of the map. Only need to set this when the keys is
225+
// empty since now we support infer the key type from the keys.
226+
optional DataType key_type = 1;
227+
// (Optional) The value type of the map. Only need to set this when the values is
228+
// empty since now we support infer the value type from the values.
229+
optional DataType value_type = 2;
224230
repeated Literal keys = 3;
225231
repeated Literal values = 4;
226232
}

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 197 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import java.time._
2424

2525
import scala.collection.{immutable, mutable}
2626
import scala.jdk.CollectionConverters._
27-
import scala.reflect.ClassTag
27+
import scala.language.existentials
2828
import scala.reflect.runtime.universe.TypeTag
2929
import scala.util.Try
3030

@@ -288,92 +288,216 @@ object LiteralValueProtoConverter {
288288
SparkIntervalUtils.microsToDuration(literal.getDayTimeInterval)
289289

290290
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
291-
toCatalystArray(literal.getArray)
291+
toCatalystArray(literal.getArray)._1
292292

293293
case other =>
294294
throw new UnsupportedOperationException(
295295
s"Unsupported Literal Type: ${other.getNumber} (${other.name})")
296296
}
297297
}
298298

299-
private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = {
300-
if (dataType.hasShort) { v =>
301-
v.getShort.toShort
302-
} else if (dataType.hasInteger) { v =>
303-
v.getInteger
304-
} else if (dataType.hasLong) { v =>
305-
v.getLong
306-
} else if (dataType.hasDouble) { v =>
307-
v.getDouble
308-
} else if (dataType.hasByte) { v =>
309-
v.getByte.toByte
310-
} else if (dataType.hasFloat) { v =>
311-
v.getFloat
312-
} else if (dataType.hasBoolean) { v =>
313-
v.getBoolean
314-
} else if (dataType.hasString) { v =>
315-
v.getString
316-
} else if (dataType.hasBinary) { v =>
317-
v.getBinary.toByteArray
318-
} else if (dataType.hasDate) { v =>
319-
v.getDate
320-
} else if (dataType.hasTimestamp) { v =>
321-
v.getTimestamp
322-
} else if (dataType.hasTimestampNtz) { v =>
323-
v.getTimestampNtz
324-
} else if (dataType.hasDayTimeInterval) { v =>
325-
v.getDayTimeInterval
326-
} else if (dataType.hasYearMonthInterval) { v =>
327-
v.getYearMonthInterval
328-
} else if (dataType.hasDecimal) { v =>
329-
Decimal(v.getDecimal.getValue)
330-
} else if (dataType.hasCalendarInterval) { v =>
331-
val interval = v.getCalendarInterval
332-
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
333-
} else if (dataType.hasArray) { v =>
334-
toCatalystArray(v.getArray)
335-
} else if (dataType.hasMap) { v =>
336-
toCatalystMap(v.getMap)
337-
} else if (dataType.hasStruct) { v =>
338-
toCatalystStruct(v.getStruct)
339-
} else {
340-
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
299+
private def getConverter(
300+
dataType: proto.DataType,
301+
inferDataType: Boolean = false): proto.Expression.Literal => Any = {
302+
dataType.getKindCase match {
303+
case proto.DataType.KindCase.SHORT => v =>
304+
v.getShort.toShort
305+
case proto.DataType.KindCase.INTEGER => v =>
306+
v.getInteger
307+
case proto.DataType.KindCase.LONG => v =>
308+
v.getLong
309+
case proto.DataType.KindCase.DOUBLE => v =>
310+
v.getDouble
311+
case proto.DataType.KindCase.BYTE => v =>
312+
v.getByte.toByte
313+
case proto.DataType.KindCase.FLOAT => v =>
314+
v.getFloat
315+
case proto.DataType.KindCase.BOOLEAN => v =>
316+
v.getBoolean
317+
case proto.DataType.KindCase.STRING => v =>
318+
v.getString
319+
case proto.DataType.KindCase.BINARY => v =>
320+
v.getBinary.toByteArray
321+
case proto.DataType.KindCase.DATE => v =>
322+
v.getDate
323+
case proto.DataType.KindCase.TIMESTAMP => v =>
324+
v.getTimestamp
325+
case proto.DataType.KindCase.TIMESTAMP_NTZ => v =>
326+
v.getTimestampNtz
327+
case proto.DataType.KindCase.DAY_TIME_INTERVAL => v =>
328+
v.getDayTimeInterval
329+
case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => v =>
330+
v.getYearMonthInterval
331+
case proto.DataType.KindCase.DECIMAL => v =>
332+
Decimal(v.getDecimal.getValue)
333+
case proto.DataType.KindCase.CALENDAR_INTERVAL => v =>
334+
val interval = v.getCalendarInterval
335+
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
336+
case proto.DataType.KindCase.ARRAY =>
337+
if (inferDataType) {
338+
v => {
339+
val (array, arrayType) = toCatalystArray(v.getArray, None)
340+
LiteralValueWithDataType(array, proto.DataType.newBuilder.setArray(arrayType).build())
341+
}
342+
} else {
343+
v => toCatalystArray(v.getArray, Some(dataType.getArray))._1
344+
}
345+
case proto.DataType.KindCase.MAP =>
346+
if (inferDataType) {
347+
v => {
348+
val (map, mapType) = toCatalystMap(v.getMap, None)
349+
LiteralValueWithDataType(map, proto.DataType.newBuilder.setMap(mapType).build())
350+
}
351+
} else {
352+
v => toCatalystMap(v.getMap, Some(dataType.getMap))._1
353+
}
354+
case proto.DataType.KindCase.STRUCT => v =>
355+
toCatalystStruct(v.getStruct)
356+
case _ =>
357+
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
341358
}
342359
}
343360

344-
def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = {
345-
def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit
346-
tag: ClassTag[T]): Array[T] = {
347-
val builder = mutable.ArrayBuilder.make[T]
348-
val elementList = array.getElementsList
349-
builder.sizeHint(elementList.size())
350-
val iter = elementList.iterator()
351-
while (iter.hasNext) {
352-
builder += converter(iter.next())
361+
private def getBasicType(literal: proto.Expression.Literal): proto.DataType = {
362+
val builder = proto.DataType.newBuilder()
363+
literal.getLiteralTypeCase match {
364+
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
365+
builder.setBinary(proto.DataType.Binary.newBuilder.build())
366+
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
367+
builder.setBoolean(proto.DataType.Boolean.newBuilder.build())
368+
case proto.Expression.Literal.LiteralTypeCase.BYTE =>
369+
builder.setByte(proto.DataType.Byte.newBuilder.build())
370+
case proto.Expression.Literal.LiteralTypeCase.SHORT =>
371+
builder.setShort(proto.DataType.Short.newBuilder.build())
372+
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
373+
builder.setInteger(proto.DataType.Integer.newBuilder.build())
374+
case proto.Expression.Literal.LiteralTypeCase.LONG =>
375+
builder.setLong(proto.DataType.Long.newBuilder.build())
376+
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
377+
builder.setFloat(proto.DataType.Float.newBuilder.build())
378+
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
379+
builder.setDouble(proto.DataType.Double.newBuilder.build())
380+
case proto.Expression.Literal.LiteralTypeCase.STRING =>
381+
builder.setString(proto.DataType.String.newBuilder.build())
382+
case proto.Expression.Literal.LiteralTypeCase.DATE =>
383+
builder.setDate(proto.DataType.Date.newBuilder.build())
384+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
385+
builder.setTimestamp(proto.DataType.Timestamp.newBuilder.build())
386+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
387+
builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder.build())
388+
case proto.Expression.Literal.LiteralTypeCase.YEAR_MONTH_INTERVAL =>
389+
builder.setYearMonthInterval(proto.DataType.YearMonthInterval.newBuilder.build())
390+
case proto.Expression.Literal.LiteralTypeCase.DECIMAL =>
391+
builder.setDecimal(proto.DataType.Decimal.newBuilder.build())
392+
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
393+
builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build())
394+
case proto.Expression.Literal.LiteralTypeCase.DAY_TIME_INTERVAL =>
395+
builder.setDayTimeInterval(proto.DataType.DayTimeInterval.newBuilder.build())
396+
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
397+
// Element type will be inferred from the elements in the array.
398+
builder.setArray(proto.DataType.Array.newBuilder.build())
399+
case proto.Expression.Literal.LiteralTypeCase.MAP =>
400+
// Key and value types will be inferred from the keys and values in the map.
401+
builder.setMap(proto.DataType.Map.newBuilder.build())
402+
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
403+
builder.setStruct(literal.getStruct.getStructType.getStruct)
404+
case _ => throw InvalidPlanInput(s"Unsupported Literal Type: ${literal.getLiteralTypeCase}")
405+
}
406+
builder.build()
407+
}
408+
409+
def toCatalystArray(
410+
array: proto.Expression.Literal.Array,
411+
arrayTypeOpt: Option[proto.DataType.Array] = None): (Array[_], proto.DataType.Array) = {
412+
def protoArrayType(elementType: proto.DataType): proto.DataType.Array = {
413+
proto.DataType.Array.newBuilder().setElementType(elementType).build()
414+
}
415+
416+
val builder = mutable.ArrayBuilder.make[Any]
417+
builder.sizeHint(array.getElementsList.size())
418+
419+
val iter = array.getElementsList.iterator()
420+
421+
def inferDataType(): proto.DataType.Array = {
422+
if (arrayTypeOpt.isDefined) {
423+
arrayTypeOpt.get
424+
} else if (array.hasElementType) {
425+
protoArrayType(array.getElementType)
426+
} else if (iter.hasNext) {
427+
val firstElement = iter.next()
428+
val basicElementType = getBasicType(firstElement)
429+
val (elem, inferredElementType) =
430+
getConverter(basicElementType, inferDataType = true)(firstElement) match {
431+
case LiteralValueWithDataType(elem, dataType) => (elem, dataType)
432+
case elem => (elem, basicElementType)
433+
}
434+
builder += elem
435+
protoArrayType(inferredElementType)
436+
} else {
437+
throw InvalidPlanInput("Cannot infer element type for an empty array")
353438
}
354-
builder.result()
355439
}
356440

357-
makeArrayData(getConverter(array.getElementType))
441+
val dataType = inferDataType()
442+
val converter = getConverter(dataType.getElementType)
443+
444+
while (iter.hasNext) {
445+
builder += converter(iter.next())
446+
}
447+
builder.result()
448+
449+
(builder.result(), dataType)
358450
}
359451

360-
def toCatalystMap(map: proto.Expression.Literal.Map): mutable.Map[_, _] = {
361-
def makeMapData[K, V](
362-
keyConverter: proto.Expression.Literal => K,
363-
valueConverter: proto.Expression.Literal => V)(implicit
364-
tagK: ClassTag[K],
365-
tagV: ClassTag[V]): mutable.Map[K, V] = {
366-
val builder = mutable.HashMap.empty[K, V]
367-
val keys = map.getKeysList.asScala
368-
val values = map.getValuesList.asScala
369-
builder.sizeHint(keys.size)
370-
keys.zip(values).foreach { case (key, value) =>
371-
builder += ((keyConverter(key), valueConverter(value)))
452+
def toCatalystMap(
453+
map: proto.Expression.Literal.Map,
454+
mapTypeOpt: Option[proto.DataType.Map] = None): (mutable.Map[_, _], proto.DataType.Map) = {
455+
def protoMapType(keyType: proto.DataType, valueType: proto.DataType): proto.DataType.Map = {
456+
proto.DataType.Map.newBuilder().setKeyType(keyType).setValueType(valueType).build()
457+
}
458+
val builder = mutable.HashMap.newBuilder[Any, Any]
459+
val keyValuePairs = map.getKeysList.asScala.zip(map.getValuesList.asScala)
460+
builder.sizeHint(keyValuePairs.size)
461+
462+
val iter = keyValuePairs.iterator
463+
464+
def inferDataType(): proto.DataType.Map = {
465+
if (mapTypeOpt.isDefined) {
466+
mapTypeOpt.get
467+
} else if (map.hasKeyType && map.hasValueType) {
468+
protoMapType(map.getKeyType, map.getValueType)
469+
} else if (iter.hasNext) {
470+
val (key, value) = iter.next()
471+
val basicKeyType = getBasicType(key)
472+
val (convertedKey, inferredKeyType) =
473+
getConverter(basicKeyType, inferDataType = true)(key) match {
474+
case LiteralValueWithDataType(convertedKey, dataType) => (convertedKey, dataType)
475+
case convertedKey => (convertedKey, basicKeyType)
476+
}
477+
val basicValueType = getBasicType(value)
478+
val (convertedValue, inferredValueType) =
479+
getConverter(basicValueType, inferDataType = true)(value) match {
480+
case LiteralValueWithDataType(convertedValue, dataType) => (convertedValue, dataType)
481+
case convertedValue => (convertedValue, basicValueType)
482+
}
483+
builder += ((convertedKey, convertedValue))
484+
protoMapType(inferredKeyType, inferredValueType)
485+
} else {
486+
throw InvalidPlanInput("Cannot infer key and value type for an empty map")
372487
}
373-
builder
374488
}
375489

376-
makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
490+
val dataType = inferDataType()
491+
val keyConverter = getConverter(dataType.getKeyType)
492+
val valueConverter = getConverter(dataType.getValueType)
493+
494+
while (iter.hasNext) {
495+
val (key, value) = iter.next()
496+
builder += ((keyConverter(key), valueConverter(value)))
497+
}
498+
builder.result()
499+
500+
(builder.result(), dataType)
377501
}
378502

379503
def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
@@ -399,4 +523,8 @@ object LiteralValueProtoConverter {
399523

400524
toTuple(structData)
401525
}
526+
527+
private case class LiteralValueWithDataType(
528+
value: Any,
529+
dataType: proto.DataType)
402530
}

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.connect.planner
1919

20+
import scala.language.existentials
21+
2022
import org.apache.spark.connect.proto
2123
import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters}
2224
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter}
@@ -98,16 +100,18 @@ object LiteralExpressionProtoConverter {
98100
expressions.Literal(lit.getDayTimeInterval, DayTimeIntervalType())
99101

100102
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
103+
val (array, arrayType) = LiteralValueProtoConverter.toCatalystArray(lit.getArray)
101104
expressions.Literal.create(
102-
LiteralValueProtoConverter.toCatalystArray(lit.getArray),
103-
ArrayType(DataTypeProtoConverter.toCatalystType(lit.getArray.getElementType)))
105+
array,
106+
ArrayType(DataTypeProtoConverter.toCatalystType(arrayType.getElementType)))
104107

105108
case proto.Expression.Literal.LiteralTypeCase.MAP =>
109+
val (map, mapType) = LiteralValueProtoConverter.toCatalystMap(lit.getMap)
106110
expressions.Literal.create(
107-
LiteralValueProtoConverter.toCatalystMap(lit.getMap),
111+
map,
108112
MapType(
109-
DataTypeProtoConverter.toCatalystType(lit.getMap.getKeyType),
110-
DataTypeProtoConverter.toCatalystType(lit.getMap.getValueType)))
113+
DataTypeProtoConverter.toCatalystType(mapType.getKeyType),
114+
DataTypeProtoConverter.toCatalystType(mapType.getValueType)))
111115

112116
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
113117
val dataType = DataTypeProtoConverter.toCatalystType(lit.getStruct.getStructType)

0 commit comments

Comments
 (0)