diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 872770ee22911..c72a6b023c722 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -29,9 +29,9 @@ Optional, ) -import json -import decimal import datetime +import decimal +import json import warnings from threading import Lock @@ -377,6 +377,52 @@ def _infer_type(cls, value: Any) -> DataType: def _from_value(cls, value: Any) -> "LiteralExpression": return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value)) + @classmethod + def _infer_type_from_literal(cls, literal: "proto.Expression.Literal") -> Optional[DataType]: + if literal.HasField("null"): + return NullType() + elif literal.HasField("binary"): + return BinaryType() + elif literal.HasField("boolean"): + return BooleanType() + elif literal.HasField("byte"): + return ByteType() + elif literal.HasField("short"): + return ShortType() + elif literal.HasField("integer"): + return IntegerType() + elif literal.HasField("long"): + return LongType() + elif literal.HasField("float"): + return FloatType() + elif literal.HasField("double"): + return DoubleType() + elif literal.HasField("date"): + return DateType() + elif literal.HasField("timestamp"): + return TimestampType() + elif literal.HasField("timestamp_ntz"): + return TimestampNTZType() + elif literal.HasField("array"): + if literal.array.HasField("element_type"): + return ArrayType( + proto_schema_to_pyspark_data_type(literal.array.element_type), True + ) + element_type = None + if len(literal.array.elements) > 0: + element_type = LiteralExpression._infer_type_from_literal(literal.array.elements[0]) + + if element_type is None: + raise PySparkTypeError( + errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE", + messageParameters={}, + ) + return ArrayType(element_type, True) + # Not all data types support inferring the data type from the literal at the moment. + # e.g. the type of DayTimeInterval contains extra information like start_field and + # end_field and cannot be inferred from the literal. + return None + @classmethod def _to_value( cls, literal: "proto.Expression.Literal", dataType: Optional[DataType] = None @@ -426,10 +472,20 @@ def _to_value( assert dataType is None or isinstance(dataType, DayTimeIntervalType) return DayTimeIntervalType().fromInternal(literal.day_time_interval) elif literal.HasField("array"): - elementType = proto_schema_to_pyspark_data_type(literal.array.element_type) - if dataType is not None: - assert isinstance(dataType, ArrayType) - assert elementType == dataType.elementType + elementType = None + if literal.array.HasField("element_type"): + elementType = proto_schema_to_pyspark_data_type(literal.array.element_type) + if dataType is not None: + assert isinstance(dataType, ArrayType) + assert elementType == dataType.elementType + elif len(literal.array.elements) > 0: + elementType = LiteralExpression._infer_type_from_literal(literal.array.elements[0]) + + if elementType is None: + raise PySparkTypeError( + errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE", + messageParameters={}, + ) return [LiteralExpression._to_value(v, elementType) for v in literal.array.elements] raise PySparkTypeError( @@ -475,11 +531,17 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": elif isinstance(self._dataType, DayTimeIntervalType): expr.literal.day_time_interval = int(self._value) elif isinstance(self._dataType, ArrayType): - element_type = self._dataType.elementType - expr.literal.array.element_type.CopyFrom(pyspark_types_to_proto_types(element_type)) for v in self._value: expr.literal.array.elements.append( - LiteralExpression(v, element_type).to_plan(session).literal + LiteralExpression(v, self._dataType.elementType).to_plan(session).literal + ) + if ( + len(self._value) == 0 + or LiteralExpression._infer_type_from_literal(expr.literal.array.elements[0]) + is None + ): + expr.literal.array.element_type.CopyFrom( + pyspark_types_to_proto_types(self._dataType.elementType) ) else: raise PySparkTypeError( diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi index ad347fd4bd154..a59cd5c86864d 100644 --- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi +++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi @@ -475,7 +475,11 @@ class Expression(google.protobuf.message.Message): ELEMENT_TYPE_FIELD_NUMBER: builtins.int ELEMENTS_FIELD_NUMBER: builtins.int @property - def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: + """(Optional) The element type of the array. Only need to set this when the elements are + empty or the element type is not inferable, since spark 4.1+ supports + inferring the element type from the elements. + """ @property def elements( self, @@ -506,9 +510,17 @@ class Expression(google.protobuf.message.Message): KEYS_FIELD_NUMBER: builtins.int VALUES_FIELD_NUMBER: builtins.int @property - def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: + """(Optional) The key type of the map. Only need to set this when the keys are + empty or the key type is not inferable, since spark 4.1+ supports + inferring the key type from the keys + """ @property - def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ... + def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: + """(Optional) The value type of the map. Only need to set this when the values are + empty or the value type is not inferable, since spark 4.1+ supports + inferring the value type from the values. + """ @property def keys( self, diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py index a03cd30c733fb..59f3c9ca6ccdc 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan.py @@ -979,21 +979,20 @@ def test_literal_expression_with_arrays(self): self.assertEqual(l0.array.elements[2].string, "z") l1 = LiteralExpression._from_value([3, -3]).to_plan(None).literal - self.assertTrue(l1.array.element_type.HasField("integer")) + self.assertFalse(l1.array.element_type.HasField("integer")) self.assertEqual(len(l1.array.elements), 2) self.assertEqual(l1.array.elements[0].integer, 3) self.assertEqual(l1.array.elements[1].integer, -3) l2 = LiteralExpression._from_value([float("nan"), -3.0, 0.0]).to_plan(None).literal - self.assertTrue(l2.array.element_type.HasField("double")) + self.assertFalse(l2.array.element_type.HasField("double")) self.assertEqual(len(l2.array.elements), 3) self.assertTrue(math.isnan(l2.array.elements[0].double)) self.assertEqual(l2.array.elements[1].double, -3.0) self.assertEqual(l2.array.elements[2].double, 0.0) l3 = LiteralExpression._from_value([[3, 4], [5, 6, 7]]).to_plan(None).literal - self.assertTrue(l3.array.element_type.HasField("array")) - self.assertTrue(l3.array.element_type.array.element_type.HasField("integer")) + self.assertFalse(l3.array.element_type.HasField("array")) self.assertEqual(len(l3.array.elements), 2) self.assertEqual(len(l3.array.elements[0].array.elements), 2) self.assertEqual(len(l3.array.elements[1].array.elements), 3) @@ -1003,8 +1002,7 @@ def test_literal_expression_with_arrays(self): .to_plan(None) .literal ) - self.assertTrue(l4.array.element_type.HasField("array")) - self.assertTrue(l4.array.element_type.array.element_type.HasField("double")) + self.assertFalse(l4.array.element_type.HasField("array")) self.assertEqual(len(l4.array.elements), 3) self.assertEqual(len(l4.array.elements[0].array.elements), 2) self.assertEqual(len(l4.array.elements[1].array.elements), 2) diff --git a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto index 3ae6cb8dba9b5..32fcb56c1e4fb 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/expressions.proto @@ -215,12 +215,21 @@ message Expression { } message Array { + // (Optional) The element type of the array. Only need to set this when the elements are + // empty or the element type is not inferable, since spark 4.1+ supports + // inferring the element type from the elements. DataType element_type = 1; repeated Literal elements = 2; } message Map { + // (Optional) The key type of the map. Only need to set this when the keys are + // empty or the key type is not inferable, since spark 4.1+ supports + // inferring the key type from the keys DataType key_type = 1; + // (Optional) The value type of the map. Only need to set this when the values are + // empty or the value type is not inferable, since spark 4.1+ supports + // inferring the value type from the values. DataType value_type = 2; repeated Literal keys = 3; repeated Literal values = 4; diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala index 4567cc10c81c8..a35dd2e4ecd82 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/LiteralValueProtoConverter.scala @@ -24,7 +24,7 @@ import java.time._ import scala.collection.{immutable, mutable} import scala.jdk.CollectionConverters._ -import scala.reflect.ClassTag +import scala.language.existentials import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -63,8 +63,10 @@ object LiteralValueProtoConverter { def arrayBuilder(array: Array[_]) = { val ab = builder.getArrayBuilder - .setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) array.foreach(x => ab.addElements(toLiteralProto(x))) + if (ab.getElementsCount == 0 || getInferredDataType(ab.getElementsList.get(0)).isEmpty) { + ab.setElementType(toConnectProtoType(toDataType(array.getClass.getComponentType))) + } ab } @@ -116,7 +118,7 @@ object LiteralValueProtoConverter { val builder = proto.Expression.Literal.newBuilder() def arrayBuilder(scalaValue: Any, elementType: DataType) = { - val ab = builder.getArrayBuilder.setElementType(toConnectProtoType(elementType)) + val ab = builder.getArrayBuilder scalaValue match { case a: Array[_] => @@ -127,13 +129,15 @@ object LiteralValueProtoConverter { throw new IllegalArgumentException(s"literal $other not supported (yet).") } + if (ab.getElementsCount == 0 || getInferredDataType(ab.getElementsList.get(0)).isEmpty) { + ab.setElementType(toConnectProtoType(elementType)) + } + ab } def mapBuilder(scalaValue: Any, keyType: DataType, valueType: DataType) = { val mb = builder.getMapBuilder - .setKeyType(toConnectProtoType(keyType)) - .setValueType(toConnectProtoType(valueType)) scalaValue match { case map: scala.collection.Map[_, _] => @@ -145,6 +149,14 @@ object LiteralValueProtoConverter { throw new IllegalArgumentException(s"literal $other not supported (yet).") } + if (mb.getKeysCount == 0 || getInferredDataType(mb.getKeysList.get(0)).isEmpty) { + mb.setKeyType(toConnectProtoType(keyType)) + } + + if (mb.getValuesCount == 0 || getInferredDataType(mb.getValuesList.get(0)).isEmpty) { + mb.setValueType(toConnectProtoType(valueType)) + } + mb } @@ -317,7 +329,10 @@ object LiteralValueProtoConverter { SparkIntervalUtils.microsToDuration(literal.getDayTimeInterval) case proto.Expression.Literal.LiteralTypeCase.ARRAY => - toCatalystArray(literal.getArray) + toCatalystArray(literal.getArray)._1 + + case proto.Expression.Literal.LiteralTypeCase.MAP => + toCatalystMap(literal.getMap)._1 case proto.Expression.Literal.LiteralTypeCase.STRUCT => toCatalystStruct(literal.getStruct)._1 @@ -351,8 +366,24 @@ object LiteralValueProtoConverter { v => val interval = v.getCalendarInterval new CalendarInterval(interval.getMonths, interval.getDays, interval.getMicroseconds) - case proto.DataType.KindCase.ARRAY => v => toCatalystArray(v.getArray) - case proto.DataType.KindCase.MAP => v => toCatalystMap(v.getMap) + case proto.DataType.KindCase.ARRAY => + if (inferDataType) { v => + { + val (array, arrayType) = toCatalystArray(v.getArray, None) + LiteralValueWithDataType(array, proto.DataType.newBuilder.setArray(arrayType).build()) + } + } else { v => + toCatalystArray(v.getArray, Some(dataType.getArray))._1 + } + case proto.DataType.KindCase.MAP => + if (inferDataType) { v => + { + val (map, mapType) = toCatalystMap(v.getMap, None) + LiteralValueWithDataType(map, proto.DataType.newBuilder.setMap(mapType).build()) + } + } else { v => + toCatalystMap(v.getMap, Some(dataType.getMap))._1 + } case proto.DataType.KindCase.STRUCT => if (inferDataType) { v => val (struct, structType) = toCatalystStruct(v.getStruct, None) @@ -398,9 +429,15 @@ object LiteralValueProtoConverter { builder.setTimestampNtz(proto.DataType.TimestampNTZ.newBuilder.build()) case proto.Expression.Literal.LiteralTypeCase.CALENDAR_INTERVAL => builder.setCalendarInterval(proto.DataType.CalendarInterval.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.ARRAY => + // Element type will be inferred from the elements in the array. + builder.setArray(proto.DataType.Array.newBuilder.build()) + case proto.Expression.Literal.LiteralTypeCase.MAP => + // Key and value types will be inferred from the keys and values in the map. + builder.setMap(proto.DataType.Map.newBuilder.build()) case proto.Expression.Literal.LiteralTypeCase.STRUCT => // The type of the fields will be inferred from the literals of the fields in the struct. - builder.setStruct(literal.getStruct.getStructType.getStruct) + builder.setStruct(proto.DataType.Struct.newBuilder.build()) case _ => // Not all data types support inferring the data type from the literal at the moment. // e.g. the type of DayTimeInterval contains extra information like start_field and @@ -412,44 +449,108 @@ object LiteralValueProtoConverter { private def getInferredDataTypeOrThrow(literal: proto.Expression.Literal): proto.DataType = { getInferredDataType(literal).getOrElse { - throw InvalidPlanInput( - s"Unsupported Literal type for data type inference: ${literal.getLiteralTypeCase}") + throw InvalidPlanInput(s"Unsupported Literal Type: ${literal.getLiteralTypeCase}") } } - def toCatalystArray(array: proto.Expression.Literal.Array): Array[_] = { - def makeArrayData[T](converter: proto.Expression.Literal => T)(implicit - tag: ClassTag[T]): Array[T] = { - val builder = mutable.ArrayBuilder.make[T] - val elementList = array.getElementsList - builder.sizeHint(elementList.size()) - val iter = elementList.iterator() - while (iter.hasNext) { - builder += converter(iter.next()) + def toCatalystArray( + array: proto.Expression.Literal.Array, + arrayTypeOpt: Option[proto.DataType.Array] = None): (Array[_], proto.DataType.Array) = { + def protoArrayType(elementType: proto.DataType): proto.DataType.Array = { + proto.DataType.Array.newBuilder().setElementType(elementType).build() + } + + val builder = mutable.ArrayBuilder.make[Any] + builder.sizeHint(array.getElementsList.size()) + + val iter = array.getElementsList.iterator() + + def inferDataTypeFromFirstElement(): proto.DataType.Array = { + if (arrayTypeOpt.isDefined) { + arrayTypeOpt.get + } else if (array.hasElementType) { + protoArrayType(array.getElementType) + } else if (iter.hasNext) { + val firstElement = iter.next() + val outerElementType = getInferredDataTypeOrThrow(firstElement) + val (elem, inferredElementType) = + getConverter(outerElementType, inferDataType = true)(firstElement) match { + case LiteralValueWithDataType(elem, dataType) => (elem, dataType) + case elem => (elem, outerElementType) + } + builder += elem + protoArrayType(inferredElementType) + } else { + throw InvalidPlanInput("Cannot infer element type for an empty array") } - builder.result() } - makeArrayData(getConverter(array.getElementType)) + val dataType = inferDataTypeFromFirstElement() + val converter = getConverter(dataType.getElementType) + + while (iter.hasNext) { + builder += converter(iter.next()) + } + + (builder.result(), dataType) } - def toCatalystMap(map: proto.Expression.Literal.Map): mutable.Map[_, _] = { - def makeMapData[K, V]( - keyConverter: proto.Expression.Literal => K, - valueConverter: proto.Expression.Literal => V)(implicit - tagK: ClassTag[K], - tagV: ClassTag[V]): mutable.Map[K, V] = { - val builder = mutable.HashMap.empty[K, V] - val keys = map.getKeysList.asScala - val values = map.getValuesList.asScala - builder.sizeHint(keys.size) - keys.zip(values).foreach { case (key, value) => - builder += ((keyConverter(key), valueConverter(value))) + def toCatalystMap( + map: proto.Expression.Literal.Map, + mapTypeOpt: Option[proto.DataType.Map] = None): (mutable.Map[_, _], proto.DataType.Map) = { + def protoMapType(keyType: proto.DataType, valueType: proto.DataType): proto.DataType.Map = { + proto.DataType.Map.newBuilder().setKeyType(keyType).setValueType(valueType).build() + } + val builder = mutable.HashMap.newBuilder[Any, Any] + val keyValuePairs = map.getKeysList.asScala.zip(map.getValuesList.asScala) + builder.sizeHint(keyValuePairs.size) + + val iter = keyValuePairs.iterator + + def inferDataTypeFromFirstPair(): proto.DataType.Map = { + if (mapTypeOpt.isDefined) { + mapTypeOpt.get + } else if (map.hasKeyType && map.hasValueType) { + protoMapType(map.getKeyType, map.getValueType) + } else if (iter.hasNext) { + val (key, value) = iter.next() + val (outerKeyType, inferKeyType) = if (map.hasKeyType) { + (map.getKeyType, false) + } else { + (getInferredDataTypeOrThrow(key), true) + } + val (catalystKey, inferredKeyType) = + getConverter(outerKeyType, inferDataType = inferKeyType)(key) match { + case LiteralValueWithDataType(catalystKey, dataType) => (catalystKey, dataType) + case catalystKey => (catalystKey, outerKeyType) + } + val (outerValueType, inferValueType) = if (map.hasValueType) { + (map.getValueType, false) + } else { + (getInferredDataTypeOrThrow(value), true) + } + val (catalystValue, inferredValueType) = + getConverter(outerValueType, inferDataType = inferValueType)(value) match { + case LiteralValueWithDataType(catalystValue, dataType) => (catalystValue, dataType) + case catalystValue => (catalystValue, outerValueType) + } + builder += ((catalystKey, catalystValue)) + protoMapType(inferredKeyType, inferredValueType) + } else { + throw InvalidPlanInput("Cannot infer key and value type for an empty map") } - builder } - makeMapData(getConverter(map.getKeyType), getConverter(map.getValueType)) + val dataType = inferDataTypeFromFirstPair() + val keyConverter = getConverter(dataType.getKeyType) + val valueConverter = getConverter(dataType.getValueType) + + while (iter.hasNext) { + val (key, value) = iter.next() + builder += ((keyConverter(key), valueConverter(value))) + } + + (builder.result(), dataType) } def toCatalystStruct( diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json index 57a17148abe44..13f7d042876a2 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.json @@ -358,10 +358,6 @@ }, { "literal": { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 8 }, { diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin index 6e4ec0e2ffef0..14540646a5067 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_lit.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json index 337b3366649f7..7299c7afb9b30 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.json @@ -40,41 +40,20 @@ }, { "literal": { "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "elements": [{ "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 1 }] } }, { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 2 }] } }, { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 3 }] @@ -102,37 +81,10 @@ }, { "literal": { "array": { - "elementType": { - "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, - "containsNull": true - } - }, "elements": [{ "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "elements": [{ "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 1 }] @@ -141,21 +93,8 @@ } }, { "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "elements": [{ "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 2 }] @@ -164,21 +103,8 @@ } }, { "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - }, - "containsNull": true - } - }, "elements": [{ "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 3 }] @@ -208,10 +134,6 @@ }, { "literal": { "array": { - "elementType": { - "boolean": { - } - }, "elements": [{ "boolean": true }, { @@ -260,10 +182,6 @@ }, { "literal": { "array": { - "elementType": { - "short": { - } - }, "elements": [{ "short": 9872 }, { @@ -293,10 +211,6 @@ }, { "literal": { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": -8726532 }, { @@ -326,10 +240,6 @@ }, { "literal": { "array": { - "elementType": { - "long": { - } - }, "elements": [{ "long": "7834609328726531" }, { @@ -359,10 +269,6 @@ }, { "literal": { "array": { - "elementType": { - "double": { - } - }, "elements": [{ "double": 2.718281828459045 }, { @@ -392,10 +298,6 @@ }, { "literal": { "array": { - "elementType": { - "float": { - } - }, "elements": [{ "float": -0.8 }, { @@ -592,10 +494,6 @@ }, { "literal": { "array": { - "elementType": { - "date": { - } - }, "elements": [{ "date": 18545 }, { @@ -623,10 +521,6 @@ }, { "literal": { "array": { - "elementType": { - "timestamp": { - } - }, "elements": [{ "timestamp": "1677155519808000" }, { @@ -654,10 +548,6 @@ }, { "literal": { "array": { - "elementType": { - "timestamp": { - } - }, "elements": [{ "timestamp": "12345000" }, { @@ -685,10 +575,6 @@ }, { "literal": { "array": { - "elementType": { - "timestampNtz": { - } - }, "elements": [{ "timestampNtz": "1677184560000000" }, { @@ -716,10 +602,6 @@ }, { "literal": { "array": { - "elementType": { - "date": { - } - }, "elements": [{ "date": 19411 }, { @@ -813,10 +695,6 @@ }, { "literal": { "array": { - "elementType": { - "calendarInterval": { - } - }, "elements": [{ "calendarInterval": { "months": 2, diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin index 320da10258180..0fa47f7d99a9b 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_lit_array.proto.bin differ diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json index bd9d6bb3c8bb7..2d54949b35ffa 100644 --- a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json +++ b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.json @@ -403,10 +403,6 @@ }, { "literal": { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 8 }, { @@ -700,10 +696,6 @@ }, { "literal": { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 1 }, { @@ -733,10 +725,6 @@ }, { "literal": { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 1 }, { @@ -771,10 +759,6 @@ "collation": "UTF8_BINARY" } }, - "valueType": { - "integer": { - } - }, "keys": [{ "string": "a" }, { @@ -875,10 +859,6 @@ }, { "literal": { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 1 }] @@ -904,14 +884,6 @@ }, { "literal": { "map": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "keys": [{ "integer": 1 }], @@ -943,14 +915,6 @@ }, { "literal": { "map": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "keys": [{ "integer": 1 }], @@ -982,14 +946,6 @@ }, { "literal": { "map": { - "keyType": { - "integer": { - } - }, - "valueType": { - "integer": { - } - }, "keys": [{ "integer": 1 }], @@ -1021,20 +977,8 @@ }, { "literal": { "array": { - "elementType": { - "array": { - "elementType": { - "integer": { - } - } - } - }, "elements": [{ "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 1 }, { @@ -1045,10 +989,6 @@ } }, { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 4 }, { @@ -1059,10 +999,6 @@ } }, { "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 7 }, { @@ -1094,19 +1030,6 @@ }, { "literal": { "array": { - "elementType": { - "map": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } - }, "elements": [{ "map": { "keyType": { @@ -1114,10 +1037,6 @@ "collation": "UTF8_BINARY" } }, - "valueType": { - "integer": { - } - }, "keys": [{ "string": "a" }, { @@ -1136,10 +1055,6 @@ "collation": "UTF8_BINARY" } }, - "valueType": { - "integer": { - } - }, "keys": [{ "string": "a" }, { @@ -1158,10 +1073,6 @@ "collation": "UTF8_BINARY" } }, - "valueType": { - "integer": { - } - }, "keys": [{ "string": "a" }, { @@ -1196,23 +1107,6 @@ }, { "literal": { "map": { - "keyType": { - "integer": { - } - }, - "valueType": { - "map": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } - }, "keys": [{ "integer": 1 }, { @@ -1225,10 +1119,6 @@ "collation": "UTF8_BINARY" } }, - "valueType": { - "integer": { - } - }, "keys": [{ "string": "a" }, { @@ -1247,10 +1137,6 @@ "collation": "UTF8_BINARY" } }, - "valueType": { - "integer": { - } - }, "keys": [{ "string": "a" }, { @@ -1287,10 +1173,6 @@ "struct": { "elements": [{ "array": { - "elementType": { - "integer": { - } - }, "elements": [{ "integer": 1 }, { @@ -1306,10 +1188,6 @@ "collation": "UTF8_BINARY" } }, - "valueType": { - "integer": { - } - }, "keys": [{ "string": "a" }, { @@ -1327,10 +1205,6 @@ "string": "a" }, { "map": { - "keyType": { - "integer": { - } - }, "valueType": { "string": { "collation": "UTF8_BINARY" @@ -1359,20 +1233,6 @@ "nullable": true }, { "name": "_2", - "dataType": { - "map": { - "keyType": { - "integer": { - } - }, - "valueType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueContainsNull": true - } - }, "nullable": true }] } @@ -1381,30 +1241,9 @@ "dataTypeStruct": { "fields": [{ "name": "_1", - "dataType": { - "array": { - "elementType": { - "integer": { - } - } - } - }, "nullable": true }, { "name": "_2", - "dataType": { - "map": { - "keyType": { - "string": { - "collation": "UTF8_BINARY" - } - }, - "valueType": { - "integer": { - } - } - } - }, "nullable": true }, { "name": "_3", diff --git a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin index da3a4a946d210..a3a7363f74e8a 100644 Binary files a/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin and b/sql/connect/common/src/test/resources/query-tests/queries/function_typedLit.proto.bin differ diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala index 0ab9105637291..66ae7a7dda0a4 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala @@ -187,7 +187,7 @@ private[ml] object MLUtils { * @return * the reconciled array */ - private def reconcileArray(elementType: Class[_], array: Array[_]): Array[_] = { + private[ml] def reconcileArray(elementType: Class[_], array: Array[_]): Array[_] = { if (elementType == classOf[Byte]) { array.map(_.asInstanceOf[Byte]) } else if (elementType == classOf[Short]) { @@ -204,6 +204,8 @@ private[ml] object MLUtils { array.map(_.asInstanceOf[String]) } else if (elementType.isArray && elementType.getComponentType == classOf[Double]) { array.map(_.asInstanceOf[Array[_]].map(_.asInstanceOf[Double])) + } else if (elementType.isArray && elementType.getComponentType == classOf[String]) { + array.map(_.asInstanceOf[Array[_]].map(_.asInstanceOf[String])) } else { throw MlUnsupportedException( s"array element type unsupported, " + diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala index df07dd42bc427..5015fb988e566 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.ml +import scala.language.existentials + import org.apache.spark.connect.proto import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param.Params @@ -165,18 +167,21 @@ private[ml] object Serializer { case proto.Expression.Literal.LiteralTypeCase.BOOLEAN => (literal.getBoolean.asInstanceOf[Object], classOf[Boolean]) case proto.Expression.Literal.LiteralTypeCase.ARRAY => - val array = literal.getArray - array.getElementType.getKindCase match { + val (catalystArray, arrayType) = + LiteralValueProtoConverter.toCatalystArray(literal.getArray) + arrayType.getElementType.getKindCase match { case proto.DataType.KindCase.DOUBLE => - (parseDoubleArray(array), classOf[Array[Double]]) + (MLUtils.reconcileArray(classOf[Double], catalystArray), classOf[Array[Double]]) case proto.DataType.KindCase.STRING => - (parseStringArray(array), classOf[Array[String]]) + (MLUtils.reconcileArray(classOf[String], catalystArray), classOf[Array[String]]) case proto.DataType.KindCase.ARRAY => - array.getElementType.getArray.getElementType.getKindCase match { + arrayType.getElementType.getArray.getElementType.getKindCase match { case proto.DataType.KindCase.STRING => - (parseStringArrayArray(array), classOf[Array[Array[String]]]) + ( + MLUtils.reconcileArray(classOf[Array[String]], catalystArray), + classOf[Array[Array[String]]]) case _ => - throw MlUnsupportedException(s"Unsupported inner array $array") + throw MlUnsupportedException(s"Unsupported inner array ${literal.getArray}") } case _ => throw MlUnsupportedException(s"Unsupported array $literal") @@ -193,37 +198,6 @@ private[ml] object Serializer { } } - private def parseDoubleArray(array: proto.Expression.Literal.Array): Array[Double] = { - val values = new Array[Double](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = array.getElements(i).getDouble - i += 1 - } - values - } - - private def parseStringArray(array: proto.Expression.Literal.Array): Array[String] = { - val values = new Array[String](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = array.getElements(i).getString - i += 1 - } - values - } - - private def parseStringArrayArray( - array: proto.Expression.Literal.Array): Array[Array[String]] = { - val values = new Array[Array[String]](array.getElementsCount) - var i = 0 - while (i < array.getElementsCount) { - values(i) = parseStringArray(array.getElements(i).getArray) - i += 1 - } - values - } - /** * Serialize an instance of "Params" which could be estimator/model/evaluator ... * @param instance diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala index 10f046a57da92..63ac4fc03a6a8 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.planner +import scala.language.existentials + import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter} @@ -105,16 +107,18 @@ object LiteralExpressionProtoConverter { expressions.Literal(lit.getTime.getNano, TimeType(precision)) case proto.Expression.Literal.LiteralTypeCase.ARRAY => + val (array, arrayType) = LiteralValueProtoConverter.toCatalystArray(lit.getArray) expressions.Literal.create( - LiteralValueProtoConverter.toCatalystArray(lit.getArray), - ArrayType(DataTypeProtoConverter.toCatalystType(lit.getArray.getElementType))) + array, + ArrayType(DataTypeProtoConverter.toCatalystType(arrayType.getElementType))) case proto.Expression.Literal.LiteralTypeCase.MAP => + val (map, mapType) = LiteralValueProtoConverter.toCatalystMap(lit.getMap) expressions.Literal.create( - LiteralValueProtoConverter.toCatalystMap(lit.getMap), + map, MapType( - DataTypeProtoConverter.toCatalystType(lit.getMap.getKeyType), - DataTypeProtoConverter.toCatalystType(lit.getMap.getValueType))) + DataTypeProtoConverter.toCatalystType(mapType.getKeyType), + DataTypeProtoConverter.toCatalystType(mapType.getValueType))) case proto.Expression.Literal.LiteralTypeCase.STRUCT => val (structData, structType) = LiteralValueProtoConverter.toCatalystStruct(lit.getStruct) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala index 559984e47cf8b..5aeffdabc5d2c 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/LiteralExpressionProtoConverterSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.planner import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter import org.apache.spark.sql.types._ @@ -56,15 +57,31 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i StructField("a", IntegerType), StructField( "b", - StructType( - Seq(StructField("c", IntegerType), StructField("d", IntegerType)))))))).zipWithIndex - .foreach { case ((v, t), idx) => + StructType(Seq(StructField("c", IntegerType), StructField("d", IntegerType))))))), + (Array(true, false, true), ArrayType(BooleanType)), + (Array(1.toByte, 2.toByte, 3.toByte), ArrayType(ByteType)), + (Array(1.toShort, 2.toShort, 3.toShort), ArrayType(ShortType)), + (Array(1, 2, 3), ArrayType(IntegerType)), + (Array(1L, 2L, 3L), ArrayType(LongType)), + (Array(1.1d, 2.1d, 3.1d), ArrayType(DoubleType)), + (Array(1.1f, 2.1f, 3.1f), ArrayType(FloatType)), + (Array(Array[Int](), Array(1, 2, 3), Array(4, 5, 6)), ArrayType(ArrayType(IntegerType))), + (Array(Array(1, 2, 3), Array(4, 5, 6), Array[Int]()), ArrayType(ArrayType(IntegerType))), + ( + Array(Array(Array(Array(Array(Array(1, 2, 3)))))), + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))), + (Array(Map(1 -> 2)), ArrayType(MapType(IntegerType, IntegerType))), + (Map[String, String]("1" -> "2", "3" -> "4"), MapType(StringType, StringType)), + (Map[String, Boolean]("1" -> true, "2" -> false), MapType(StringType, BooleanType)), + (Map[Int, Int](), MapType(IntegerType, IntegerType)), + (Map(1 -> 2, 3 -> 4, 5 -> 6), MapType(IntegerType, IntegerType))).zipWithIndex.foreach { + case ((v, t), idx) => test(s"complex proto value and catalyst value conversion #$idx") { assertResult(v)( LiteralValueProtoConverter.toCatalystValue( LiteralValueProtoConverter.toLiteralProto(v, t))) } - } + } test("backward compatibility for struct literal proto") { // Test the old way of defining structs with structType field and elements @@ -163,4 +180,75 @@ class LiteralExpressionProtoConverterSuite extends AnyFunSuite { // scalastyle:i assert(!structTypeProto.getFieldsList.get(1).getNullable) assert(!structTypeProto.getFieldsList.get(1).hasMetadata) } + + test("element type of array literal is set for an empty array") { + val literalProto = + LiteralValueProtoConverter.toLiteralProto(Array[Int](), ArrayType(IntegerType)) + assert(literalProto.getArray.hasElementType) + } + + test("element type of array literal is set for a non-empty array with non-inferable type") { + val literalProto = LiteralValueProtoConverter.toLiteralProto( + Array[String]("1", "2", "3"), + ArrayType(StringType)) + assert(literalProto.getArray.hasElementType) + } + + test("element type of array literal is not set for a non-empty array with inferable type") { + val literalProto = + LiteralValueProtoConverter.toLiteralProto(Array(1, 2, 3), ArrayType(IntegerType)) + assert(!literalProto.getArray.hasElementType) + } + + test("key and value type of map literal are set for an empty map") { + val literalProto = LiteralValueProtoConverter.toLiteralProto( + Map[Int, Int](), + MapType(IntegerType, IntegerType)) + assert(literalProto.getMap.hasKeyType) + assert(literalProto.getMap.hasValueType) + } + + test("key type of map literal is set for a non-empty map with non-inferable key type") { + val literalProto = LiteralValueProtoConverter.toLiteralProto( + Map[String, Int]("1" -> 1, "2" -> 2, "3" -> 3), + MapType(StringType, IntegerType)) + assert(literalProto.getMap.hasKeyType) + assert(!literalProto.getMap.hasValueType) + } + + test("value type of map literal is set for a non-empty map with non-inferable value type") { + val literalProto = LiteralValueProtoConverter.toLiteralProto( + Map[Int, String](1 -> "1", 2 -> "2", 3 -> "3"), + MapType(IntegerType, StringType)) + assert(!literalProto.getMap.hasKeyType) + assert(literalProto.getMap.hasValueType) + } + + test("key and value type of map literal are not set for a non-empty map with inferable types") { + val literalProto = LiteralValueProtoConverter.toLiteralProto( + Map(1 -> 2, 3 -> 4, 5 -> 6), + MapType(IntegerType, IntegerType)) + assert(!literalProto.getMap.hasKeyType) + assert(!literalProto.getMap.hasValueType) + } + + test("an invalid array literal") { + val literalProto = proto.Expression.Literal + .newBuilder() + .setArray(proto.Expression.Literal.Array.newBuilder()) + .build() + intercept[InvalidPlanInput] { + LiteralValueProtoConverter.toCatalystValue(literalProto) + } + } + + test("an invalid map literal") { + val literalProto = proto.Expression.Literal + .newBuilder() + .setMap(proto.Expression.Literal.Map.newBuilder()) + .build() + intercept[InvalidPlanInput] { + LiteralValueProtoConverter.toCatalystValue(literalProto) + } + } }