@@ -149,17 +149,36 @@ object LiteralValueProtoConverter {
149
149
}
150
150
151
151
def structBuilder (scalaValue : Any , structType : StructType ) = {
152
- val sb = builder.getStructBuilder.setStructType(toConnectProtoType(structType))
153
- val dataTypes = structType.fields.map(_.dataType)
152
+ val sb = builder.getStructBuilder
153
+ val fields = structType.fields
154
154
155
155
scalaValue match {
156
156
case p : Product =>
157
157
val iter = p.productIterator
158
+ val dataTypeStruct = proto.DataType .Struct .newBuilder()
158
159
var idx = 0
159
160
while (idx < structType.size) {
160
- sb.addElements(toLiteralProto(iter.next(), dataTypes(idx)))
161
+ val field = fields(idx)
162
+ val literalProto = toLiteralProto(iter.next(), field.dataType)
163
+ sb.addElements(literalProto)
164
+
165
+ val fieldBuilder = dataTypeStruct
166
+ .addFieldsBuilder()
167
+ .setName(field.name)
168
+ .setNullable(field.nullable)
169
+
170
+ if (LiteralValueProtoConverter .getInferredDataType(literalProto).isEmpty) {
171
+ fieldBuilder.setDataType(toConnectProtoType(field.dataType))
172
+ }
173
+
174
+ // Set metadata if available
175
+ if (field.metadata != Metadata .empty) {
176
+ fieldBuilder.setMetadata(field.metadata.json)
177
+ }
178
+
161
179
idx += 1
162
180
}
181
+ sb.setDataTypeStruct(dataTypeStruct.build())
163
182
case other =>
164
183
throw new IllegalArgumentException (s " literal $other not supported (yet). " )
165
184
}
@@ -300,54 +319,101 @@ object LiteralValueProtoConverter {
300
319
case proto.Expression .Literal .LiteralTypeCase .ARRAY =>
301
320
toCatalystArray(literal.getArray)
302
321
322
+ case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
323
+ toCatalystStruct(literal.getStruct)._1
324
+
303
325
case other =>
304
326
throw new UnsupportedOperationException (
305
327
s " Unsupported Literal Type: ${other.getNumber} ( ${other.name}) " )
306
328
}
307
329
}
308
330
309
- private def getConverter (dataType : proto.DataType ): proto.Expression .Literal => Any = {
310
- if (dataType.hasShort) { v =>
311
- v.getShort.toShort
312
- } else if (dataType.hasInteger) { v =>
313
- v.getInteger
314
- } else if (dataType.hasLong) { v =>
315
- v.getLong
316
- } else if (dataType.hasDouble) { v =>
317
- v.getDouble
318
- } else if (dataType.hasByte) { v =>
319
- v.getByte.toByte
320
- } else if (dataType.hasFloat) { v =>
321
- v.getFloat
322
- } else if (dataType.hasBoolean) { v =>
323
- v.getBoolean
324
- } else if (dataType.hasString) { v =>
325
- v.getString
326
- } else if (dataType.hasBinary) { v =>
327
- v.getBinary.toByteArray
328
- } else if (dataType.hasDate) { v =>
329
- v.getDate
330
- } else if (dataType.hasTimestamp) { v =>
331
- v.getTimestamp
332
- } else if (dataType.hasTimestampNtz) { v =>
333
- v.getTimestampNtz
334
- } else if (dataType.hasDayTimeInterval) { v =>
335
- v.getDayTimeInterval
336
- } else if (dataType.hasYearMonthInterval) { v =>
337
- v.getYearMonthInterval
338
- } else if (dataType.hasDecimal) { v =>
339
- Decimal (v.getDecimal.getValue)
340
- } else if (dataType.hasCalendarInterval) { v =>
341
- val interval = v.getCalendarInterval
342
- new CalendarInterval (interval.getMonths, interval.getDays, interval.getMicroseconds)
343
- } else if (dataType.hasArray) { v =>
344
- toCatalystArray(v.getArray)
345
- } else if (dataType.hasMap) { v =>
346
- toCatalystMap(v.getMap)
347
- } else if (dataType.hasStruct) { v =>
348
- toCatalystStruct(v.getStruct)
349
- } else {
350
- throw InvalidPlanInput (s " Unsupported Literal Type: $dataType) " )
331
+ private def getConverter (
332
+ dataType : proto.DataType ,
333
+ inferDataType : Boolean = false ): proto.Expression .Literal => Any = {
334
+ dataType.getKindCase match {
335
+ case proto.DataType .KindCase .SHORT => v => v.getShort.toShort
336
+ case proto.DataType .KindCase .INTEGER => v => v.getInteger
337
+ case proto.DataType .KindCase .LONG => v => v.getLong
338
+ case proto.DataType .KindCase .DOUBLE => v => v.getDouble
339
+ case proto.DataType .KindCase .BYTE => v => v.getByte.toByte
340
+ case proto.DataType .KindCase .FLOAT => v => v.getFloat
341
+ case proto.DataType .KindCase .BOOLEAN => v => v.getBoolean
342
+ case proto.DataType .KindCase .STRING => v => v.getString
343
+ case proto.DataType .KindCase .BINARY => v => v.getBinary.toByteArray
344
+ case proto.DataType .KindCase .DATE => v => v.getDate
345
+ case proto.DataType .KindCase .TIMESTAMP => v => v.getTimestamp
346
+ case proto.DataType .KindCase .TIMESTAMP_NTZ => v => v.getTimestampNtz
347
+ case proto.DataType .KindCase .DAY_TIME_INTERVAL => v => v.getDayTimeInterval
348
+ case proto.DataType .KindCase .YEAR_MONTH_INTERVAL => v => v.getYearMonthInterval
349
+ case proto.DataType .KindCase .DECIMAL => v => Decimal (v.getDecimal.getValue)
350
+ case proto.DataType .KindCase .CALENDAR_INTERVAL =>
351
+ v =>
352
+ val interval = v.getCalendarInterval
353
+ new CalendarInterval (interval.getMonths, interval.getDays, interval.getMicroseconds)
354
+ case proto.DataType .KindCase .ARRAY => v => toCatalystArray(v.getArray)
355
+ case proto.DataType .KindCase .MAP => v => toCatalystMap(v.getMap)
356
+ case proto.DataType .KindCase .STRUCT =>
357
+ if (inferDataType) { v =>
358
+ val (struct, structType) = toCatalystStruct(v.getStruct, None )
359
+ LiteralValueWithDataType (
360
+ struct,
361
+ proto.DataType .newBuilder.setStruct(structType).build())
362
+ } else { v =>
363
+ toCatalystStruct(v.getStruct, Some (dataType.getStruct))._1
364
+ }
365
+ case _ =>
366
+ throw InvalidPlanInput (s " Unsupported Literal Type: $dataType) " )
367
+ }
368
+ }
369
+
370
+ private def getInferredDataType (literal : proto.Expression .Literal ): Option [proto.DataType ] = {
371
+ if (literal.hasNull) {
372
+ return Some (literal.getNull)
373
+ }
374
+
375
+ val builder = proto.DataType .newBuilder()
376
+ literal.getLiteralTypeCase match {
377
+ case proto.Expression .Literal .LiteralTypeCase .BINARY =>
378
+ builder.setBinary(proto.DataType .Binary .newBuilder.build())
379
+ case proto.Expression .Literal .LiteralTypeCase .BOOLEAN =>
380
+ builder.setBoolean(proto.DataType .Boolean .newBuilder.build())
381
+ case proto.Expression .Literal .LiteralTypeCase .BYTE =>
382
+ builder.setByte(proto.DataType .Byte .newBuilder.build())
383
+ case proto.Expression .Literal .LiteralTypeCase .SHORT =>
384
+ builder.setShort(proto.DataType .Short .newBuilder.build())
385
+ case proto.Expression .Literal .LiteralTypeCase .INTEGER =>
386
+ builder.setInteger(proto.DataType .Integer .newBuilder.build())
387
+ case proto.Expression .Literal .LiteralTypeCase .LONG =>
388
+ builder.setLong(proto.DataType .Long .newBuilder.build())
389
+ case proto.Expression .Literal .LiteralTypeCase .FLOAT =>
390
+ builder.setFloat(proto.DataType .Float .newBuilder.build())
391
+ case proto.Expression .Literal .LiteralTypeCase .DOUBLE =>
392
+ builder.setDouble(proto.DataType .Double .newBuilder.build())
393
+ case proto.Expression .Literal .LiteralTypeCase .DATE =>
394
+ builder.setDate(proto.DataType .Date .newBuilder.build())
395
+ case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP =>
396
+ builder.setTimestamp(proto.DataType .Timestamp .newBuilder.build())
397
+ case proto.Expression .Literal .LiteralTypeCase .TIMESTAMP_NTZ =>
398
+ builder.setTimestampNtz(proto.DataType .TimestampNTZ .newBuilder.build())
399
+ case proto.Expression .Literal .LiteralTypeCase .CALENDAR_INTERVAL =>
400
+ builder.setCalendarInterval(proto.DataType .CalendarInterval .newBuilder.build())
401
+ case proto.Expression .Literal .LiteralTypeCase .STRUCT =>
402
+ // The type of the fields will be inferred from the literals of the fields in the struct.
403
+ builder.setStruct(literal.getStruct.getStructType.getStruct)
404
+ case _ =>
405
+ // Not all data types support inferring the data type from the literal at the moment.
406
+ // e.g. the type of DayTimeInterval contains extra information like start_field and
407
+ // end_field and cannot be inferred from the literal.
408
+ return None
409
+ }
410
+ Some (builder.build())
411
+ }
412
+
413
+ private def getInferredDataTypeOrThrow (literal : proto.Expression .Literal ): proto.DataType = {
414
+ getInferredDataType(literal).getOrElse {
415
+ throw InvalidPlanInput (
416
+ s " Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}" )
351
417
}
352
418
}
353
419
@@ -386,7 +452,9 @@ object LiteralValueProtoConverter {
386
452
makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
387
453
}
388
454
389
- def toCatalystStruct (struct : proto.Expression .Literal .Struct ): Any = {
455
+ def toCatalystStruct (
456
+ struct : proto.Expression .Literal .Struct ,
457
+ structTypeOpt : Option [proto.DataType .Struct ] = None ): (Any , proto.DataType .Struct ) = {
390
458
def toTuple [A <: Object ](data : Seq [A ]): Product = {
391
459
try {
392
460
val tupleClass = SparkClassUtils .classForName(s " scala.Tuple ${data.length}" )
@@ -397,16 +465,78 @@ object LiteralValueProtoConverter {
397
465
}
398
466
}
399
467
400
- val elements = struct.getElementsList.asScala
401
- val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
402
- val structData = elements
403
- .zip(dataTypes)
404
- .map { case (element, dataType) =>
405
- getConverter(dataType)(element)
468
+ if (struct.hasDataTypeStruct) {
469
+ // The new way to define and convert structs.
470
+ val (structData, structType) = if (structTypeOpt.isDefined) {
471
+ val structFields = structTypeOpt.get.getFieldsList.asScala
472
+ val structData =
473
+ struct.getElementsList.asScala.zip(structFields).map { case (element, structField) =>
474
+ getConverter(structField.getDataType)(element)
475
+ }
476
+ (structData, structTypeOpt.get)
477
+ } else {
478
+ def protoStructField (
479
+ name : String ,
480
+ dataType : proto.DataType ,
481
+ nullable : Boolean ,
482
+ metadata : Option [String ]): proto.DataType .StructField = {
483
+ val builder = proto.DataType .StructField
484
+ .newBuilder()
485
+ .setName(name)
486
+ .setDataType(dataType)
487
+ .setNullable(nullable)
488
+ metadata.foreach(builder.setMetadata)
489
+ builder.build()
490
+ }
491
+
492
+ val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala
493
+
494
+ val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map {
495
+ case (element, dataTypeField) =>
496
+ if (dataTypeField.hasDataType) {
497
+ (getConverter(dataTypeField.getDataType)(element), dataTypeField)
498
+ } else {
499
+ val outerDataType = getInferredDataTypeOrThrow(element)
500
+ val (value, dataType) =
501
+ getConverter(outerDataType, inferDataType = true )(element) match {
502
+ case LiteralValueWithDataType (value, dataType) => (value, dataType)
503
+ case value => (value, outerDataType)
504
+ }
505
+ (
506
+ value,
507
+ protoStructField(
508
+ dataTypeField.getName,
509
+ dataType,
510
+ dataTypeField.getNullable,
511
+ if (dataTypeField.hasMetadata) Some (dataTypeField.getMetadata) else None ))
512
+ }
513
+ }
514
+
515
+ val structType = proto.DataType .Struct
516
+ .newBuilder()
517
+ .addAllFields(structDataAndFields.map(_._2).asJava)
518
+ .build()
519
+
520
+ (structDataAndFields.map(_._1), structType)
406
521
}
407
- .asInstanceOf [scala.collection.Seq [Object ]]
408
- .toSeq
522
+ (toTuple(structData.toSeq.asInstanceOf [Seq [Object ]]), structType)
523
+ } else if (struct.hasStructType) {
524
+ // For backward compatibility, we still support the old way to define and convert structs.
525
+ val elements = struct.getElementsList.asScala
526
+ val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
527
+ val structData = elements
528
+ .zip(dataTypes)
529
+ .map { case (element, dataType) =>
530
+ getConverter(dataType)(element)
531
+ }
532
+ .asInstanceOf [scala.collection.Seq [Object ]]
533
+ .toSeq
409
534
410
- toTuple(structData)
535
+ (toTuple(structData), struct.getStructType.getStruct)
536
+ } else {
537
+ throw InvalidPlanInput (" Data type information is missing in the struct literal." )
538
+ }
411
539
}
540
+
541
+ private case class LiteralValueWithDataType (value : Any , dataType : proto.DataType )
412
542
}
0 commit comments