Skip to content

Commit f25c0b8

Browse files
heyihonghaoyangeng-db
authored andcommitted
[SPARK-52448][CONNECT] Add simplified Struct Expression.Literal
### What changes were proposed in this pull request? This PR adds a new `data_type_struct` field to the protobuf definition for struct literals in Spark Connect, addressing the ambiguity issues with the existing `struct_type` field. The changes include: 1. **Protobuf Schema Update**: Added a new `data_type_struct` field of type `DataType.Struct` to the `Literal.Struct` message in `expressions.proto`, while marking the existing `struct_type` field as deprecated. 2. **Enhanced Struct Conversion Logic**: Updated `LiteralValueProtoConverter.scala` to: - Use the new `data_type_struct` field when available for more precise struct type definition - Maintain backward compatibility by still supporting the deprecated `struct_type` field - Add proper field metadata handling in struct conversions - Improve type inference for struct fields when data types can be inferred from literal values ### Why are the changes needed? The current Expression.Struct literal is somewhat overcomplicated since it duplicates most of the information its fields already have. This is bulky to send over the wire, and it can be ambiguous. ### Does this PR introduce _any_ user-facing change? No. This PR maintains backward compatibility with existing struct literal implementations. Existing code using the deprecated `struct_type` field will continue to work without modification. ### How was this patch tested? `build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"` ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.2.4 Closes apache#51561 from heyihong/SPARK-52448. Authored-by: Yihong He <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 555437a commit f25c0b8

File tree

9 files changed

+520
-250
lines changed

9 files changed

+520
-250
lines changed

python/pyspark/sql/connect/proto/expressions_pb2.py

Lines changed: 60 additions & 56 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/expressions_pb2.pyi

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -554,27 +554,51 @@ class Expression(google.protobuf.message.Message):
554554

555555
STRUCT_TYPE_FIELD_NUMBER: builtins.int
556556
ELEMENTS_FIELD_NUMBER: builtins.int
557+
DATA_TYPE_STRUCT_FIELD_NUMBER: builtins.int
557558
@property
558-
def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
559+
def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
560+
"""(Deprecated) The type of the struct.
561+
562+
This field is deprecated since Spark 4.1+ because using DataType as the type of a struct
563+
is ambiguous. This field should only be set if the data_type_struct field is not set.
564+
Use data_type_struct field instead.
565+
"""
559566
@property
560567
def elements(
561568
self,
562569
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
563570
global___Expression.Literal
564-
]: ...
571+
]:
572+
"""(Required) The literal values that make up the struct elements."""
573+
@property
574+
def data_type_struct(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Struct:
575+
"""The type of the struct.
576+
577+
Whether data_type_struct.fields.data_type should be set depends on
578+
whether each field's type can be inferred from the elements field.
579+
"""
565580
def __init__(
566581
self,
567582
*,
568583
struct_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
569584
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
585+
data_type_struct: pyspark.sql.connect.proto.types_pb2.DataType.Struct | None = ...,
570586
) -> None: ...
571587
def HasField(
572-
self, field_name: typing_extensions.Literal["struct_type", b"struct_type"]
588+
self,
589+
field_name: typing_extensions.Literal[
590+
"data_type_struct", b"data_type_struct", "struct_type", b"struct_type"
591+
],
573592
) -> builtins.bool: ...
574593
def ClearField(
575594
self,
576595
field_name: typing_extensions.Literal[
577-
"elements", b"elements", "struct_type", b"struct_type"
596+
"data_type_struct",
597+
b"data_type_struct",
598+
"elements",
599+
b"elements",
600+
"struct_type",
601+
b"struct_type",
578602
],
579603
) -> None: ...
580604

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ColumnNodeToProtoConverterSuite.scala

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,24 @@ class ColumnNodeToProtoConverterSuite extends ConnectFunSuite {
7979
Literal((12.0, "north", 60.0, "west"), Option(dataType)),
8080
expr { b =>
8181
val builder = b.getLiteralBuilder.getStructBuilder
82-
builder.getStructTypeBuilder.getStructBuilder
83-
.addFields(structField("_1", ProtoDataTypes.DoubleType))
84-
.addFields(structField("_2", stringTypeWithCollation))
85-
.addFields(structField("_3", ProtoDataTypes.DoubleType))
86-
.addFields(structField("_4", stringTypeWithCollation))
87-
builder.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0))
88-
builder.addElements(proto.Expression.Literal.newBuilder().setString("north"))
89-
builder.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0))
90-
builder.addElements(proto.Expression.Literal.newBuilder().setString("west"))
82+
builder
83+
.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0).build())
84+
builder
85+
.addElements(proto.Expression.Literal.newBuilder().setString("north").build())
86+
builder
87+
.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0).build())
88+
builder
89+
.addElements(proto.Expression.Literal.newBuilder().setString("west").build())
90+
builder.setDataTypeStruct(
91+
proto.DataType.Struct
92+
.newBuilder()
93+
.addFields(
94+
proto.DataType.StructField.newBuilder().setName("_1").setNullable(true).build())
95+
.addFields(structField("_2", stringTypeWithCollation))
96+
.addFields(
97+
proto.DataType.StructField.newBuilder().setName("_3").setNullable(true).build())
98+
.addFields(structField("_4", stringTypeWithCollation))
99+
.build())
91100
})
92101
}
93102

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,21 @@ message Expression {
227227
}
228228

229229
message Struct {
230-
DataType struct_type = 1;
230+
// (Deprecated) The type of the struct.
231+
//
232+
// This field is deprecated since Spark 4.1+ because using DataType as the type of a struct
233+
// is ambiguous. This field should only be set if the data_type_struct field is not set.
234+
// Use data_type_struct field instead.
235+
DataType struct_type = 1 [deprecated = true];
236+
237+
// (Required) The literal values that make up the struct elements.
231238
repeated Literal elements = 2;
239+
240+
// The type of the struct.
241+
//
242+
// Whether data_type_struct.fields.data_type should be set depends on
243+
// whether each field's type can be inferred from the elements field.
244+
DataType.Struct data_type_struct = 3;
232245
}
233246

234247
message SpecializedArray {

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala

Lines changed: 185 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -149,17 +149,36 @@ object LiteralValueProtoConverter {
149149
}
150150

151151
def structBuilder(scalaValue: Any, structType: StructType) = {
152-
val sb = builder.getStructBuilder.setStructType(toConnectProtoType(structType))
153-
val dataTypes = structType.fields.map(_.dataType)
152+
val sb = builder.getStructBuilder
153+
val fields = structType.fields
154154

155155
scalaValue match {
156156
case p: Product =>
157157
val iter = p.productIterator
158+
val dataTypeStruct = proto.DataType.Struct.newBuilder()
158159
var idx = 0
159160
while (idx < structType.size) {
160-
sb.addElements(toLiteralProto(iter.next(), dataTypes(idx)))
161+
val field = fields(idx)
162+
val literalProto = toLiteralProto(iter.next(), field.dataType)
163+
sb.addElements(literalProto)
164+
165+
val fieldBuilder = dataTypeStruct
166+
.addFieldsBuilder()
167+
.setName(field.name)
168+
.setNullable(field.nullable)
169+
170+
if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
171+
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
172+
}
173+
174+
// Set metadata if available
175+
if (field.metadata != Metadata.empty) {
176+
fieldBuilder.setMetadata(field.metadata.json)
177+
}
178+
161179
idx += 1
162180
}
181+
sb.setDataTypeStruct(dataTypeStruct.build())
163182
case other =>
164183
throw new IllegalArgumentException(s"literal $other not supported (yet).")
165184
}
@@ -300,54 +319,101 @@ object LiteralValueProtoConverter {
300319
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
301320
toCatalystArray(literal.getArray)
302321

322+
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
323+
toCatalystStruct(literal.getStruct)._1
324+
303325
case other =>
304326
throw new UnsupportedOperationException(
305327
s"Unsupported Literal Type: ${other.getNumber} (${other.name})")
306328
}
307329
}
308330

309-
private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = {
310-
if (dataType.hasShort) { v =>
311-
v.getShort.toShort
312-
} else if (dataType.hasInteger) { v =>
313-
v.getInteger
314-
} else if (dataType.hasLong) { v =>
315-
v.getLong
316-
} else if (dataType.hasDouble) { v =>
317-
v.getDouble
318-
} else if (dataType.hasByte) { v =>
319-
v.getByte.toByte
320-
} else if (dataType.hasFloat) { v =>
321-
v.getFloat
322-
} else if (dataType.hasBoolean) { v =>
323-
v.getBoolean
324-
} else if (dataType.hasString) { v =>
325-
v.getString
326-
} else if (dataType.hasBinary) { v =>
327-
v.getBinary.toByteArray
328-
} else if (dataType.hasDate) { v =>
329-
v.getDate
330-
} else if (dataType.hasTimestamp) { v =>
331-
v.getTimestamp
332-
} else if (dataType.hasTimestampNtz) { v =>
333-
v.getTimestampNtz
334-
} else if (dataType.hasDayTimeInterval) { v =>
335-
v.getDayTimeInterval
336-
} else if (dataType.hasYearMonthInterval) { v =>
337-
v.getYearMonthInterval
338-
} else if (dataType.hasDecimal) { v =>
339-
Decimal(v.getDecimal.getValue)
340-
} else if (dataType.hasCalendarInterval) { v =>
341-
val interval = v.getCalendarInterval
342-
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
343-
} else if (dataType.hasArray) { v =>
344-
toCatalystArray(v.getArray)
345-
} else if (dataType.hasMap) { v =>
346-
toCatalystMap(v.getMap)
347-
} else if (dataType.hasStruct) { v =>
348-
toCatalystStruct(v.getStruct)
349-
} else {
350-
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
331+
private def getConverter(
332+
dataType: proto.DataType,
333+
inferDataType: Boolean = false): proto.Expression.Literal => Any = {
334+
dataType.getKindCase match {
335+
case proto.DataType.KindCase.SHORT => v => v.getShort.toShort
336+
case proto.DataType.KindCase.INTEGER => v => v.getInteger
337+
case proto.DataType.KindCase.LONG => v => v.getLong
338+
case proto.DataType.KindCase.DOUBLE => v => v.getDouble
339+
case proto.DataType.KindCase.BYTE => v => v.getByte.toByte
340+
case proto.DataType.KindCase.FLOAT => v => v.getFloat
341+
case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean
342+
case proto.DataType.KindCase.STRING => v => v.getString
343+
case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray
344+
case proto.DataType.KindCase.DATE => v => v.getDate
345+
case proto.DataType.KindCase.TIMESTAMP => v => v.getTimestamp
346+
case proto.DataType.KindCase.TIMESTAMP_NTZ => v => v.getTimestampNtz
347+
case proto.DataType.KindCase.DAY_TIME_INTERVAL => v => v.getDayTimeInterval
348+
case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => v => v.getYearMonthInterval
349+
case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue)
350+
case proto.DataType.KindCase.CALENDAR_INTERVAL =>
351+
v =>
352+
val interval = v.getCalendarInterval
353+
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
354+
case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray)
355+
case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap)
356+
case proto.DataType.KindCase.STRUCT =>
357+
if (inferDataType) { v =>
358+
val (struct, structType) = toCatalystStruct(v.getStruct, None)
359+
LiteralValueWithDataType(
360+
struct,
361+
proto.DataType.newBuilder.setStruct(structType).build())
362+
} else { v =>
363+
toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1
364+
}
365+
case _ =>
366+
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
367+
}
368+
}
369+
370+
private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = {
371+
if (literal.hasNull) {
372+
return Some(literal.getNull)
373+
}
374+
375+
val builder = proto.DataType.newBuilder()
376+
literal.getLiteralTypeCase match {
377+
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
378+
builder.setBinary(proto.DataType.Binary.newBuilder.build())
379+
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
380+
builder.setBoolean(proto.DataType.Boolean.newBuilder.build())
381+
case proto.Expression.Literal.LiteralTypeCase.BYTE =>
382+
builder.setByte(proto.DataType.Byte.newBuilder.build())
383+
case proto.Expression.Literal.LiteralTypeCase.SHORT =>
384+
builder.setShort(proto.DataType.Short.newBuilder.build())
385+
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
386+
builder.setInteger(proto.DataType.Integer.newBuilder.build())
387+
case proto.Expression.Literal.LiteralTypeCase.LONG =>
388+
builder.setLong(proto.DataType.Long.newBuilder.build())
389+
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
390+
builder.setFloat(proto.DataType.Float.newBuilder.build())
391+
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
392+
builder.setDouble(proto.DataType.Double.newBuilder.build())
393+
case proto.Expression.Literal.LiteralTypeCase.DATE =>
394+
builder.setDate(proto.DataType.Date.newBuilder.build())
395+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
396+
builder.setTimestamp(proto.DataType.Timestamp.newBuilder.build())
397+
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
398+
builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder.build())
399+
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
400+
builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build())
401+
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
402+
// The type of the fields will be inferred from the literals of the fields in the struct.
403+
builder.setStruct(literal.getStruct.getStructType.getStruct)
404+
case _ =>
405+
// Not all data types support inferring the data type from the literal at the moment.
406+
// e.g. the type of DayTimeInterval contains extra information like start_field and
407+
// end_field and cannot be inferred from the literal.
408+
return None
409+
}
410+
Some(builder.build())
411+
}
412+
413+
private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = {
414+
getInferredDataType(literal).getOrElse {
415+
throw InvalidPlanInput(
416+
s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}")
351417
}
352418
}
353419

@@ -386,7 +452,9 @@ object LiteralValueProtoConverter {
386452
makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
387453
}
388454

389-
def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
455+
def toCatalystStruct(
456+
struct: proto.Expression.Literal.Struct,
457+
structTypeOpt: Option[proto.DataType.Struct] = None): (Any, proto.DataType.Struct) = {
390458
def toTuple[A <: Object](data: Seq[A]): Product = {
391459
try {
392460
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
@@ -397,16 +465,78 @@ object LiteralValueProtoConverter {
397465
}
398466
}
399467

400-
val elements = struct.getElementsList.asScala
401-
val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
402-
val structData = elements
403-
.zip(dataTypes)
404-
.map { case (element, dataType) =>
405-
getConverter(dataType)(element)
468+
if (struct.hasDataTypeStruct) {
469+
// The new way to define and convert structs.
470+
val (structData, structType) = if (structTypeOpt.isDefined) {
471+
val structFields = structTypeOpt.get.getFieldsList.asScala
472+
val structData =
473+
struct.getElementsList.asScala.zip(structFields).map { case (element, structField) =>
474+
getConverter(structField.getDataType)(element)
475+
}
476+
(structData, structTypeOpt.get)
477+
} else {
478+
def protoStructField(
479+
name: String,
480+
dataType: proto.DataType,
481+
nullable: Boolean,
482+
metadata: Option[String]): proto.DataType.StructField = {
483+
val builder = proto.DataType.StructField
484+
.newBuilder()
485+
.setName(name)
486+
.setDataType(dataType)
487+
.setNullable(nullable)
488+
metadata.foreach(builder.setMetadata)
489+
builder.build()
490+
}
491+
492+
val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala
493+
494+
val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map {
495+
case (element, dataTypeField) =>
496+
if (dataTypeField.hasDataType) {
497+
(getConverter(dataTypeField.getDataType)(element), dataTypeField)
498+
} else {
499+
val outerDataType = getInferredDataTypeOrThrow(element)
500+
val (value, dataType) =
501+
getConverter(outerDataType, inferDataType = true)(element) match {
502+
case LiteralValueWithDataType(value, dataType) => (value, dataType)
503+
case value => (value, outerDataType)
504+
}
505+
(
506+
value,
507+
protoStructField(
508+
dataTypeField.getName,
509+
dataType,
510+
dataTypeField.getNullable,
511+
if (dataTypeField.hasMetadata) Some(dataTypeField.getMetadata) else None))
512+
}
513+
}
514+
515+
val structType = proto.DataType.Struct
516+
.newBuilder()
517+
.addAllFields(structDataAndFields.map(_._2).asJava)
518+
.build()
519+
520+
(structDataAndFields.map(_._1), structType)
406521
}
407-
.asInstanceOf[scala.collection.Seq[Object]]
408-
.toSeq
522+
(toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType)
523+
} else if (struct.hasStructType) {
524+
// For backward compatibility, we still support the old way to define and convert structs.
525+
val elements = struct.getElementsList.asScala
526+
val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
527+
val structData = elements
528+
.zip(dataTypes)
529+
.map { case (element, dataType) =>
530+
getConverter(dataType)(element)
531+
}
532+
.asInstanceOf[scala.collection.Seq[Object]]
533+
.toSeq
409534

410-
toTuple(structData)
535+
(toTuple(structData), struct.getStructType.getStruct)
536+
} else {
537+
throw InvalidPlanInput("Data type information is missing in the struct literal.")
538+
}
411539
}
540+
541+
private case class LiteralValueWithDataType(value: Any, dataType: proto.DataType)
412542
}

0 commit comments

Comments
 (0)