Skip to content

[SPARK-52448][CONNECT] Add simplified Struct Expression.Literal #51561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 60 additions & 56 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

32 changes: 28 additions & 4 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -554,27 +554,51 @@ class Expression(google.protobuf.message.Message):

STRUCT_TYPE_FIELD_NUMBER: builtins.int
ELEMENTS_FIELD_NUMBER: builtins.int
DATA_TYPE_STRUCT_FIELD_NUMBER: builtins.int
@property
def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
def struct_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
"""(Deprecated) The type of the struct.

This field is deprecated since Spark 4.1+ because using DataType as the type of a struct
is ambiguous. This field should only be set if the data_type_struct field is not set.
Use data_type_struct field instead.
"""
@property
def elements(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
global___Expression.Literal
]: ...
]:
"""(Required) The literal values that make up the struct elements."""
@property
def data_type_struct(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Struct:
"""The type of the struct.

Whether data_type_struct.fields.data_type should be set depends on
whether each field's type can be inferred from the elements field.
"""
def __init__(
self,
*,
struct_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
data_type_struct: pyspark.sql.connect.proto.types_pb2.DataType.Struct | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["struct_type", b"struct_type"]
self,
field_name: typing_extensions.Literal[
"data_type_struct", b"data_type_struct", "struct_type", b"struct_type"
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"elements", b"elements", "struct_type", b"struct_type"
"data_type_struct",
b"data_type_struct",
"elements",
b"elements",
"struct_type",
b"struct_type",
],
) -> None: ...

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,24 @@ class ColumnNodeToProtoConverterSuite extends ConnectFunSuite {
Literal((12.0, "north", 60.0, "west"), Option(dataType)),
expr { b =>
val builder = b.getLiteralBuilder.getStructBuilder
builder.getStructTypeBuilder.getStructBuilder
.addFields(structField("_1", ProtoDataTypes.DoubleType))
.addFields(structField("_2", stringTypeWithCollation))
.addFields(structField("_3", ProtoDataTypes.DoubleType))
.addFields(structField("_4", stringTypeWithCollation))
builder.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0))
builder.addElements(proto.Expression.Literal.newBuilder().setString("north"))
builder.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0))
builder.addElements(proto.Expression.Literal.newBuilder().setString("west"))
builder
.addElements(proto.Expression.Literal.newBuilder().setDouble(12.0).build())
builder
.addElements(proto.Expression.Literal.newBuilder().setString("north").build())
builder
.addElements(proto.Expression.Literal.newBuilder().setDouble(60.0).build())
builder
.addElements(proto.Expression.Literal.newBuilder().setString("west").build())
builder.setDataTypeStruct(
proto.DataType.Struct
.newBuilder()
.addFields(
proto.DataType.StructField.newBuilder().setName("_1").setNullable(true).build())
.addFields(structField("_2", stringTypeWithCollation))
.addFields(
proto.DataType.StructField.newBuilder().setName("_3").setNullable(true).build())
.addFields(structField("_4", stringTypeWithCollation))
.build())
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,21 @@ message Expression {
}

message Struct {
DataType struct_type = 1;
// (Deprecated) The type of the struct.
//
// This field is deprecated since Spark 4.1+ because using DataType as the type of a struct
// is ambiguous. This field should only be set if the data_type_struct field is not set.
// Use data_type_struct field instead.
DataType struct_type = 1 [deprecated = true];

// (Required) The literal values that make up the struct elements.
repeated Literal elements = 2;

// The type of the struct.
//
// Whether data_type_struct.fields.data_type should be set depends on
// whether each field's type can be inferred from the elements field.
DataType.Struct data_type_struct = 3;
}

message SpecializedArray {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,36 @@ object LiteralValueProtoConverter {
}

def structBuilder(scalaValue: Any, structType: StructType) = {
val sb = builder.getStructBuilder.setStructType(toConnectProtoType(structType))
val dataTypes = structType.fields.map(_.dataType)
val sb = builder.getStructBuilder
val fields = structType.fields

scalaValue match {
case p: Product =>
val iter = p.productIterator
val dataTypeStruct = proto.DataType.Struct.newBuilder()
var idx = 0
while (idx < structType.size) {
sb.addElements(toLiteralProto(iter.next(), dataTypes(idx)))
val field = fields(idx)
val literalProto = toLiteralProto(iter.next(), field.dataType)
sb.addElements(literalProto)

val fieldBuilder = dataTypeStruct
.addFieldsBuilder()
.setName(field.name)
.setNullable(field.nullable)

if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) {
fieldBuilder.setDataType(toConnectProtoType(field.dataType))
}

// Set metadata if available
if (field.metadata != Metadata.empty) {
fieldBuilder.setMetadata(field.metadata.json)
}

idx += 1
}
sb.setDataTypeStruct(dataTypeStruct.build())
case other =>
throw new IllegalArgumentException(s"literal $other not supported (yet).")
}
Expand Down Expand Up @@ -300,54 +319,101 @@ object LiteralValueProtoConverter {
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
toCatalystArray(literal.getArray)

case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
toCatalystStruct(literal.getStruct)._1

case other =>
throw new UnsupportedOperationException(
s"Unsupported Literal Type: ${other.getNumber} (${other.name})")
}
}

private def getConverter(dataType: proto.DataType): proto.Expression.Literal => Any = {
if (dataType.hasShort) { v =>
v.getShort.toShort
} else if (dataType.hasInteger) { v =>
v.getInteger
} else if (dataType.hasLong) { v =>
v.getLong
} else if (dataType.hasDouble) { v =>
v.getDouble
} else if (dataType.hasByte) { v =>
v.getByte.toByte
} else if (dataType.hasFloat) { v =>
v.getFloat
} else if (dataType.hasBoolean) { v =>
v.getBoolean
} else if (dataType.hasString) { v =>
v.getString
} else if (dataType.hasBinary) { v =>
v.getBinary.toByteArray
} else if (dataType.hasDate) { v =>
v.getDate
} else if (dataType.hasTimestamp) { v =>
v.getTimestamp
} else if (dataType.hasTimestampNtz) { v =>
v.getTimestampNtz
} else if (dataType.hasDayTimeInterval) { v =>
v.getDayTimeInterval
} else if (dataType.hasYearMonthInterval) { v =>
v.getYearMonthInterval
} else if (dataType.hasDecimal) { v =>
Decimal(v.getDecimal.getValue)
} else if (dataType.hasCalendarInterval) { v =>
val interval = v.getCalendarInterval
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
} else if (dataType.hasArray) { v =>
toCatalystArray(v.getArray)
} else if (dataType.hasMap) { v =>
toCatalystMap(v.getMap)
} else if (dataType.hasStruct) { v =>
toCatalystStruct(v.getStruct)
} else {
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
private def getConverter(
dataType: proto.DataType,
inferDataType: Boolean = false): proto.Expression.Literal => Any = {
dataType.getKindCase match {
case proto.DataType.KindCase.SHORT => v => v.getShort.toShort
case proto.DataType.KindCase.INTEGER => v => v.getInteger
case proto.DataType.KindCase.LONG => v => v.getLong
case proto.DataType.KindCase.DOUBLE => v => v.getDouble
case proto.DataType.KindCase.BYTE => v => v.getByte.toByte
case proto.DataType.KindCase.FLOAT => v => v.getFloat
case proto.DataType.KindCase.BOOLEAN => v => v.getBoolean
case proto.DataType.KindCase.STRING => v => v.getString
case proto.DataType.KindCase.BINARY => v => v.getBinary.toByteArray
case proto.DataType.KindCase.DATE => v => v.getDate
case proto.DataType.KindCase.TIMESTAMP => v => v.getTimestamp
case proto.DataType.KindCase.TIMESTAMP_NTZ => v => v.getTimestampNtz
case proto.DataType.KindCase.DAY_TIME_INTERVAL => v => v.getDayTimeInterval
case proto.DataType.KindCase.YEAR_MONTH_INTERVAL => v => v.getYearMonthInterval
case proto.DataType.KindCase.DECIMAL => v => Decimal(v.getDecimal.getValue)
case proto.DataType.KindCase.CALENDAR_INTERVAL =>
v =>
val interval = v.getCalendarInterval
new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds)
case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray)
case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap)
case proto.DataType.KindCase.STRUCT =>
if (inferDataType) { v =>
val (struct, structType) = toCatalystStruct(v.getStruct, None)
LiteralValueWithDataType(
struct,
proto.DataType.newBuilder.setStruct(structType).build())
} else { v =>
toCatalystStruct(v.getStruct, Some(dataType.getStruct))._1
}
case _ =>
throw InvalidPlanInput(s"Unsupported Literal Type: $dataType)")
}
}

private def getInferredDataType(literal: proto.Expression.Literal): Option[proto.DataType] = {
if (literal.hasNull) {
return Some(literal.getNull)
}

val builder = proto.DataType.newBuilder()
literal.getLiteralTypeCase match {
case proto.Expression.Literal.LiteralTypeCase.BINARY =>
builder.setBinary(proto.DataType.Binary.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.BOOLEAN =>
builder.setBoolean(proto.DataType.Boolean.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.BYTE =>
builder.setByte(proto.DataType.Byte.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.SHORT =>
builder.setShort(proto.DataType.Short.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
builder.setInteger(proto.DataType.Integer.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.LONG =>
builder.setLong(proto.DataType.Long.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.FLOAT =>
builder.setFloat(proto.DataType.Float.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.DOUBLE =>
builder.setDouble(proto.DataType.Double.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.DATE =>
builder.setDate(proto.DataType.Date.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP =>
builder.setTimestamp(proto.DataType.Timestamp.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.TIMESTAMP_NTZ =>
builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL =>
builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build())
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
// The type of the fields will be inferred from the literals of the fields in the struct.
builder.setStruct(literal.getStruct.getStructType.getStruct)
case _ =>
// Not all data types support inferring the data type from the literal at the moment.
// e.g. the type of DayTimeInterval contains extra information like start_field and
// end_field and cannot be inferred from the literal.
return None
}
Some(builder.build())
}

private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = {
getInferredDataType(literal).getOrElse {
throw InvalidPlanInput(
s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}")
}
}

Expand Down Expand Up @@ -386,7 +452,9 @@ object LiteralValueProtoConverter {
makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType))
}

def toCatalystStruct(struct: proto.Expression.Literal.Struct): Any = {
def toCatalystStruct(
struct: proto.Expression.Literal.Struct,
structTypeOpt: Option[proto.DataType.Struct] = None): (Any, proto.DataType.Struct) = {
def toTuple[A <: Object](data: Seq[A]): Product = {
try {
val tupleClass = SparkClassUtils.classForName(s"scala.Tuple${data.length}")
Expand All @@ -397,16 +465,78 @@ object LiteralValueProtoConverter {
}
}

val elements = struct.getElementsList.asScala
val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
val structData = elements
.zip(dataTypes)
.map { case (element, dataType) =>
getConverter(dataType)(element)
if (struct.hasDataTypeStruct) {
// The new way to define and convert structs.
val (structData, structType) = if (structTypeOpt.isDefined) {
val structFields = structTypeOpt.get.getFieldsList.asScala
val structData =
struct.getElementsList.asScala.zip(structFields).map { case (element, structField) =>
getConverter(structField.getDataType)(element)
}
(structData, structTypeOpt.get)
} else {
def protoStructField(
name: String,
dataType: proto.DataType,
nullable: Boolean,
metadata: Option[String]): proto.DataType.StructField = {
val builder = proto.DataType.StructField
.newBuilder()
.setName(name)
.setDataType(dataType)
.setNullable(nullable)
metadata.foreach(builder.setMetadata)
builder.build()
}

val dataTypeFields = struct.getDataTypeStruct.getFieldsList.asScala

val structDataAndFields = struct.getElementsList.asScala.zip(dataTypeFields).map {
case (element, dataTypeField) =>
if (dataTypeField.hasDataType) {
(getConverter(dataTypeField.getDataType)(element), dataTypeField)
} else {
val outerDataType = getInferredDataTypeOrThrow(element)
val (value, dataType) =
getConverter(outerDataType, inferDataType = true)(element) match {
case LiteralValueWithDataType(value, dataType) => (value, dataType)
case value => (value, outerDataType)
}
(
value,
protoStructField(
dataTypeField.getName,
dataType,
dataTypeField.getNullable,
if (dataTypeField.hasMetadata) Some(dataTypeField.getMetadata) else None))
}
}

val structType = proto.DataType.Struct
.newBuilder()
.addAllFields(structDataAndFields.map(_._2).asJava)
.build()

(structDataAndFields.map(_._1), structType)
}
.asInstanceOf[scala.collection.Seq[Object]]
.toSeq
(toTuple(structData.toSeq.asInstanceOf[Seq[Object]]), structType)
} else if (struct.hasStructType) {
// For backward compatibility, we still support the old way to define and convert structs.
val elements = struct.getElementsList.asScala
val dataTypes = struct.getStructType.getStruct.getFieldsList.asScala.map(_.getDataType)
val structData = elements
.zip(dataTypes)
.map { case (element, dataType) =>
getConverter(dataType)(element)
}
.asInstanceOf[scala.collection.Seq[Object]]
.toSeq

toTuple(structData)
(toTuple(structData), struct.getStructType.getStruct)
} else {
throw InvalidPlanInput("Data type information is missing in the struct literal.")
}
}

private case class LiteralValueWithDataType(value: Any, dataType: proto.DataType)
}
Loading