From 81af0688ad7cdaf893ef9febd7374efc4625e299 Mon Sep 17 00:00:00 2001 From: Yihong He Date: Thu, 28 Aug 2025 13:47:18 +0200 Subject: [PATCH 1/2] [SPARK-52449] Make datatypes for Expression.Literal.Map/Expression.Literal.Array optional --- .../common/LiteralValueProtoConverter.scala | 93 ++++++++----- .../query-tests/queries/function_lit.json | 4 - .../queries/function_lit.proto.bin | Bin 5391 -> 5387 bytes .../queries/function_lit_array.json | 122 ----------------- .../queries/function_lit_array.proto.bin | Bin 5127 -> 4989 bytes .../queries/function_typedLit.json | 126 ------------------ .../queries/function_typedLit.proto.bin | Bin 9356 -> 9201 bytes .../apache/spark/sql/connect/ml/MLUtils.scala | 4 +- .../spark/sql/connect/ml/Serializer.scala | 48 ++----- ...LiteralExpressionProtoConverterSuite.scala | 91 ++++++++++++- 10 files changed, 158 insertions(+), 330 deletions(-) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 870a452e85ecd..05713a9d9fd18 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -39,6 +39,45 @@ import org.apache.spark.util.SparkClassUtils object LiteralValueProtoConverter { + private def setArrayTypeAfterAddingElements( + ab: proto.Expression.Literal.Array.Builder, + elementType: DataType, + containsNull: Boolean, + useDeprecatedDataTypeFields: Boolean): Unit = { + if (useDeprecatedDataTypeFields) { + ab.setElementType(toConnectProtoType(elementType)) + } else { + val dataTypeBuilder = proto.DataType.Array.newBuilder() + if (ab.getElementsCount == 0 || getInferredDataType(ab.getElements(0)).isEmpty) { + dataTypeBuilder.setElementType(toConnectProtoType(elementType)) + } + dataTypeBuilder.setContainsNull(containsNull) + ab.setDataType(dataTypeBuilder.build()) + } + } + + private def setMapTypeAfterAddingKeysAndValues( + mb: proto.Expression.Literal.Map.Builder, + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean, + useDeprecatedDataTypeFields: Boolean): Unit = { + if (useDeprecatedDataTypeFields) { + mb.setKeyType(toConnectProtoType(keyType)) + mb.setValueType(toConnectProtoType(valueType)) + } else { + val dataTypeBuilder = proto.DataType.Map.newBuilder() + if (mb.getKeysCount == 0 || getInferredDataType(mb.getKeys(0)).isEmpty) { + dataTypeBuilder.setKeyType(toConnectProtoType(keyType)) + } + if (mb.getValuesCount == 0 || getInferredDataType(mb.getValues(0)).isEmpty) { + dataTypeBuilder.setValueType(toConnectProtoType(valueType)) + } + dataTypeBuilder.setValueContainsNull(valueContainsNull) + mb.setDataType(dataTypeBuilder.build()) + } + } + @scala.annotation.tailrec private def toLiteralProtoBuilderInternal( literal: Any, @@ -58,17 +97,12 @@ object LiteralValueProtoConverter { def arrayBuilder(array: Array[_]) = { val ab = builder.getArrayBuilder - if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) - } else { - ab.setDataType( - proto.DataType.Array - .newBuilder() - .setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) - .setContainsNull(true) - .build()) - } array.foreach(x => ab.addElements(toLiteralProtoWithOptions(x, None, options))) + setArrayTypeAfterAddingElements( + ab, + toDataType(array.getClass.getComponentType), + containsNull = true, + options.useDeprecatedDataTypeFields) ab } @@ -122,16 +156,6 @@ object LiteralValueProtoConverter { def arrayBuilder(scalaValue: Any, elementType: DataType, containsNull: Boolean) = { val ab = builder.getArrayBuilder - if (options.useDeprecatedDataTypeFields) { - ab.setElementType(toConnectProtoType(elementType)) - } else { - ab.setDataType( - proto.DataType.Array - .newBuilder() - .setElementType(toConnectProtoType(elementType)) - .setContainsNull(containsNull) - .build()) - } scalaValue match { case a: Array[_] => a.foreach(item => @@ -142,7 +166,11 @@ object LiteralValueProtoConverter { case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") } - + setArrayTypeAfterAddingElements( + ab, + elementType, + containsNull, + options.useDeprecatedDataTypeFields) ab } @@ -152,19 +180,6 @@ object LiteralValueProtoConverter { valueType: DataType, valueContainsNull: Boolean) = { val mb = builder.getMapBuilder - if (options.useDeprecatedDataTypeFields) { - mb.setKeyType(toConnectProtoType(keyType)) - mb.setValueType(toConnectProtoType(valueType)) - } else { - mb.setDataType( - proto.DataType.Map - .newBuilder() - .setKeyType(toConnectProtoType(keyType)) - .setValueType(toConnectProtoType(valueType)) - .setValueContainsNull(valueContainsNull) - .build()) - } - scalaValue match { case map: scala.collection.Map[_, _] => map.foreach { case (k, v) => @@ -174,7 +189,12 @@ object LiteralValueProtoConverter { case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") } - + setMapTypeAfterAddingKeysAndValues( + mb, + keyType, + valueType, + valueContainsNull, + options.useDeprecatedDataTypeFields) mb } @@ -414,6 +434,9 @@ object LiteralValueProtoConverter { case proto.Expression.Literal.LiteralTypeCase.ARRAY => toCatalystArray(literal.getArray) + case proto.Expression.Literal.LiteralTypeCase.MAP => + toCatalystMap(literal.getMap) + case proto.Expression.Literal.LiteralTypeCase.STRUCT => toCatalystStruct(literal.getStruct) diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json index cedf7572a1fd3..a899c9f410aad 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json @@ -364,10 +364,6 @@ "integer": 6 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin index 5d30f4fca159b1d3cd36b30a5403cd434c0a0700..26c7b3a7dc02326aa7f6882bc0453a13c15d1f4d 100644 GIT binary patch delta 44 xcmeCz>egc8ViI7KYT3vp!6ST#k&Ay5BaaZ10S6GXNdaXhD{^ky{G3OF0|33)2*CgV delta 42 vcmeCy>epi9ViI7K>e$F8!6SN%kxO_Jqks^T0S6GXNde`I7&hzk+~EKKspJQd diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json index 53b1a7b3947f9..dc73d05d6c53b 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json @@ -49,10 +49,6 @@ "integer": 1 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } @@ -62,10 +58,6 @@ "integer": 2 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } @@ -75,24 +67,11 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "containsNull": true } } @@ -125,24 +104,11 @@ "integer": 1 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "containsNull": true } } @@ -154,24 +120,11 @@ "integer": 2 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "containsNull": true } } @@ -183,43 +136,16 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "containsNull": true } } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, - "containsNull": true - } - }, "containsNull": true } } @@ -250,10 +176,6 @@ "boolean": false }], "dataType": { - "elementType": { - "boolean": { - } - }, "containsNull": true } } @@ -307,10 +229,6 @@ "short": 9874 }], "dataType": { - "elementType": { - "short": { - } - }, "containsNull": true } } @@ -343,10 +261,6 @@ "integer": -8726533 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } @@ -379,10 +293,6 @@ "long": "7834609328726533" }], "dataType": { - "elementType": { - "long": { - } - }, "containsNull": true } } @@ -415,10 +325,6 @@ "double": 2.0 }], "dataType": { - "elementType": { - "double": { - } - }, "containsNull": true } } @@ -451,10 +357,6 @@ "float": -0.9 }], "dataType": { - "elementType": { - "float": { - } - }, "containsNull": true } } @@ -664,10 +566,6 @@ "date": 18546 }], "dataType": { - "elementType": { - "date": { - } - }, "containsNull": true } } @@ -698,10 +596,6 @@ "timestamp": "1677155519809000" }], "dataType": { - "elementType": { - "timestamp": { - } - }, "containsNull": true } } @@ -732,10 +626,6 @@ "timestamp": "23456000" }], "dataType": { - "elementType": { - "timestamp": { - } - }, "containsNull": true } } @@ -766,10 +656,6 @@ "timestampNtz": "1677188160000000" }], "dataType": { - "elementType": { - "timestampNtz": { - } - }, "containsNull": true } } @@ -800,10 +686,6 @@ "date": 19417 }], "dataType": { - "elementType": { - "date": { - } - }, "containsNull": true } } @@ -914,10 +796,6 @@ } }], "dataType": { - "elementType": { - "calendarInterval": { - } - }, "containsNull": true } } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin index 8cb965dd25a0b5d3474c6b4c7d241616226ea3dc..66877f54fef2f3ad42b159c3c545acc5fe92437e 100644 GIT binary patch delta 464 zcmZWkyGjE=6!lEp%z7i59rM_Y!50_;g60Qo1(E#=8}TR7X}XPD1X5Xvg^)07Cn8uy z5yaRiR$^-t3vEnrcV<_D7IPnG&bjB_gX*xtiDgmxtTg$1?G>~xu_iJSreHOc!mp%Z z<4r1RO3d#X=R4n^ZHXl!pGF)f&l@J(nJ|0{K2-+&4d zb0Sss+%PnwaO7$Ne*sk`R{e&t0=3vLvY-BoxI97bL1hWY-!KbW6^zWi$zl(bmk`3J zW8fi5A6|Sf?}HX4%Kqkzqb1HQs3?&W=|h}e4vZ+*`>l*#Y&qr&@=$;3bgs^ZOx4Gu ux4W>`tZ~Q3mU4V<#g0M4I>FKdXhC9LOtq0VSvF_y3TlIgW!~;-X7B?v33HJE delta 589 zcmeyX)~>j8Ns1H}Gw#-^;|+$+(HJl~G7}6Qcsm1h|nX8h|z+D}q~tPZ4&jfnE{e5?I8@ zi|if&Mv#j{k1=uyZ(Vv=A4Vg``Yr5HdmlV>yUmO0DFCAEoBT!>j?f-jJm2qGr= zf)r_i6|vuED1V#|`X#zOViol9guQ76ggNmzz zv7zPr^vkRNG6@0MFS;RYka1ODm9pm;xuiESN(ixaFs|79krhNYT$;{}P{DqkkxOM0 wqvGUjen~cVF6JgiutsDjT%Hbg!p-SWCu|N7h+<-gx?%E0AqkMaF0j7M0F(}aYXATM diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index 66bf31d670f9f..b43528bcb7711 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -409,10 +409,6 @@ "integer": 6 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, @@ -710,10 +706,6 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, @@ -745,10 +737,6 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, @@ -787,10 +775,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -887,10 +871,6 @@ "integer": 1 }], "dataType": { - "elementType": { - "integer": { - } - }, "containsNull": true } } @@ -925,14 +905,6 @@ } }], "dataType": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "valueContainsNull": true } } @@ -967,14 +939,6 @@ } }], "dataType": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "valueContainsNull": true } } @@ -1009,14 +973,6 @@ } }], "dataType": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "valueContainsNull": true } } @@ -1051,10 +1007,6 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, { @@ -1067,10 +1019,6 @@ "integer": 6 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, { @@ -1083,22 +1031,10 @@ "integer": 9 }], "dataType": { - "elementType": { - "integer": { - } - } } } }], "dataType": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - } - } - }, "containsNull": true } } @@ -1140,10 +1076,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -1164,10 +1096,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -1188,28 +1116,11 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } }], "dataType": { - "elementType": { - "map": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } - }, "containsNull": true } } @@ -1256,10 +1167,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -1280,32 +1187,11 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } }], "dataType": { - "keyType": { - "integer": { - } - }, - "valueType": { - "map": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } - }, "valueContainsNull": true } } @@ -1340,10 +1226,6 @@ "integer": 3 }], "dataType": { - "elementType": { - "integer": { - } - } } } }, { @@ -1363,10 +1245,6 @@ "string": { "collation": "UTF8_BINARY" } - }, - "valueType": { - "integer": { - } } } } @@ -1387,10 +1265,6 @@ "string": "b" }], "dataType": { - "keyType": { - "integer": { - } - }, "valueType": { "string": { "collation": "UTF8_BINARY" diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index b3ebe8a79e3ecb357df55f76ec791cf6ebbe0c65..e69435ae88faddfcfd1136de61e1a698f8e297f1 100644 GIT binary patch delta 636 zcmeD2{OHck#U#Ke^~!xC`#oO4ql{dE zG4cV$8G)Dyh?%7rCU0cij$4NP0V9{jE=JYKg$lw#f?WJjyh7ZeA#N7&PM&^_L6IQM zn?t2`F!CK{H`2-z5>WyLftK*(h04N{XDWXN E0Qppc7XSbN delta 841 zcmez9-s8#6#U#Ke)#|yC{T}b+Q{r4~970S6Y*LdQrT0v>lV+Qo$Fy^E4u2en$XP}% z(M^m(Kv_m0W&&bnuqvDq22U8dbapXnN-<|KCW440C7@m)X3`Sn;*a7L;tmaQvxs-{ z^m7b~6k;-Bn7om3`({_E9gIT98M*j3G4g=yl49cmN(oHfz_>%^A|sdNE=DmakTEPE zVJ$8Y3DPS8RA4pvu52Ck6h-thafNPT3=+cdHrQ^MFbmL?tU$~rg(A!j6aob$Cq!6^ zn~QxBBUlC4FZEBExYjUkVqC!}L{b=Gs%8d;C<~4d#Z=7-RLurdZHO%bq-3}xHZfw4 z9EhtAGjVk>?qY0X1jPXmGbxeedQ8>$T(6}7j$I@t@uMafP^9mu|H;I4fbkIH9!6M< sAkrr!Mo4m^Q1l_j2q93C71G (literal.getBoolean.asInstanceOf[Object], classOf[Boolean]) case proto.Expression.Literal.LiteralTypeCase.ARRAY => - val array = literal.getArray - array.getElementType.getKindCase match { + val catalystArray = LiteralValueProtoConverter.toCatalystArray(literal.getArray) + val arrayType = LiteralValueProtoConverter.getProtoArrayType(literal.getArray) + arrayType.getElementType.getKindCase match { case proto.DataType.KindCase.DOUBLE => - (parseDoubleArray(array), classOf[Array[Double]]) + (MLUtils.reconcileArray(classOf[Double], catalystArray), classOf[Array[Double]]) case proto.DataType.KindCase.STRING => - (parseStringArray(array), classOf[Array[String]]) + (MLUtils.reconcileArray(classOf[String], catalystArray), classOf[Array[String]]) case proto.DataType.KindCase.ARRAY => - array.getElementType.getArray.getElementType.getKindCase match { + arrayType.getElementType.getArray.getElementType.getKindCase match { case proto.DataType.KindCase.STRING => - (parseStringArrayArray(array), classOf[Array[Array[String]]]) + ( + MLUtils.reconcileArray(classOf[Array[String]], catalystArray), + classOf[Array[Array[String]]]) case _ => - throw MlUnsupportedException(s"Unsupported inner array $array") + throw MlUnsupportedException(s"Unsupported inner array ${literal.getArray}") } case _ => throw MlUnsupportedException(s"Unsupported array $literal") @@ -193,37 +196,6 @@ private[ml] object Serializer { } } - private def parseDoubleArray(array: proto.Expression.Literal.Array): Array[Double] = { - val values = new Array[Double](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = array.getElements(i).getDouble - i += 1 - } - values - } - - private def parseStringArray(array: proto.Expression.Literal.Array): Array[String] = { - val values = new Array[String](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = array.getElements(i).getString - i += 1 - } - values - } - - private def parseStringArrayArray( - array: proto.Expression.Literal.Array): Array[Array[String]] = { - val values = new Array[Array[String]](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = parseStringArray(array.getElements(i).getArray) - i += 1 - } - values - } - /** * Serialize an instance of "Params" which could be estimator/model/evaluator ... * @param instance diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 0af181e4be1a7..808d2f3281f76 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.planner import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.ToLiteralProtoOptions import org.apache.spark.sql.connect.planner.LiteralExpressionProtoConverter @@ -73,9 +74,25 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i StructField("a", IntegerType), StructField( "b", - StructType( - Seq(StructField("c", IntegerType), StructField("d", IntegerType)))))))).zipWithIndex - .foreach { case ((v, t), idx) => + StructType(Seq(StructField("c", IntegerType), StructField("d", IntegerType))))))), + (Array(true, false, true), ArrayType(BooleanType)), + (Array(1.toByte, 2.toByte, 3.toByte), ArrayType(ByteType)), + (Array(1.toShort, 2.toShort, 3.toShort), ArrayType(ShortType)), + (Array(1, 2, 3), ArrayType(IntegerType)), + (Array(1L, 2L, 3L), ArrayType(LongType)), + (Array(1.1d, 2.1d, 3.1d), ArrayType(DoubleType)), + (Array(1.1f, 2.1f, 3.1f), ArrayType(FloatType)), + (Array(Array[Int](), Array(1, 2, 3), Array(4, 5, 6)), ArrayType(ArrayType(IntegerType))), + (Array(Array(1, 2, 3), Array(4, 5, 6), Array[Int]()), ArrayType(ArrayType(IntegerType))), + ( + Array(Array(Array(Array(Array(Array(1, 2, 3)))))), + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))), + (Array(Map(1 -> 2)), ArrayType(MapType(IntegerType, IntegerType))), + (Map[String, String]("1" -> "2", "3" -> "4"), MapType(StringType, StringType)), + (Map[String, Boolean]("1" -> true, "2" -> false), MapType(StringType, BooleanType)), + (Map[Int, Int](), MapType(IntegerType, IntegerType)), + (Map(1 -> 2, 3 -> 4, 5 -> 6), MapType(IntegerType, IntegerType))).zipWithIndex.foreach { + case ((v, t), idx) => test(s"complex proto value and catalyst value conversion #$idx") { assertResult(v)( LiteralValueProtoConverter.toCatalystValue( @@ -93,7 +110,7 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i Some(t), ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)))) } - } + } test("backward compatibility for array literal proto") { // Test the old way of defining arrays with elementType field and elements @@ -259,4 +276,70 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i assert(!structTypeProto.getFieldsList.get(1).getNullable) assert(!structTypeProto.getFieldsList.get(1).hasMetadata) } + + test("element type of array literal is set for an empty array") { + val literalProto = + toLiteralProto(Array[Int](), ArrayType(IntegerType)) + assert(literalProto.getArray.getDataType.hasElementType) + } + + test("element type of array literal is set for a non-empty array with non-inferable type") { + val literalProto = toLiteralProto(Array[String]("1", "2", "3"), ArrayType(StringType)) + assert(literalProto.getArray.getDataType.hasElementType) + } + + test("element type of array literal is not set for a non-empty array with inferable type") { + val literalProto = + toLiteralProto(Array(1, 2, 3), ArrayType(IntegerType)) + assert(!literalProto.getArray.getDataType.hasElementType) + } + + test("key and value type of map literal are set for an empty map") { + val literalProto = toLiteralProto(Map[Int, Int](), MapType(IntegerType, IntegerType)) + assert(literalProto.getMap.getDataType.hasKeyType) + assert(literalProto.getMap.getDataType.hasValueType) + } + + test("key type of map literal is set for a non-empty map with non-inferable key type") { + val literalProto = toLiteralProto( + Map[String, Int]("1" -> 1, "2" -> 2, "3" -> 3), + MapType(StringType, IntegerType)) + assert(literalProto.getMap.getDataType.hasKeyType) + assert(!literalProto.getMap.getDataType.hasValueType) + } + + test("value type of map literal is set for a non-empty map with non-inferable value type") { + val literalProto = toLiteralProto( + Map[Int, String](1 -> "1", 2 -> "2", 3 -> "3"), + MapType(IntegerType, StringType)) + assert(!literalProto.getMap.getDataType.hasKeyType) + assert(literalProto.getMap.getDataType.hasValueType) + } + + test("key and value type of map literal are not set for a non-empty map with inferable types") { + val literalProto = + toLiteralProto(Map(1 -> 2, 3 -> 4, 5 -> 6), MapType(IntegerType, IntegerType)) + assert(!literalProto.getMap.getDataType.hasKeyType) + assert(!literalProto.getMap.getDataType.hasValueType) + } + + test("an invalid array literal") { + val literalProto = proto.Expression.Literal + .newBuilder() + .setArray(proto.Expression.Literal.Array.newBuilder()) + .build() + intercept[InvalidPlanInput] { + LiteralValueProtoConverter.toCatalystValue(literalProto) + } + } + + test("an invalid map literal") { + val literalProto = proto.Expression.Literal + .newBuilder() + .setMap(proto.Expression.Literal.Map.newBuilder()) + .build() + intercept[InvalidPlanInput] { + LiteralValueProtoConverter.toCatalystValue(literalProto) + } + } } From c50229cff58499cdd261f33f6eb7ee228d3a6914 Mon Sep 17 00:00:00 2001 From: Yihong He Date: Thu, 28 Aug 2025 17:43:40 +0200 Subject: [PATCH 2/2] Optimize further --- .../sql/connect/proto/expressions_pb2.pyi | 25 ++- .../spark/sql/PlanGenerationTestSuite.scala | 6 + .../protobuf/spark/connect/expressions.proto | 25 ++- .../common/LiteralValueProtoConverter.scala | 126 ++++++----- .../queries/function_lit_array.json | 30 +-- .../queries/function_lit_array.proto.bin | Bin 4989 -> 4965 bytes .../queries/function_typedLit.json | 201 ++++++++++++++++-- .../queries/function_typedLit.proto.bin | Bin 9201 -> 9795 bytes ...LiteralExpressionProtoConverterSuite.scala | 98 +++------ 9 files changed, 333 insertions(+), 178 deletions(-) diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index 508a11a01c85e..8dfac59f46544 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -479,8 +479,7 @@ class Expression(google.protobuf.message.Message): def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: """(Deprecated) The element type of the array. - This field is deprecated since Spark 4.1+ and should only be set - if the data_type field is not set. Use data_type field instead. + This field is deprecated since Spark 4.1+. Use data_type field instead. """ @property def elements( @@ -491,11 +490,14 @@ class Expression(google.protobuf.message.Message): """The literal values that make up the array elements.""" @property def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array: - """The type of the array. + """The type of the array. You don't need to set this field if the type information is not needed. If the element type can be inferred from the first element of the elements field, - then you don't need to set data_type.element_type to save space. On the other hand, - redundant type information is also acceptable. + then you don't need to set data_type.element_type to save space. + For inferring the data_type.element_type, only the first element needs to + contain the type information. + + On the other hand, redundant type information is also acceptable. """ def __init__( self, @@ -534,8 +536,7 @@ class Expression(google.protobuf.message.Message): def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: """(Deprecated) The key type of the map. - This field is deprecated since Spark 4.1+ and should only be set - if the data_type field is not set. Use data_type field instead. + This field is deprecated since Spark 4.1+. Use data_type field instead. """ @property def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: @@ -560,10 +561,13 @@ class Expression(google.protobuf.message.Message): """The literal values that make up the map.""" @property def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map: - """The type of the map. + """The type of the map. You don't need to set this field if the type information is not needed. If the key/value types can be inferred from the first element of the keys/values fields, then you don't need to set data_type.key_type/data_type.value_type to save space. + For inferring the data_type.key_type/data_type.value_type, only the first element needs to + contain the type information. + On the other hand, redundant type information is also acceptable. """ def __init__( @@ -608,8 +612,7 @@ class Expression(google.protobuf.message.Message): """(Deprecated) The type of the struct. This field is deprecated since Spark 4.1+ because using DataType as the type of a struct - is ambiguous. This field should only be set if the data_type_struct field is not set. - Use data_type_struct field instead. + is ambiguous. Use data_type_struct field instead. """ @property def elements( @@ -620,7 +623,7 @@ class Expression(google.protobuf.message.Message): """(Required) The literal values that make up the struct elements.""" @property def data_type_struct(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Struct: - """The type of the struct. + """The type of the struct. You don't need to set this field if the type information is not needed. Whether data_type_struct.fields.data_type should be set depends on whether each field's type can be inferred from the elements field. diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index b760828a1e99c..ec79ad601b9d7 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3414,11 +3414,17 @@ class PlanGenerationTestSuite fn.typedlit[collection.immutable.Map[Int, Option[Int]]]( collection.immutable.Map(1 -> None)), fn.typedLit(Seq(Seq(1, 2, 3), Seq(4, 5, 6), Seq(7, 8, 9))), + fn.typedLit(Seq((1, "2", Seq("3", "4")), (5, "6", Seq.empty[String]))), fn.typedLit( Seq( mutable.LinkedHashMap("a" -> 1, "b" -> 2), mutable.LinkedHashMap("a" -> 3, "b" -> 4), mutable.LinkedHashMap("a" -> 5, "b" -> 6))), + fn.typedLit( + Seq( + mutable.LinkedHashMap("a" -> Seq("1", "2"), "b" -> Seq("3", "4")), + mutable.LinkedHashMap("a" -> Seq("5", "6"), "b" -> Seq("7", "8")), + mutable.LinkedHashMap("a" -> Seq.empty[String], "b" -> Seq.empty[String]))), fn.typedLit( mutable.LinkedHashMap( 1 -> mutable.LinkedHashMap("a" -> 1, "b" -> 2), diff --git a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto index 913622b91a284..162e5a323a6fb 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -217,26 +217,27 @@ message Expression { message Array { // (Deprecated) The element type of the array. // - // This field is deprecated since Spark 4.1+ and should only be set - // if the data_type field is not set. Use data_type field instead. + // This field is deprecated since Spark 4.1+. Use data_type field instead. DataType element_type = 1 [deprecated = true]; // The literal values that make up the array elements. repeated Literal elements = 2; - // The type of the array. + // The type of the array. You don't need to set this field if the type information is not needed. // // If the element type can be inferred from the first element of the elements field, - // then you don't need to set data_type.element_type to save space. On the other hand, - // redundant type information is also acceptable. + // then you don't need to set data_type.element_type to save space. + // For inferring the data_type.element_type, only the first element needs to + // contain the type information. + // + // On the other hand, redundant type information is also acceptable. DataType.Array data_type = 3; } message Map { // (Deprecated) The key type of the map. // - // This field is deprecated since Spark 4.1+ and should only be set - // if the data_type field is not set. Use data_type field instead. + // This field is deprecated since Spark 4.1+. Use data_type field instead. DataType key_type = 1 [deprecated = true]; // (Deprecated) The value type of the map. @@ -251,10 +252,13 @@ message Expression { // The literal values that make up the map. repeated Literal values = 4; - // The type of the map. + // The type of the map. You don't need to set this field if the type information is not needed. // // If the key/value types can be inferred from the first element of the keys/values fields, // then you don't need to set data_type.key_type/data_type.value_type to save space. + // For inferring the data_type.key_type/data_type.value_type, only the first element needs to + // contain the type information. + // // On the other hand, redundant type information is also acceptable. DataType.Map data_type = 5; } @@ -263,14 +267,13 @@ message Expression { // (Deprecated) The type of the struct. // // This field is deprecated since Spark 4.1+ because using DataType as the type of a struct - // is ambiguous. This field should only be set if the data_type_struct field is not set. - // Use data_type_struct field instead. + // is ambiguous. Use data_type_struct field instead. DataType struct_type = 1 [deprecated = true]; // (Required) The literal values that make up the struct elements. repeated Literal elements = 2; - // The type of the struct. + // The type of the struct. You don't need to set this field if the type information is not needed. // // Whether data_type_struct.fields.data_type should be set depends on // whether each field's type can be inferred from the elements field. diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 05713a9d9fd18..7ca1290ea4c2d 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -43,10 +43,11 @@ object LiteralValueProtoConverter { ab: proto.Expression.Literal.Array.Builder, elementType: DataType, containsNull: Boolean, - useDeprecatedDataTypeFields: Boolean): Unit = { + useDeprecatedDataTypeFields: Boolean, + needElementType: Boolean): Unit = { if (useDeprecatedDataTypeFields) { ab.setElementType(toConnectProtoType(elementType)) - } else { + } else if (needElementType) { val dataTypeBuilder = proto.DataType.Array.newBuilder() if (ab.getElementsCount == 0 || getInferredDataType(ab.getElements(0)).isEmpty) { dataTypeBuilder.setElementType(toConnectProtoType(elementType)) @@ -61,11 +62,12 @@ object LiteralValueProtoConverter { keyType: DataType, valueType: DataType, valueContainsNull: Boolean, - useDeprecatedDataTypeFields: Boolean): Unit = { + useDeprecatedDataTypeFields: Boolean, + needKeyAndValueType: Boolean): Unit = { if (useDeprecatedDataTypeFields) { mb.setKeyType(toConnectProtoType(keyType)) mb.setValueType(toConnectProtoType(valueType)) - } else { + } else if (needKeyAndValueType) { val dataTypeBuilder = proto.DataType.Map.newBuilder() if (mb.getKeysCount == 0 || getInferredDataType(mb.getKeys(0)).isEmpty) { dataTypeBuilder.setKeyType(toConnectProtoType(keyType)) @@ -78,10 +80,10 @@ object LiteralValueProtoConverter { } } - @scala.annotation.tailrec private def toLiteralProtoBuilderInternal( literal: Any, - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { + options: ToLiteralProtoOptions, + needDataType: Boolean): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() def decimalBuilder(precision: Int, scale: Int, value: String) = { @@ -97,12 +99,17 @@ object LiteralValueProtoConverter { def arrayBuilder(array: Array[_]) = { val ab = builder.getArrayBuilder - array.foreach(x => ab.addElements(toLiteralProtoWithOptions(x, None, options))) + var needElementType = needDataType + array.foreach { x => + ab.addElements(toLiteralProtoBuilderInternal(x, options, needElementType).build()) + needElementType = false + } setArrayTypeAfterAddingElements( ab, toDataType(array.getClass.getComponentType), containsNull = true, - options.useDeprecatedDataTypeFields) + options.useDeprecatedDataTypeFields, + needDataType) ab } @@ -122,8 +129,9 @@ object LiteralValueProtoConverter { case v: Char => builder.setString(v.toString) case v: Array[Char] => builder.setString(String.valueOf(v)) case v: Array[Byte] => builder.setBinary(ByteString.copyFrom(v)) - case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options) - case v: immutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.unsafeArray, options) + case v: mutable.ArraySeq[_] => toLiteralProtoBuilderInternal(v.array, options, needDataType) + case v: immutable.ArraySeq[_] => + toLiteralProtoBuilderInternal(v.unsafeArray, options, needDataType) case v: LocalDate => builder.setDate(v.toEpochDay.toInt) case v: Decimal => builder.setDecimal(decimalBuilder(Math.max(v.precision, v.scale), v.scale, v.toString)) @@ -147,22 +155,29 @@ object LiteralValueProtoConverter { } } - @scala.annotation.tailrec private def toLiteralProtoBuilderInternal( literal: Any, dataType: DataType, - options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { + options: ToLiteralProtoOptions, + needDataType: Boolean): proto.Expression.Literal.Builder = { val builder = proto.Expression.Literal.newBuilder() def arrayBuilder(scalaValue: Any, elementType: DataType, containsNull: Boolean) = { val ab = builder.getArrayBuilder + var needElementType = needDataType scalaValue match { case a: Array[_] => - a.foreach(item => - ab.addElements(toLiteralProtoWithOptions(item, Some(elementType), options))) + a.foreach { item => + ab.addElements( + toLiteralProtoBuilderInternal(item, elementType, options, needElementType).build()) + needElementType = false + } case s: scala.collection.Seq[_] => - s.foreach(item => - ab.addElements(toLiteralProtoWithOptions(item, Some(elementType), options))) + s.foreach { item => + ab.addElements( + toLiteralProtoBuilderInternal(item, elementType, options, needElementType).build()) + needElementType = false + } case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") } @@ -170,7 +185,8 @@ object LiteralValueProtoConverter { ab, elementType, containsNull, - options.useDeprecatedDataTypeFields) + options.useDeprecatedDataTypeFields, + needDataType) ab } @@ -180,11 +196,15 @@ object LiteralValueProtoConverter { valueType: DataType, valueContainsNull: Boolean) = { val mb = builder.getMapBuilder + var needKeyAndValueType = needDataType scalaValue match { case map: scala.collection.Map[_, _] => map.foreach { case (k, v) => - mb.addKeys(toLiteralProtoWithOptions(k, Some(keyType), options)) - mb.addValues(toLiteralProtoWithOptions(v, Some(valueType), options)) + mb.addKeys( + toLiteralProtoBuilderInternal(k, keyType, options, needKeyAndValueType).build()) + mb.addValues( + toLiteralProtoBuilderInternal(v, valueType, options, needKeyAndValueType).build()) + needKeyAndValueType = false } case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") @@ -194,7 +214,8 @@ object LiteralValueProtoConverter { keyType, valueType, valueContainsNull, - options.useDeprecatedDataTypeFields) + options.useDeprecatedDataTypeFields, + needDataType) mb } @@ -209,37 +230,42 @@ object LiteralValueProtoConverter { if (options.useDeprecatedDataTypeFields) { while (idx < structType.size) { val field = fields(idx) - val literalProto = - toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options) + // For backward compatibility, we need the data type for each field. + val literalProto = toLiteralProtoBuilderInternal( + iter.next(), + field.dataType, + options, + needDataType = true).build() sb.addElements(literalProto) idx += 1 } sb.setStructType(toConnectProtoType(structType)) } else { - val dataTypeStruct = proto.DataType.Struct.newBuilder() while (idx < structType.size) { val field = fields(idx) val literalProto = - toLiteralProtoWithOptions(iter.next(), Some(field.dataType), options) + toLiteralProtoBuilderInternal(iter.next(), field.dataType, options, needDataType) + .build() sb.addElements(literalProto) - val fieldBuilder = dataTypeStruct - .addFieldsBuilder() - .setName(field.name) - .setNullable(field.nullable) + if (needDataType) { + val fieldBuilder = sb.getDataTypeStructBuilder + .addFieldsBuilder() + .setName(field.name) + .setNullable(field.nullable) - if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) { - fieldBuilder.setDataType(toConnectProtoType(field.dataType)) - } + if (LiteralValueProtoConverter.getInferredDataType(literalProto).isEmpty) { + fieldBuilder.setDataType(toConnectProtoType(field.dataType)) + } - // Set metadata if available - if (field.metadata != Metadata.empty) { - fieldBuilder.setMetadata(field.metadata.json) + // Set metadata if available + if (field.metadata != Metadata.empty) { + fieldBuilder.setMetadata(field.metadata.json) + } } idx += 1 } - sb.setDataTypeStruct(dataTypeStruct.build()) } case other => throw new IllegalArgumentException(s"literal $other not supported (yet).") @@ -250,11 +276,11 @@ object LiteralValueProtoConverter { (literal, dataType) match { case (v: mutable.ArraySeq[_], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v.array, dataType, options) + toLiteralProtoBuilderInternal(v.array, dataType, options, needDataType) case (v: immutable.ArraySeq[_], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v.unsafeArray, dataType, options) + toLiteralProtoBuilderInternal(v.unsafeArray, dataType, options, needDataType) case (v: Array[Byte], ArrayType(_, _)) => - toLiteralProtoBuilderInternal(v, options) + toLiteralProtoBuilderInternal(v, options, needDataType) case (v, ArrayType(elementType, containsNull)) => builder.setArray(arrayBuilder(v, elementType, containsNull)) case (v, MapType(keyType, valueType, valueContainsNull)) => @@ -263,7 +289,7 @@ object LiteralValueProtoConverter { builder.setStruct(structBuilder(v, structType)) case (v: Option[_], _: DataType) => if (v.isDefined) { - toLiteralProtoBuilderInternal(v.get, options) + toLiteralProtoBuilderInternal(v.get, options, needDataType) } else { builder.setNull(toConnectProtoType(dataType)) } @@ -272,7 +298,7 @@ object LiteralValueProtoConverter { builder.getTimeBuilder .setNano(SparkDateTimeUtils.localTimeToNanos(v)) .setPrecision(timeType.precision)) - case _ => toLiteralProtoBuilderInternal(literal, options) + case _ => toLiteralProtoBuilderInternal(literal, options, needDataType) } } @@ -286,7 +312,8 @@ object LiteralValueProtoConverter { def toLiteralProtoBuilder(literal: Any): proto.Expression.Literal.Builder = { toLiteralProtoBuilderInternal( literal, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true), + needDataType = true) } def toLiteralProtoBuilder( @@ -295,7 +322,8 @@ object LiteralValueProtoConverter { toLiteralProtoBuilderInternal( literal, dataType, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true), + needDataType = true) } def toLiteralProtoBuilderWithOptions( @@ -304,9 +332,9 @@ object LiteralValueProtoConverter { options: ToLiteralProtoOptions): proto.Expression.Literal.Builder = { dataTypeOpt match { case Some(dataType) => - toLiteralProtoBuilderInternal(literal, dataType, options) + toLiteralProtoBuilderInternal(literal, dataType, options, needDataType = true) case None => - toLiteralProtoBuilderInternal(literal, options) + toLiteralProtoBuilderInternal(literal, options, needDataType = true) } } @@ -328,13 +356,15 @@ object LiteralValueProtoConverter { def toLiteralProto(literal: Any): proto.Expression.Literal = toLiteralProtoBuilderInternal( literal, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)).build() + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true), + needDataType = true).build() def toLiteralProto(literal: Any, dataType: DataType): proto.Expression.Literal = toLiteralProtoBuilderInternal( literal, dataType, - ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)).build() + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true), + needDataType = true).build() def toLiteralProtoWithOptions( literal: Any, @@ -342,9 +372,9 @@ object LiteralValueProtoConverter { options: ToLiteralProtoOptions): proto.Expression.Literal = { dataTypeOpt match { case Some(dataType) => - toLiteralProtoBuilderInternal(literal, dataType, options).build() + toLiteralProtoBuilderInternal(literal, dataType, options, needDataType = true).build() case None => - toLiteralProtoBuilderInternal(literal, options).build() + toLiteralProtoBuilderInternal(literal, options, needDataType = true).build() } } diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json index dc73d05d6c53b..153478ce75bb5 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json @@ -56,19 +56,13 @@ "array": { "elements": [{ "integer": 2 - }], - "dataType": { - "containsNull": true - } + }] } }, { "array": { "elements": [{ "integer": 3 - }], - "dataType": { - "containsNull": true - } + }] } }], "dataType": { @@ -118,15 +112,9 @@ "array": { "elements": [{ "integer": 2 - }], - "dataType": { - "containsNull": true - } + }] } - }], - "dataType": { - "containsNull": true - } + }] } }, { "array": { @@ -134,15 +122,9 @@ "array": { "elements": [{ "integer": 3 - }], - "dataType": { - "containsNull": true - } + }] } - }], - "dataType": { - "containsNull": true - } + }] } }], "dataType": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin index 66877f54fef2f3ad42b159c3c545acc5fe92437e..d9edb4100b0d614dbb0b7065a2e9d87ee67f1f2c 100644 GIT binary patch delta 106 zcmeyX_Ee3Hi%Eb{>fS`Qw`w;Txzsi>DhqLMV&o8FGGLSf$_TM5I j*xxX6S#Dx9n|zN^T$*PSBPUFjNeGz-R=fEc({cd--&7X% delta 146 zcmaE=_E(LKi%Eb{>f1!Nx4I7*x%4(MY7231V&o8FGGLSf%D^~G$Q))MXR;ydc8@=d mT%Ma4-GoFoF$%&|g0+EFBa4FdV~Bzb0;vZXws{}ZasdGR+ZoFM diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index b43528bcb7711..bdf75c68ce6c1 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -1017,9 +1017,7 @@ "integer": 5 }, { "integer": 6 - }], - "dataType": { - } + }] } }, { "array": { @@ -1029,10 +1027,85 @@ "integer": 8 }, { "integer": 9 + }] + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "array": { + "elements": [{ + "struct": { + "elements": [{ + "integer": 1 + }, { + "string": "2" + }, { + "array": { + "elements": [{ + "string": "3" + }, { + "string": "4" + }], + "dataType": { + "elementType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "containsNull": true + } + } }], - "dataType": { + "dataTypeStruct": { + "fields": [{ + "name": "_1" + }, { + "name": "_2", + "dataType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "nullable": true + }, { + "name": "_3", + "nullable": true + }] } } + }, { + "struct": { + "elements": [{ + "integer": 5 + }, { + "string": "6" + }, { + "array": { + } + }] + } }], "dataType": { "containsNull": true @@ -1090,14 +1163,7 @@ "integer": 3 }, { "integer": 4 - }], - "dataType": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - } - } + }] } }, { "map": { @@ -1110,15 +1176,115 @@ "integer": 5 }, { "integer": 6 + }] + } + }], + "dataType": { + "containsNull": true + } + } + }, + "common": { + "origin": { + "jvmOrigin": { + "stackTrace": [{ + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.functions$", + "methodName": "typedLit", + "fileName": "functions.scala" + }, { + "classLoaderName": "app", + "declaringClass": "org.apache.spark.sql.PlanGenerationTestSuite", + "methodName": "~~trimmed~anonfun~~", + "fileName": "PlanGenerationTestSuite.scala" + }] + } + } + } + }, { + "literal": { + "array": { + "elements": [{ + "map": { + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "array": { + "elements": [{ + "string": "1" + }, { + "string": "2" + }], + "dataType": { + "elementType": { + "string": { + "collation": "UTF8_BINARY" + } + }, + "containsNull": true + } + } + }, { + "array": { + "elements": [{ + "string": "3" + }, { + "string": "4" + }] + } }], "dataType": { "keyType": { "string": { "collation": "UTF8_BINARY" } - } + }, + "valueContainsNull": true } } + }, { + "map": { + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "array": { + "elements": [{ + "string": "5" + }, { + "string": "6" + }] + } + }, { + "array": { + "elements": [{ + "string": "7" + }, { + "string": "8" + }] + } + }] + } + }, { + "map": { + "keys": [{ + "string": "a" + }, { + "string": "b" + }], + "values": [{ + "array": { + } + }, { + "array": { + } + }] + } }], "dataType": { "containsNull": true @@ -1181,14 +1347,7 @@ "integer": 3 }, { "integer": 4 - }], - "dataType": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - } - } + }] } }], "dataType": { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index e69435ae88faddfcfd1136de61e1a698f8e297f1..a6cacd8f9117e741a2bbe1ee45efed56d9438b67 100644 GIT binary patch delta 504 zcmez9e%Obdi%Eb{YLCxGb_IF+FN|C^n;0#H1UE7A2{9Qk0x=U1GfOcD@o!?}0Ww*D zm=%cGU;^ww4hIl(N-+sc-oUt{ek~JM=_bY^p~ypwVL)+4A?7SbBO#?tjPf9sF^Djc z66WHM;uYc!4RN!Gck=Xe42lFgLY<3+iz(iaOB_fV385>HVB}&0N*GHp3JD$pJC7A) zmYERqCPoIJ<&2XUH!*UppS+!MBl~nFuJ}!iF_Y&hX&Z>{Vib}BDog|sNlHN905OXY zhA=Bom<{HMdrVwQfyOOl6bjwN7=)x9<_|-#x3Ky{i5KV&gkQBV{R;Gg_AW*ZBx7M( zOu=4K(t`2LL3|4#@m-7}NUA|TR)TsP6z1$5Ok4rG7=0&aD;kSocns{{$qQs{Cm)ci K+uWp-&jar delta 253 zcmX@?^UXrLOb_IF+AB0W z2(bgX96-z|HJMj=59fX+t~SO^j7^M_Kd5L2YwTiFm153fOau{0N 1, "b" -> 2), + Some(MapType(StringType, IntegerType, valueContainsNull = false)), + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + assert(!literalProto.getMap.hasDataType) + assert(literalProto.getMap.getKeysList.size == 2) + assert(literalProto.getMap.getValuesList.size == 2) + assert(literalProto.getMap.getKeyType.hasString) + assert(literalProto.getMap.getValueType.hasInteger) - val literalProto = proto.Expression.Literal.newBuilder().setMap(mapProto).build() val literal = LiteralExpressionProtoConverter.toCatalystExpression(literalProto) assert(literal.dataType.isInstanceOf[MapType]) assert(literal.dataType.asInstanceOf[MapType].keyType == StringType) @@ -180,39 +166,25 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i test("backward compatibility for struct literal proto") { // Test the old way of defining structs with structType field and elements - val structTypeProto = proto.DataType.Struct - .newBuilder() - .addFields( - proto.DataType.StructField - .newBuilder() - .setName("a") - .setDataType(proto.DataType - .newBuilder() - .setInteger(proto.DataType.Integer.newBuilder()) - .build()) - .setNullable(true) - .build()) - .addFields( - proto.DataType.StructField - .newBuilder() - .setName("b") - .setDataType(proto.DataType - .newBuilder() - .setString(proto.DataType.String.newBuilder()) - .build()) - .setNullable(false) - .build()) - .build() - - val structProto = proto.Expression.Literal.Struct - .newBuilder() - .setStructType(proto.DataType.newBuilder().setStruct(structTypeProto).build()) - .addElements(LiteralValueProtoConverter.toLiteralProto(1)) - .addElements(LiteralValueProtoConverter.toLiteralProto("test")) - .build() - - val result = LiteralValueProtoConverter.toCatalystStruct(structProto) - val resultType = LiteralValueProtoConverter.getProtoStructType(structProto) + val structProto = LiteralValueProtoConverter.toLiteralProtoWithOptions( + (1, "test"), + Some( + StructType( + Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", StringType, nullable = false)))), + ToLiteralProtoOptions(useDeprecatedDataTypeFields = true)) + assert(!structProto.getStruct.hasDataTypeStruct) + assert(structProto.getStruct.getElementsList.size == 2) + val structTypeProto = structProto.getStruct.getStructType.getStruct + assert(structTypeProto.getFieldsList.size == 2) + assert(structTypeProto.getFieldsList.get(0).getName == "a") + assert(structTypeProto.getFieldsList.get(0).getDataType.hasInteger) + assert(structTypeProto.getFieldsList.get(1).getName == "b") + assert(structTypeProto.getFieldsList.get(1).getDataType.hasString) + + val result = LiteralValueProtoConverter.toCatalystValue(structProto) + val resultType = LiteralValueProtoConverter.getProtoStructType(structProto.getStruct) // Verify the result is a tuple with correct values assert(result.isInstanceOf[Product])