@@ -24,7 +24,7 @@ import java.time._
24
24
25
25
import scala .collection .{immutable , mutable }
26
26
import scala .jdk .CollectionConverters ._
27
- import scala .reflect . ClassTag
27
+ import scala .language . existentials
28
28
import scala .reflect .runtime .universe .TypeTag
29
29
import scala .util .Try
30
30
@@ -288,92 +288,216 @@ object LiteralValueProtoConverter {
288
288
SparkIntervalUtils .microsToDuration(literal.getDayTimeInterval)
289
289
290
290
case proto.Expression .Literal .LiteralTypeCase .ARRAY =>
291
- toCatalystArray(literal.getArray)
291
+ toCatalystArray(literal.getArray)._1
292
292
293
293
case other =>
294
294
throw new UnsupportedOperationException (
295
295
s " Unsupported Literal Type: ${other.getNumber} ( ${other.name}) " )
296
296
}
297
297
}
298
298
299
- private def getConverter (dataType : proto.DataType ): proto.Expression .Literal => Any = {
300
- if (dataType.hasShort) { v =>
301
- v.getShort.toShort
302
- } else if (dataType.hasInteger) { v =>
303
- v.getInteger
304
- } else if (dataType.hasLong) { v =>
305
- v.getLong
306
- } else if (dataType.hasDouble) { v =>
307
- v.getDouble
308
- } else if (dataType.hasByte) { v =>
309
- v.getByte.toByte
310
- } else if (dataType.hasFloat) { v =>
311
- v.getFloat
312
- } else if (dataType.hasBoolean) { v =>
313
- v.getBoolean
314
- } else if (dataType.hasString) { v =>
315
- v.getString
316
- } else if (dataType.hasBinary) { v =>
317
- v.getBinary.toByteArray
318
- } else if (dataType.hasDate) { v =>
319
- v.getDate
320
- } else if (dataType.hasTimestamp) { v =>
321
- v.getTimestamp
322
- } else if (dataType.hasTimestampNtz) { v =>
323
- v.getTimestampNtz
324
- } else if (dataType.hasDayTimeInterval) { v =>
325
- v.getDayTimeInterval
326
- } else if (dataType.hasYearMonthInterval) { v =>
327
- v.getYearMonthInterval
328
- } else if (dataType.hasDecimal) { v =>
329
- Decimal (v.getDecimal.getValue)
330
- } else if (dataType.hasCalendarInterval) { v =>
331
- val interval = v.getCalendarInterval
332
- new CalendarInterval (interval.getMonths, interval.getDays, interval.getMicroseconds)
333
- } else if (dataType.hasArray) { v =>
334
- toCatalystArray(v.getArray)
335
- } else if (dataType.hasMap) { v =>
336
- toCatalystMap(v.getMap)
337
- } else if (dataType.hasStruct) { v =>
338
- toCatalystStruct(v.getStruct)
339
- } else {
340
- throw InvalidPlanInput (s " Unsupported Literal Type: $dataType) " )
299
+ private def getConverter (
300
+ dataType : proto.DataType ,
301
+ inferDataType : Boolean = false ): proto.Expression .Literal => Any = {
302
+ dataType.getKindCase match {
303
+ case proto.DataType .KindCase .SHORT => v =>
304
+ v.getShort.toShort
305
+ case proto.DataType .KindCase .INTEGER => v =>
306
+ v.getInteger
307
+ case proto.DataType .KindCase .LONG => v =>
308
+ v.getLong
309
+ case proto.DataType .KindCase .DOUBLE => v =>
310
+ v.getDouble
311
+ case proto.DataType .KindCase .BYTE => v =>
312
+ v.getByte.toByte
313
+ case proto.DataType .KindCase .FLOAT => v =>
314
+ v.getFloat
315
+ case proto.DataType .KindCase .BOOLEAN => v =>
316
+ v.getBoolean
317
+ case proto.DataType .KindCase .STRING => v =>
318
+ v.getString
319
+ case proto.DataType .KindCase .BINARY => v =>
320
+ v.getBinary.toByteArray
321
+ case proto.DataType .KindCase .DATE => v =>
322
+ v.getDate
323
+ case proto.DataType .KindCase .TIMESTAMP => v =>
324
+ v.getTimestamp
325
+ case proto.DataType .KindCase .TIMESTAMP_NTZ => v =>
326
+ v.getTimestampNtz
327
+ case proto.DataType .KindCase .DAY_TIME_INTERVAL => v =>
328
+ v.getDayTimeInterval
329
+ case proto.DataType .KindCase .YEAR_MONTH_INTERVAL => v =>
330
+ v.getYearMonthInterval
331
+ case proto.DataType .KindCase .DECIMAL => v =>
332
+ Decimal (v.getDecimal.getValue)
333
+ case proto.DataType .KindCase .CALENDAR_INTERVAL => v =>
334
+ val interval = v.getCalendarInterval
335
+ new CalendarInterval (interval.getMonths, interval.getDays, interval.getMicroseconds)
336
+ case proto.DataType .KindCase .ARRAY =>
337
+ if (inferDataType) {
338
+ v => {
339
+ val (array, arrayType) = toCatalystArray(v.getArray, None )
340
+ LiteralValueWithDataType (array, proto.DataType .newBuilder.setArray(arrayType).build())
341
+ }
342
+ } else {
343
+ v => toCatalystArray(v.getArray, Some (dataType.getArray))._1
344
+ }
345
+ case proto.DataType .KindCase .MAP =>
346
+ if (inferDataType) {
347
+ v => {
348
+ val (map, mapType) = toCatalystMap(v.getMap, None )
349
+ LiteralValueWithDataType (map, proto.DataType .newBuilder.setMap(mapType).build())
350
+ }
351
+ } else {
352
+ v => toCatalystMap(v.getMap, Some (dataType.getMap))._1
353
+ }
354
+ case proto.DataType .KindCase .STRUCT => v =>
355
+ toCatalystStruct(v.getStruct)
356
+ case _ =>
357
+ throw InvalidPlanInput (s " Unsupported Literal Type: $dataType) " )
341
358
}
342
359
}
343
360
344
- def toCatalystArray (array : proto.Expression .Literal .Array ): Array [_] = {
345
- def makeArrayData [T ](converter : proto.Expression .Literal => T )(implicit
346
- tag : ClassTag [T ]): Array [T ] = {
347
- val builder = mutable.ArrayBuilder .make[T ]
348
- val elementList = array.getElementsList
349
- builder.sizeHint(elementList.size())
350
- val iter = elementList.iterator()
351
- while (iter.hasNext) {
352
- builder += converter(iter.next())
361
+ private def getBasicType (literal : proto.Expression .Literal ): proto.DataType = {
362
+ val builder = proto.DataType .newBuilder()
363
+ literal.getLiteralTypeCase match {
364
+ case proto.Expression .Literal .LiteralTypeCase .BINARY =>
365
+ builder.setBinary(proto.DataType .Binary .newBuilder.build())
366
+ case proto.Expression .Literal .LiteralTypeCase .BOOLEAN =>
367
+ builder.setBoolean(proto.DataType .Boolean .newBuilder.build())
368
+ case proto.Expression .Literal .LiteralTypeCase .BYTE =>
369
+ builder.setByte(proto.DataType .Byte .newBuilder.build())
370
+ case proto.Expression .Literal .LiteralTypeCase .SHORT =>
371
+ builder.setShort(proto.DataType .Short .newBuilder.build())
372
+ case proto.Expression .Literal .LiteralTypeCase .INTEGER =>
373
+ builder.setInteger(proto.DataType .Integer .newBuilder.build())
374
+ case proto.Expression .Literal .LiteralTypeCase .LONG =>
375
+ builder.setLong(proto.DataType .Long .newBuilder.build())
376
+ case proto.Expression .Literal .LiteralTypeCase .FLOAT =>
377
+ builder.setFloat(proto.DataType .Float .newBuilder.build())
378
+ case proto.Expression .Literal .LiteralTypeCase .DOUBLE =>
379
+ builder.setDouble(proto.DataType .Double .newBuilder.build())
380
+ case proto.Expression .Literal .LiteralTypeCase .STRING =>
381
+ builder.setString(proto.DataType .String .newBuilder.build())
382
+ case proto.Expression .Literal .LiteralTypeCase .DATE =>
383
+ builder.setDate(proto.DataType .Date .newBuilder.build())
384
+ case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP =>
385
+ builder.setTimestamp(proto.DataType .Timestamp .newBuilder.build())
386
+ case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP_NTZ =>
387
+ builder.setTimestampNtz(proto.DataType .TimestampNTZ .newBuilder.build())
388
+ case proto.Expression .Literal .LiteralTypeCase .YEAR_MONTH_INTERVAL =>
389
+ builder.setYearMonthInterval(proto.DataType .YearMonthInterval .newBuilder.build())
390
+ case proto.Expression .Literal .LiteralTypeCase .DECIMAL =>
391
+ builder.setDecimal(proto.DataType .Decimal .newBuilder.build())
392
+ case proto.Expression .Literal .LiteralTypeCase .CALENDAR_INTERVAL =>
393
+ builder.setCalendarInterval(proto.DataType .CalendarInterval .newBuilder.build())
394
+ case proto.Expression .Literal .LiteralTypeCase .DAY_TIME_INTERVAL =>
395
+ builder.setDayTimeInterval(proto.DataType .DayTimeInterval .newBuilder.build())
396
+ case proto.Expression .Literal .LiteralTypeCase .ARRAY =>
397
+ // Element type will be inferred from the elements in the array.
398
+ builder.setArray(proto.DataType .Array .newBuilder.build())
399
+ case proto.Expression .Literal .LiteralTypeCase .MAP =>
400
+ // Key and value types will be inferred from the keys and values in the map.
401
+ builder.setMap(proto.DataType .Map .newBuilder.build())
402
+ case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
403
+ builder.setStruct(literal.getStruct.getStructType.getStruct)
404
+ case _ => throw InvalidPlanInput (s " Unsupported Literal Type: ${literal.getLiteralTypeCase}" )
405
+ }
406
+ builder.build()
407
+ }
408
+
409
+ def toCatalystArray (
410
+ array : proto.Expression .Literal .Array ,
411
+ arrayTypeOpt : Option [proto.DataType .Array ] = None ): (Array [_], proto.DataType .Array ) = {
412
+ def protoArrayType (elementType : proto.DataType ): proto.DataType .Array = {
413
+ proto.DataType .Array .newBuilder().setElementType(elementType).build()
414
+ }
415
+
416
+ val builder = mutable.ArrayBuilder .make[Any ]
417
+ builder.sizeHint(array.getElementsList.size())
418
+
419
+ val iter = array.getElementsList.iterator()
420
+
421
+ def inferDataType (): proto.DataType .Array = {
422
+ if (arrayTypeOpt.isDefined) {
423
+ arrayTypeOpt.get
424
+ } else if (array.hasElementType) {
425
+ protoArrayType(array.getElementType)
426
+ } else if (iter.hasNext) {
427
+ val firstElement = iter.next()
428
+ val basicElementType = getBasicType(firstElement)
429
+ val (elem, inferredElementType) =
430
+ getConverter(basicElementType, inferDataType = true )(firstElement) match {
431
+ case LiteralValueWithDataType (elem, dataType) => (elem, dataType)
432
+ case elem => (elem, basicElementType)
433
+ }
434
+ builder += elem
435
+ protoArrayType(inferredElementType)
436
+ } else {
437
+ throw InvalidPlanInput (" Cannot infer element type for an empty array" )
353
438
}
354
- builder.result()
355
439
}
356
440
357
- makeArrayData(getConverter(array.getElementType))
441
+ val dataType = inferDataType()
442
+ val converter = getConverter(dataType.getElementType)
443
+
444
+ while (iter.hasNext) {
445
+ builder += converter(iter.next())
446
+ }
447
+ builder.result()
448
+
449
+ (builder.result(), dataType)
358
450
}
359
451
360
- def toCatalystMap (map : proto.Expression .Literal .Map ): mutable.Map [_, _] = {
361
- def makeMapData [K , V ](
362
- keyConverter : proto.Expression .Literal => K ,
363
- valueConverter : proto.Expression .Literal => V )(implicit
364
- tagK : ClassTag [K ],
365
- tagV : ClassTag [V ]): mutable.Map [K , V ] = {
366
- val builder = mutable.HashMap .empty[K , V ]
367
- val keys = map.getKeysList.asScala
368
- val values = map.getValuesList.asScala
369
- builder.sizeHint(keys.size)
370
- keys.zip(values).foreach { case (key, value) =>
371
- builder += ((keyConverter(key), valueConverter(value)))
452
+ def toCatalystMap (
453
+ map : proto.Expression .Literal .Map ,
454
+ mapTypeOpt : Option [proto.DataType .Map ] = None ): (mutable.Map [_, _], proto.DataType .Map ) = {
455
+ def protoMapType (keyType : proto.DataType , valueType : proto.DataType ): proto.DataType .Map = {
456
+ proto.DataType .Map .newBuilder().setKeyType(keyType).setValueType(valueType).build()
457
+ }
458
+ val builder = mutable.HashMap .newBuilder[Any , Any ]
459
+ val keyValuePairs = map.getKeysList.asScala.zip(map.getValuesList.asScala)
460
+ builder.sizeHint(keyValuePairs.size)
461
+
462
+ val iter = keyValuePairs.iterator
463
+
464
+ def inferDataType (): proto.DataType .Map = {
465
+ if (mapTypeOpt.isDefined) {
466
+ mapTypeOpt.get
467
+ } else if (map.hasKeyType && map.hasValueType) {
468
+ protoMapType(map.getKeyType, map.getValueType)
469
+ } else if (iter.hasNext) {
470
+ val (key, value) = iter.next()
471
+ val basicKeyType = getBasicType(key)
472
+ val (convertedKey, inferredKeyType) =
473
+ getConverter(basicKeyType, inferDataType = true )(key) match {
474
+ case LiteralValueWithDataType (convertedKey, dataType) => (convertedKey, dataType)
475
+ case convertedKey => (convertedKey, basicKeyType)
476
+ }
477
+ val basicValueType = getBasicType(value)
478
+ val (convertedValue, inferredValueType) =
479
+ getConverter(basicValueType, inferDataType = true )(value) match {
480
+ case LiteralValueWithDataType (convertedValue, dataType) => (convertedValue, dataType)
481
+ case convertedValue => (convertedValue, basicValueType)
482
+ }
483
+ builder += ((convertedKey, convertedValue))
484
+ protoMapType(inferredKeyType, inferredValueType)
485
+ } else {
486
+ throw InvalidPlanInput (" Cannot infer key and value type for an empty map" )
372
487
}
373
- builder
374
488
}
375
489
376
- makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
490
+ val dataType = inferDataType()
491
+ val keyConverter = getConverter(dataType.getKeyType)
492
+ val valueConverter = getConverter(dataType.getValueType)
493
+
494
+ while (iter.hasNext) {
495
+ val (key, value) = iter.next()
496
+ builder += ((keyConverter(key), valueConverter(value)))
497
+ }
498
+ builder.result()
499
+
500
+ (builder.result(), dataType)
377
501
}
378
502
379
503
def toCatalystStruct (struct : proto.Expression .Literal .Struct ): Any = {
@@ -399,4 +523,8 @@ object LiteralValueProtoConverter {
399
523
400
524
toTuple(structData)
401
525
}
526
+
527
+ private case class LiteralValueWithDataType (
528
+ value : Any ,
529
+ dataType : proto.DataType )
402
530
}
0 commit comments