Skip to content

Commit 2d17a12

Browse files
committed
[SPARK-52448][CONNECT] Add simplified Struct Expression.Literal
1 parent 169b47f commit 2d17a12

File tree

9 files changed

+520
-250
lines changed

9 files changed

+520
-250
lines changed

python/pyspark/sql/connect/proto/expressions_pb2.py

Lines changed: 60 additions & 56 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/expressions_pb2.pyi

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,27 +554,51 @@ class Expression(google.protobuf.message.Message):
554554

555555
STRUCT_TYPE_FIELD_NUMBER: builtins.int
556556
ELEMENTS_FIELD_NUMBER: builtins.int
557+
DATA_TYPE_STRUCT_FIELD_NUMBER: builtins.int
557558
@property
558-
def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
559+
def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
560+
"""(Deprecated) The type of the struct.
561+
562+
This field is deprecated since Spark 4.1+ because using DataType as the type of a struct
563+
is ambiguous. This field should only be set if the data_type_struct field is not set.
564+
Use data_type_struct field instead.
565+
"""
559566
@property
560567
def elements(
561568
self,
562569
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
563570
global___Expression.Literal
564-
]: ...
571+
]:
572+
"""(Required) The literal values that make up the struct elements."""
573+
@property
574+
def data_type_struct(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Struct:
575+
"""The type of the struct.
576+
577+
Whether data_type_struct.fields.data_type should be set depends on
578+
whether each field's type can be inferred from the elements field.
579+
"""
565580
def __init__(
566581
self,
567582
*,
568583
struct_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
569584
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
585+
data_type_struct: pyspark.sql.connect.proto.types_pb2.DataType.Struct | None = ...,
570586
) -> None: ...
571587
def HasField(
572-
self, field_name: typing_extensions.Literal["struct_type", b"struct_type"]
588+
self,
589+
field_name: typing_extensions.Literal[
590+
"data_type_struct", b"data_type_struct", "struct_type", b"struct_type"
591+
],
573592
) -> builtins.bool: ...
574593
def ClearField(
575594
self,
576595
field_name: typing_extensions.Literal[
577-
"elements", b"elements", "struct_type", b"struct_type"
596+
"data_type_struct",
597+
b"data_type_struct",
598+
"elements",
599+
b"elements",
600+
"struct_type",
601+
b"struct_type",
578602
],
579603
) -> None: ...
580604

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,24 @@ class ColumnNodeToProtoConverterSuite extends ConnectFunSuite {
7979
Literal((12.0, "north", 60.0, "west"), Option(dataType)),
8080
expr { b =>
8181
val builder = b.getLiteralBuilder.getStructBuilder
82-
builder.getStructTypeBuilder.getStructBuilder
83-
.addFields(structField("_1", ProtoDataTypes.DoubleType))
84-
.addFields(structField("_2", stringTypeWithCollation))
85-
.addFields(structField("_3", ProtoDataTypes.DoubleType))
86-
.addFields(structField("_4", stringTypeWithCollation))
87-
builder.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0))
88-
builder.addElements(proto.Expression.Literal.newBuilder().setString("north"))
89-
builder.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0))
90-
builder.addElements(proto.Expression.Literal.newBuilder().setString("west"))
82+
builder
83+
.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0).build())
84+
builder
85+
.addElements(proto.Expression.Literal.newBuilder().setString("north").build())
86+
builder
87+
.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0).build())
88+
builder
89+
.addElements(proto.Expression.Literal.newBuilder().setString("west").build())
90+
builder.setDataTypeStruct(
91+
proto.DataType.Struct
92+
.newBuilder()
93+
.addFields(
94+
proto.DataType.StructField.newBuilder().setName("_1").setNullable(true).build())
95+
.addFields(structField("_2", stringTypeWithCollation))
96+
.addFields(
97+
proto.DataType.StructField.newBuilder().setName("_3").setNullable(true).build())
98+
.addFields(structField("_4", stringTypeWithCollation))
99+
.build())
91100
})
92101
}
93102

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,21 @@ message Expression {
227227
}
228228

229229
message Struct {
230-
DataType struct_type = 1;
230+
// (Deprecated) The type of the struct.
231+
//
232+
// This field is deprecated since Spark 4.1+ because using DataType as the type of a struct
233+
// is ambiguous. This field should only be set if the data_type_struct field is not set.
234+
// Use data_type_struct field instead.
235+
DataType struct_type = 1 [deprecated = true];
236+
237+
// (Required) The literal values that make up the struct elements.
231238
repeated Literal elements = 2;
239+
240+
// The type of the struct.
241+
//
242+
// Whether data_type_struct.fields.data_type should be set depends on
243+
// whether each field's type can be inferred from the elements field.
244+
DataType.Struct data_type_struct = 3;
232245
}
233246

234247
message SpecializedArray {

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 185 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,36 @@ object LiteralValueProtoConverter {
149149
}
150150

151151
def structBuilder(scalaValue: Any, structType: StructType) = {
152-
val sb = builder.getStructBuilder.setStructType(toConnectProtoType(structType))
153-
val dataTypes = structType.fields.map(_.dataType)
152+
val sb = builder.getStructBuilder
153+
val fields = structType.fields
154154

155155
scalaValue match {
156156
case p: Product =>
157157
val iter = p.productIterator
158+
val dataTypeStruct = proto.DataType.Struct.newBuilder()
158159
var idx = 0
159160
while (idx < structType.size) {
160-
sb.addElements(toLiteralProto(iter.next(), dataTypes(idx)))
161+
val field = fields(idx)
162+
val literalProto = toLiteralProto(iter.next(), field.dataType)
163+
sb.addElements(literalProto)
164+
165+
val fieldBuilder = dataTypeStruct
166+
.addFieldsBuilder()
167+
.setName(field.name)
168+
.setNullable(field.nullable)
169+
170+
if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
171+
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
172+
}
173+
174+
// Set metadata if available
175+
if (field.metadata != Metadata.empty) {
176+
fieldBuilder.setMetadata(field.metadata.json)
177+
}
178+
161179
idx += 1
162180
}
181+
sb.setDataTypeStruct(dataTypeStruct.build())
163182
case other =>
164183
throw new IllegalArgumentException(s"literal $other not supported (yet).")
165184
}
@@ -300,54 +319,101 @@ object LiteralValueProtoConverter {
300319
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
301320
toCatalystArray(literal.getArray)
302321

322+
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
323+
toCatalystStruct(literal.getStruct)._1
324+
303325
case other =>
304326
throw new UnsupportedOperationException(
305327
s"Unsupported Literal Type: ${other.getNumber} (${other.name})")
306328
}
307329
}
308330

309-
private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = {
310-
if (dataType.hasShort) { v =>
311-
v.getShort.toShort
312-
} else if (dataType.hasInteger) { v =>
313-
v.getInteger
314-
} else if (dataType.hasLong) { v =>
315-
v.getLong
316-
} else if (dataType.hasDouble) { v =>
317-
v.getDouble
318-
} else if (dataType.hasByte) { v =>
319-
v.getByte.toByte
320-
} else if (dataType.hasFloat) { v =>
321-
v.getFloat
322-
} else if (dataType.hasBoolean) { v =>
323-
v.getBoolean
324-
} else if (dataType.hasString) { v =>
325-
v.getString
326-
} else if (dataType.hasBinary) { v =>
327-
v.getBinary.toByteArray
328-
} else if (dataType.hasDate) { v =>
329-
v.getDate
330-
} else if (dataType.hasTimestamp) { v =>
331-
v.getTimestamp
332-
} else if (dataType.hasTimestampNtz) { v =>
333-
v.getTimestampNtz
334-
} else if (dataType.hasDayTimeInterval) { v =>
335-
v.getDayTimeInterval
336-
} else if (dataType.hasYearMonthInterval) { v =>
337-
v.getYearMonthInterval
338-
} else if (dataType.hasDecimal) { v =>
339-
Decimal(v.getDecimal.getValue)
340-
} else if (dataType.hasCalendarInterval) { v =>
341-
val interval = v.getCalendarInterval
342-
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
343-
} else if (dataType.hasArray) { v =>
344-
toCatalystArray(v.getArray)
345-
} else if (dataType.hasMap) { v =>
346-
toCatalystMap(v.getMap)
347-
} else if (dataType.hasStruct) { v =>
348-
toCatalystStruct(v.getStruct)
349-
} else {
350-
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
331+
private def getConverter(
332+
dataType: proto.DataType,
333+
inferDataType: Boolean = false): proto.Expression.Literal => Any = {
334+
dataType.getKindCase match {
335+
case proto.DataType.KindCase.SHORT => v => v.getShort.toShort
336+
case proto.DataType.KindCase.INTEGER => v => v.getInteger
337+
case proto.DataType.KindCase.LONG => v => v.getLong
338+
case proto.DataType.KindCase.DOUBLE => v => v.getDouble
339+
case proto.DataType.KindCase.BYTE => v => v.getByte.toByte
340+
case proto.DataType.KindCase.FLOAT => v => v.getFloat
341+
case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean
342+
case proto.DataType.KindCase.STRING => v => v.getString
343+
case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray
344+
case proto.DataType.KindCase.DATE => v => v.getDate
345+
case proto.DataType.KindCase.TIMESTAMP => v => v.getTimestamp
346+
case proto.DataType.KindCase.TIMESTAMP_NTZ => v => v.getTimestampNtz
347+
case proto.DataType.KindCase.DAY_TIME_INTERVAL => v => v.getDayTimeInterval
348+
case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => v => v.getYearMonthInterval
349+
case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue)
350+
case proto.DataType.KindCase.CALENDAR_INTERVAL =>
351+
v =>
352+
val interval = v.getCalendarInterval
353+
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
354+
case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray)
355+
case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap)
356+
case proto.DataType.KindCase.STRUCT =>
357+
if (inferDataType) { v =>
358+
val (struct, structType) = toCatalystStruct(v.getStruct, None)
359+
LiteralValueWithDataType(
360+
struct,
361+
proto.DataType.newBuilder.setStruct(structType).build())
362+
} else { v =>
363+
toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1
364+
}
365+
case _ =>
366+
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
367+
}
368+
}
369+
370+
private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = {
371+
if (literal.hasNull) {
372+
return Some(literal.getNull)
373+
}
374+
375+
val builder = proto.DataType.newBuilder()
376+
literal.getLiteralTypeCase match {
377+
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
378+
builder.setBinary(proto.DataType.Binary.newBuilder.build())
379+
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
380+
builder.setBoolean(proto.DataType.Boolean.newBuilder.build())
381+
case proto.Expression.Literal.LiteralTypeCase.BYTE =>
382+
builder.setByte(proto.DataType.Byte.newBuilder.build())
383+
case proto.Expression.Literal.LiteralTypeCase.SHORT =>
384+
builder.setShort(proto.DataType.Short.newBuilder.build())
385+
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
386+
builder.setInteger(proto.DataType.Integer.newBuilder.build())
387+
case proto.Expression.Literal.LiteralTypeCase.LONG =>
388+
builder.setLong(proto.DataType.Long.newBuilder.build())
389+
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
390+
builder.setFloat(proto.DataType.Float.newBuilder.build())
391+
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
392+
builder.setDouble(proto.DataType.Double.newBuilder.build())
393+
case proto.Expression.Literal.LiteralTypeCase.DATE =>
394+
builder.setDate(proto.DataType.Date.newBuilder.build())
395+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
396+
builder.setTimestamp(proto.DataType.Timestamp.newBuilder.build())
397+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
398+
builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder.build())
399+
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
400+
builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build())
401+
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
402+
// The type of the fields will be inferred from the literals of the fields in the struct.
403+
builder.setStruct(literal.getStruct.getStructType.getStruct)
404+
case _ =>
405+
// Not all data types support inferring the data type from the literal at the moment.
406+
// e.g. the type of DayTimeInterval contains extra information like start_field and
407+
// end_field and cannot be inferred from the literal.
408+
return None
409+
}
410+
Some(builder.build())
411+
}
412+
413+
private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = {
414+
getInferredDataType(literal).getOrElse {
415+
throw InvalidPlanInput(
416+
s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}")
351417
}
352418
}
353419

@@ -386,7 +452,9 @@ object LiteralValueProtoConverter {
386452
makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
387453
}
388454

389-
def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
455+
def toCatalystStruct(
456+
struct: proto.Expression.Literal.Struct,
457+
structTypeOpt: Option[proto.DataType.Struct] = None): (Any, proto.DataType.Struct) = {
390458
def toTuple[A <: Object](data: Seq[A]): Product = {
391459
try {
392460
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
@@ -397,16 +465,78 @@ object LiteralValueProtoConverter {
397465
}
398466
}
399467

400-
val elements = struct.getElementsList.asScala
401-
val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
402-
val structData = elements
403-
.zip(dataTypes)
404-
.map { case (element, dataType) =>
405-
getConverter(dataType)(element)
468+
if (struct.hasDataTypeStruct) {
469+
// The new way to define and convert structs.
470+
val (structData, structType) = if (structTypeOpt.isDefined) {
471+
val structFields = structTypeOpt.get.getFieldsList.asScala
472+
val structData =
473+
struct.getElementsList.asScala.zip(structFields).map { case (element, structField) =>
474+
getConverter(structField.getDataType)(element)
475+
}
476+
(structData, structTypeOpt.get)
477+
} else {
478+
def protoStructField(
479+
name: String,
480+
dataType: proto.DataType,
481+
nullable: Boolean,
482+
metadata: Option[String]): proto.DataType.StructField = {
483+
val builder = proto.DataType.StructField
484+
.newBuilder()
485+
.setName(name)
486+
.setDataType(dataType)
487+
.setNullable(nullable)
488+
metadata.foreach(builder.setMetadata)
489+
builder.build()
490+
}
491+
492+
val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala
493+
494+
val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map {
495+
case (element, dataTypeField) =>
496+
if (dataTypeField.hasDataType) {
497+
(getConverter(dataTypeField.getDataType)(element), dataTypeField)
498+
} else {
499+
val outerDataType = getInferredDataTypeOrThrow(element)
500+
val (value, dataType) =
501+
getConverter(outerDataType, inferDataType = true)(element) match {
502+
case LiteralValueWithDataType(value, dataType) => (value, dataType)
503+
case value => (value, outerDataType)
504+
}
505+
(
506+
value,
507+
protoStructField(
508+
dataTypeField.getName,
509+
dataType,
510+
dataTypeField.getNullable,
511+
if (dataTypeField.hasMetadata) Some(dataTypeField.getMetadata) else None))
512+
}
513+
}
514+
515+
val structType = proto.DataType.Struct
516+
.newBuilder()
517+
.addAllFields(structDataAndFields.map(_._2).asJava)
518+
.build()
519+
520+
(structDataAndFields.map(_._1), structType)
406521
}
407-
.asInstanceOf[scala.collection.Seq[Object]]
408-
.toSeq
522+
(toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType)
523+
} else if (struct.hasStructType) {
524+
// For backward compatibility, we still support the old way to define and convert structs.
525+
val elements = struct.getElementsList.asScala
526+
val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
527+
val structData = elements
528+
.zip(dataTypes)
529+
.map { case (element, dataType) =>
530+
getConverter(dataType)(element)
531+
}
532+
.asInstanceOf[scala.collection.Seq[Object]]
533+
.toSeq
409534

410-
toTuple(structData)
535+
(toTuple(structData), struct.getStructType.getStruct)
536+
} else {
537+
throw InvalidPlanInput("Data type information is missing in the struct literal.")
538+
}
411539
}
540+
541+
private case class LiteralValueWithDataType(value: Any, dataType: proto.DataType)
412542
}

0 commit comments

Comments
 (0)