diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 0a94270dd89f3..dce1befe58240 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -114,6 +114,18 @@ ], "sqlState" : "22004" }, + "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH" : { + "message" : [ + "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." + ], + "sqlState" : "42846" + }, + "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH" : { + "message" : [ + "Combining approx_top_k sketches of different types is not allowed. Found sketches of type and ." + ], + "sqlState" : "42846" + }, "ARITHMETIC_OVERFLOW" : { "message" : [ ". If necessary set to \"false\" to bypass this error." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 76c3b1d80b294..89a8e8a9a6a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -529,6 +529,7 @@ object FunctionRegistry { expression[HllUnionAgg]("hll_union_agg"), expression[ApproxTopK]("approx_top_k"), expression[ApproxTopKAccumulate]("approx_top_k_accumulate"), + expression[ApproxTopKCombine]("approx_top_k_combine"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index 3c9440764a9a1..ed95a5265c2dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -76,35 +76,12 @@ case class ApproxTopKEstimate(state: Expression, k: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) - private def checkStateFieldAndType(state: Expression): TypeCheckResult = { - val stateStructType = state.dataType.asInstanceOf[StructType] - if (stateStructType.length != 3) { - return TypeCheckFailure("State must be a struct with 3 fields. " + - "Expected struct: struct. " + - "Got: " + state.dataType.simpleString) - } - - if (stateStructType.head.dataType != BinaryType) { - TypeCheckFailure("State struct must have the first field to be binary. " + - "Got: " + stateStructType.head.dataType.simpleString) - } else if (!ApproxTopK.isDataTypeSupported(itemDataType)) { - TypeCheckFailure("State struct must have the second field to be a supported data type. " + - "Got: " + itemDataType.simpleString) - } else if (stateStructType(2).dataType != IntegerType) { - TypeCheckFailure("State struct must have the third field to be int. " + - "Got: " + stateStructType(2).dataType.simpleString) - } else { - TypeCheckSuccess - } - } - - override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() if (defaultCheck.isFailure) { defaultCheck } else { - val stateCheck = checkStateFieldAndType(state) + val stateCheck = ApproxTopK.checkStateFieldAndType(state) if (stateCheck.isFailure) { stateCheck } else if (!k.foldable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index cefe0a14dee56..1989c190aa609 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -177,6 +177,8 @@ object ApproxTopK { val DEFAULT_K: Int = 5 val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 + val VOID_MAX_ITEMS_TRACKED = -1 + val SKETCH_SIZE_PLACEHOLDER = 8 def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { if (expr == null || expr.eval() == null) { @@ -318,7 +320,72 @@ object ApproxTopK { StructType( StructField("sketch", BinaryType, nullable = false) :: StructField("itemDataType", itemDataType) :: - StructField("maxItemsTracked", IntegerType, nullable = false) :: Nil) + StructField("maxItemsTracked", IntegerType, nullable = false) :: + StructField("typeCode", BinaryType, nullable = false) :: Nil) + + def dataTypeToBytes(dataType: DataType): Array[Byte] = { + dataType match { + case _: BooleanType => Array(0, 0, 0) + case _: ByteType => Array(1, 0, 0) + case _: ShortType => Array(2, 0, 0) + case _: IntegerType => Array(3, 0, 0) + case _: LongType => Array(4, 0, 0) + case _: FloatType => Array(5, 0, 0) + case _: DoubleType => Array(6, 0, 0) + case _: DateType => Array(7, 0, 0) + case _: TimestampType => Array(8, 0, 0) + case _: TimestampNTZType => Array(9, 0, 0) + case _: StringType => Array(10, 0, 0) + case dt: DecimalType => Array(11, dt.precision.toByte, dt.scale.toByte) + } + } + + def bytesToDataType(bytes: Array[Byte]): DataType = { + bytes(0) match { + case 0 => BooleanType + case 1 => ByteType + case 2 => ShortType + case 3 => IntegerType + case 4 => LongType + case 5 => FloatType + case 6 => DoubleType + case 7 => DateType + case 8 => TimestampType + case 9 => TimestampNTZType + case 10 => StringType + case 11 => DecimalType(bytes(1).toInt, bytes(2).toInt) + } + } + + def checkStateFieldAndType(state: Expression): TypeCheckResult = { + val stateStructType = state.dataType.asInstanceOf[StructType] + if (stateStructType.length != 4) { + return TypeCheckFailure("State must be a struct with 4 fields. " + + "Expected struct: " + + "struct. " + + "Got: " + state.dataType.simpleString) + } + + val fieldType1 = stateStructType.head.dataType + val fieldType2 = stateStructType(1).dataType + val fieldType3 = stateStructType(2).dataType + val fieldType4 = stateStructType(3).dataType + if (fieldType1 != BinaryType) { + TypeCheckFailure("State struct must have the first field to be binary. " + + "Got: " + fieldType1.simpleString) + } else if (!ApproxTopK.isDataTypeSupported(fieldType2)) { + TypeCheckFailure("State struct must have the second field to be a supported data type. " + + "Got: " + fieldType2.simpleString) + } else if (fieldType3 != IntegerType) { + TypeCheckFailure("State struct must have the third field to be int. " + + "Got: " + fieldType3.simpleString) + } else if (fieldType4 != BinaryType) { + TypeCheckFailure("State struct must have the fourth field to be binary. " + + "Got: " + fieldType4.simpleString) + } else { + TypeCheckSuccess + } + } } /** @@ -328,7 +395,11 @@ object ApproxTopK { * * The output of this function is a struct containing the sketch in binary format, * a null object indicating the type of items in the sketch, - * and the maximum number of items tracked by the sketch. + * the maximum number of items tracked by the sketch, + * and a binary typeCode encoding the data type of the items in the sketch. + * + * The null object is used in approx_top_k_estimate, + * while the typeCode is used in approx_top_k_combine. * * @param expr the child expression to accumulate items from * @param maxItemsTracked the maximum number of items to track in the sketch @@ -410,7 +481,8 @@ case class ApproxTopKAccumulate( override def eval(buffer: ItemsSketch[Any]): Any = { val sketchBytes = serialize(buffer) - InternalRow.apply(sketchBytes, null, maxItemsTrackedVal) + val typeCode = ApproxTopK.dataTypeToBytes(itemDataType) + InternalRow.apply(sketchBytes, null, maxItemsTrackedVal, typeCode) } override def serialize(buffer: ItemsSketch[Any]): Array[Byte] = @@ -435,3 +507,203 @@ case class ApproxTopKAccumulate( override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") } + +class CombineInternal[T]( + sketch: ItemsSketch[T], + var itemDataType: DataType, + var maxItemsTracked: Int) { + def getSketch: ItemsSketch[T] = sketch + + def getItemDataType: DataType = itemDataType + + def setItemDataType(dataType: DataType): Unit = { + if (this.itemDataType == null) { + this.itemDataType = dataType + } else if (this.itemDataType != dataType) { + throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType) + } + } + + def getMaxItemsTracked: Int = maxItemsTracked + + def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked +} + +/** + * An aggregate function that combines multiple sketches into a single sketch. + * + * @param state the expression containing the sketches to combine + * @param maxItemsTracked the maximum number of items to track in the sketch + * @param mutableAggBufferOffset the offset for mutable aggregation buffer + * @param inputAggBufferOffset the offset for input aggregation buffer + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single sketch. + `maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. + """, + examples = """ + Examples: + > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (2), (3), (4), (4) AS tab(expr)); + [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] + """, + group = "agg_funcs", + since = "4.1.0") +// scalastyle:on line.size.limit +case class ApproxTopKCombine( + state: Expression, + maxItemsTracked: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[CombineInternal[Any]] + with ImplicitCastInputTypes + with BinaryLike[Expression] { + + def this(child: Expression, maxItemsTracked: Expression) = { + this(child, maxItemsTracked, 0, 0) + ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) + } + + def this(child: Expression, maxItemsTracked: Int) = this(child, Literal(maxItemsTracked)) + + def this(child: Expression) = this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0) + + private lazy val uncheckedItemDataType: DataType = + state.dataType.asInstanceOf[StructType](1).dataType + private lazy val maxItemsTrackedVal: Int = maxItemsTracked.eval().asInstanceOf[Int] + private lazy val combineSizeSpecified: Boolean = + maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED + + override def left: Expression = state + + override def right: Expression = maxItemsTracked + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else { + val stateCheck = ApproxTopK.checkStateFieldAndType(state) + if (stateCheck.isFailure) { + stateCheck + } else if (!maxItemsTracked.foldable) { + TypeCheckFailure("Number of items tracked must be a constant literal") + } else { + TypeCheckSuccess + } + } + } + + override def dataType: DataType = ApproxTopK.getSketchStateDataType(uncheckedItemDataType) + + override def createAggregationBuffer(): CombineInternal[Any] = { + if (combineSizeSpecified) { + val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) + new CombineInternal[Any]( + new ItemsSketch[Any](maxMapSize), + null, + maxItemsTrackedVal) + } else { + new CombineInternal[Any]( + new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER), + null, + ApproxTopK.VOID_MAX_ITEMS_TRACKED) + } + } + + override def update(buffer: CombineInternal[Any], input: InternalRow): CombineInternal[Any] = { + val inputState = state.eval(input).asInstanceOf[InternalRow] + val inputSketchBytes = inputState.getBinary(0) + val inputMaxItemsTracked = inputState.getInt(2) + val typeCode = inputState.getBinary(3) + val actualItemDataType = ApproxTopK.bytesToDataType(typeCode) + buffer.setItemDataType(actualItemDataType) + val inputSketch = ItemsSketch.getInstance( + Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + buffer.getSketch.merge(inputSketch) + if (!combineSizeSpecified) { + buffer.setMaxItemsTracked(inputMaxItemsTracked) + } + buffer + } + + override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any]) + : CombineInternal[Any] = { + if (!combineSizeSpecified) { + // check size + if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) { + // If buffer is a placeholder sketch, set it to the input sketch's max items tracked + buffer.setMaxItemsTracked(input.getMaxItemsTracked) + } + if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) { + throw QueryExecutionErrors.approxTopKSketchSizeNotMatch( + buffer.getMaxItemsTracked, + input.getMaxItemsTracked + ) + } + } + // check item data type + if (buffer.getItemDataType != null && input.getItemDataType != null && + buffer.getItemDataType != input.getItemDataType) { + throw QueryExecutionErrors.approxTopKSketchTypeNotMatch( + buffer.getItemDataType, + input.getItemDataType + ) + } else if (buffer.getItemDataType == null) { + // If buffer is a placeholder sketch, set it to the input sketch's item data type + buffer.setItemDataType(input.getItemDataType) + } + buffer.getSketch.merge(input.getSketch) + buffer + } + + override def eval(buffer: CombineInternal[Any]): Any = { + val sketchBytes = + buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + val maxItemsTracked = buffer.getMaxItemsTracked + val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) + InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode) + } + + override def serialize(buffer: CombineInternal[Any]): Array[Byte] = { + val sketchBytes = buffer.getSketch.toByteArray( + ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte + val itemDataTypeBytes = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) + val byteArray = new Array[Byte](sketchBytes.length + 4) + byteArray(0) = maxItemsTrackedByte + System.arraycopy(itemDataTypeBytes, 0, byteArray, 1, itemDataTypeBytes.length) + System.arraycopy(sketchBytes, 0, byteArray, 4, sketchBytes.length) + byteArray + } + + override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = { + val maxItemsTracked = buffer(0).toInt + val itemDataTypeBytes = buffer.slice(1, 4) + val actualItemDataType = ApproxTopK.bytesToDataType(itemDataTypeBytes) + val sketchBytes = buffer.slice(4, buffer.length) + val sketch = ItemsSketch.getInstance( + Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(actualItemDataType)) + new CombineInternal[Any](sketch, actualItemDataType, maxItemsTracked) + } + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = + copy(state = newLeft, maxItemsTracked = newRight) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = false + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_combine") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 680d7b4b172db..bdd252c245620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -152,13 +152,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s: UTF8String, fmt: UTF8String, hint: String): SparkIllegalArgumentException = { - new SparkIllegalArgumentException( - errorClass = "CONVERSION_INVALID_INPUT", - messageParameters = Map( - "str" -> toSQLValue(s, StringType), - "fmt" -> toSQLValue(fmt, StringType), - "targetType" -> toSQLType(to), - "suggestion" -> toSQLId(hint))) + new SparkIllegalArgumentException( + errorClass = "CONVERSION_INVALID_INPUT", + messageParameters = Map( + "str" -> toSQLValue(s, StringType), + "fmt" -> toSQLValue(fmt, StringType), + "targetType" -> toSQLType(to), + "suggestion" -> toSQLId(hint))) } def cannotCastFromNullTypeError(to: DataType): Throwable = { @@ -249,7 +249,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def invalidBitmapPositionError(bitPosition: Long, - bitmapNumBytes: Long): ArrayIndexOutOfBoundsException = { + bitmapNumBytes: Long): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_BITMAP_POSITION", messageParameters = Map( @@ -346,7 +346,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE errorClass = "UNSUPPORTED_FEATURE.LITERAL_TYPE", messageParameters = Map( "value" -> v.toString, - "type" -> v.getClass.toString)) + "type" -> v.getClass.toString)) } def pivotColumnUnsupportedError(v: Any, expr: Expression): RuntimeException = { @@ -668,7 +668,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def withoutSuggestionIntervalArithmeticOverflowError( - context: QueryContext): SparkArithmeticException = { + context: QueryContext): SparkArithmeticException = { new SparkArithmeticException( errorClass = "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", messageParameters = Map(), @@ -728,7 +728,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "windowFunc" -> windowFuncList.map(toSQLStmt(_)).mkString(","), "columnName" -> columnNameList.map(toSQLId(_)).mkString(","), "windowSpec" -> windowSpecList.map(toSQLStmt(_)).mkString(",")), - origin = origin) + origin = origin) } def multiplePathsSpecifiedError(allPaths: Seq[String]): SparkIllegalArgumentException = { @@ -1421,7 +1421,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def arrayFunctionWithElementsExceedLimitError( - prettyName: String, numberOfElements: Long): SparkRuntimeException = { + prettyName: String, numberOfElements: Long): SparkRuntimeException = { new SparkRuntimeException( errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", messageParameters = Map( @@ -1432,7 +1432,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def createArrayWithElementsExceedLimitError( - prettyName: String, count: Any): SparkRuntimeException = { + prettyName: String, count: Any): SparkRuntimeException = { new SparkRuntimeException( errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", messageParameters = Map( @@ -1547,7 +1547,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE errorClass = "_LEGACY_ERROR_TEMP_2176", messageParameters = Map( "numElements" -> numElements.toString(), - "maxRoundedArrayLength"-> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), "additionalErrorMessage" -> additionalErrorMessage)) } @@ -1742,7 +1742,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def stateNotDefinedOrAlreadyRemovedError(): Throwable = { - new NoSuchElementException("State is either not defined or has already been removed") + new NoSuchElementException("State is either not defined or has already been removed") } def statefulOperatorNotMatchInStateMetadataError( @@ -1750,6 +1750,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE opsInCurBatchSeq: Map[Long, String]): SparkRuntimeException = { def formatPairString(pair: (Long, String)): String = s"(OperatorId: ${pair._1} -> OperatorName: ${pair._2})" + new SparkRuntimeException( errorClass = s"STREAMING_STATEFUL_OPERATOR_NOT_MATCH_IN_STATE_METADATA", messageParameters = Map( @@ -2645,7 +2646,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE acquiredThreadInfo: String, timeWaitedMs: Long, stackTraceOutput: String): Throwable = { - new SparkException ( + new SparkException( errorClass = "CANNOT_LOAD_STATE_STORE.UNRELEASED_THREAD_ERROR", messageParameters = Map( "loggingId" -> loggingId, @@ -2658,7 +2659,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def cannotReadCheckpoint(expectedVersion: String, actualVersion: String): Throwable = { - new SparkException ( + new SparkException( errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_CHECKPOINT", messageParameters = Map( "expectedVersion" -> expectedVersion, @@ -2667,7 +2668,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def cannotFindBaseSnapshotCheckpoint(lineage: String): Throwable = { - new SparkException ( + new SparkException( errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_FIND_BASE_SNAPSHOT_CHECKPOINT", messageParameters = Map("lineage" -> lineage), @@ -2814,6 +2815,22 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "limit" -> toSQLValue(limit, IntegerType))) } + def approxTopKSketchSizeNotMatch(size1: Int, size2: Int): Throwable = { + new SparkRuntimeException( + errorClass = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH", + messageParameters = Map( + "size1" -> toSQLValue(size1, IntegerType), + "size2" -> toSQLValue(size2, IntegerType))) + } + + def approxTopKSketchTypeNotMatch(type1: DataType, type2: DataType): Throwable = { + new SparkRuntimeException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + messageParameters = Map( + "type1" -> toSQLType(type1), + "type2" -> toSQLType(type2))) + } + def mergeCardinalityViolationError(): SparkRuntimeException = { new SparkRuntimeException( errorClass = "MERGE_CARDINALITY_VIOLATION", @@ -2840,7 +2857,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE val errorParmsMutable = collection.mutable.Map[String, String]() errorParms.foreach(StringType, StringType, { case (key, value) => errorParmsMutable += (key.toString -> - (if (value == null) { "null" } else { value.toString } )) + (if (value == null) { + "null" + } else { + value.toString + })) }) errorParmsMutable.toMap } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala index 2b339003abd4c..ffa0dabca9a97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -137,18 +137,14 @@ class ApproxTopKSuite extends SparkFunSuite { // ApproxTopKEstimate tests ///////////////////////////// - val stateStructType: StructType = StructType(Seq( - StructField("sketch", BinaryType), - StructField("itemDataType", IntegerType), - StructField("maxItemsTracked", IntegerType) - )) - test("SPARK-52588: invalid estimate if k are not foldable") { val badEstimate = ApproxTopKEstimate( state = BoundReference(0, StructType(Seq( StructField("sketch", BinaryType), StructField("itemDataType", IntegerType), - StructField("maxItemsTracked", IntegerType))), nullable = false), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) + )), nullable = false), k = Sum(BoundReference(1, IntegerType, nullable = true)) ) assert(badEstimate.checkInputDataTypes().isFailure) @@ -184,11 +180,11 @@ class ApproxTopKSuite extends SparkFunSuite { ) } - test("SPARK-52588: invalid estimate if state struct length is not 3") { + test("SPARK-52588: invalid estimate if state struct length is not 4") { val invalidState = StructType(Seq( StructField("sketch", BinaryType), StructField("itemDataType", IntegerType) - // Missing "maxItemsTracked" + // Missing "maxItemsTracked", "typeCode" fields )) val badEstimate = ApproxTopKEstimate( state = BoundReference(0, invalidState, nullable = false), @@ -196,8 +192,9 @@ class ApproxTopKSuite extends SparkFunSuite { ) assert(badEstimate.checkInputDataTypes().isFailure) assert(badEstimate.checkInputDataTypes() == - TypeCheckFailure("State must be a struct with 3 fields. " + - "Expected struct: struct. " + + TypeCheckFailure("State must be a struct with 4 fields. " + + "Expected struct: " + + "struct. " + "Got: struct")) } @@ -205,7 +202,8 @@ class ApproxTopKSuite extends SparkFunSuite { val invalidState = StructType(Seq( StructField("notSketch", IntegerType), // Should be BinaryType StructField("itemDataType", IntegerType), - StructField("maxItemsTracked", IntegerType) + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) )) val badEstimate = ApproxTopKEstimate( state = BoundReference(0, invalidState, nullable = false), @@ -227,7 +225,8 @@ class ApproxTopKSuite extends SparkFunSuite { val invalidState = StructType(Seq( StructField("sketch", BinaryType), StructField("itemDataType", dataType), - StructField("maxItemsTracked", IntegerType) + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) )) val badEstimate = ApproxTopKEstimate( state = BoundReference(0, invalidState, nullable = false), @@ -243,7 +242,8 @@ class ApproxTopKSuite extends SparkFunSuite { val invalidState = StructType(Seq( StructField("sketch", BinaryType), StructField("itemDataType", IntegerType), - StructField("maxItemsTracked", LongType) // Should be IntegerType + StructField("maxItemsTracked", LongType), // Should be IntegerType + StructField("typeCode", BinaryType) )) val badEstimate = ApproxTopKEstimate( state = BoundReference(0, invalidState, nullable = false), @@ -253,4 +253,156 @@ class ApproxTopKSuite extends SparkFunSuite { assert(badEstimate.checkInputDataTypes() == TypeCheckFailure("State struct must have the third field to be int. Got: bigint")) } + + test("SPARK-52588: invalid estimate if state struct's fourth field is not binary") { + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", StringType) // Should be BinaryType + )) + val badEstimate = ApproxTopKEstimate( + state = BoundReference(0, invalidState, nullable = false), + k = Literal(5) + ) + assert(badEstimate.checkInputDataTypes().isFailure) + assert(badEstimate.checkInputDataTypes() == + TypeCheckFailure("State struct must have the fourth field to be binary. Got: string")) + } + + ///////////////////////////// + // ApproxTopKCombine tests + ///////////////////////////// + test("SPARK-52798: invalid combine if maxItemsTracked is not foldable") { + val badCombine = ApproxTopKCombine( + state = BoundReference(0, StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) + )), nullable = false), + maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true)) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(1), + "requiredType" -> "\"INT\"", + "inputSql" -> "\"sum(boundreference())\"", + "inputType" -> "\"BIGINT\"" + ) + ) + ) + } + + test("SPARK-52798: invalid combine if state is not a struct") { + val badCombine = ApproxTopKCombine( + state = BoundReference(0, IntegerType, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(0), + "requiredType" -> "\"STRUCT\"", + "inputSql" -> "\"boundreference()\"", + "inputType" -> "\"INT\"" + ) + ) + ) + } + + test("SPARK-52798: invalid combine if state struct length is not 4") { + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType) + // Missing "maxItemsTracked", "typeCode" fields + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure("State must be a struct with 4 fields. " + + "Expected struct: " + + "struct. " + + "Got: struct")) + } + + test("SPARK-52798: invalid combine if state struct's first field is not binary") { + val invalidState = StructType(Seq( + StructField("sketch", IntegerType), // Should be BinaryType + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure("State struct must have the first field to be binary. Got: int")) + } + + gridTest("SPARK-52798: invalid combine if state struct's second field is not supported")( + Seq( + ("array", ArrayType(IntegerType)), + ("map", MapType(StringType, IntegerType)), + ("struct", StructType(Seq(StructField("a", IntegerType)))), + ("binary", BinaryType) + )) { unSupportedType => + val (typeName, dataType) = unSupportedType + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", dataType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure(s"State struct must have the second field to be a supported data type. " + + s"Got: $typeName")) + } + + test("SPARK-52798: invalid combine if state struct's third field is not int") { + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", LongType), // Should be IntegerType + StructField("typeCode", BinaryType) + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure("State struct must have the third field to be int. Got: bigint")) + } + + test("SPARK-52798: invalid combine if state struct's fourth field is not binary") { + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", StringType) // Should be BinaryType + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure("State struct must have the fourth field to be binary. Got: string")) + } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 3b0b21b9cd776..1812d98432322 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -403,6 +403,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue | any_value | SELECT any_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK | approx_top_k | SELECT approx_top_k(expr) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr) | struct>> | | org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopKAccumulate | approx_top_k_accumulate | SELECT approx_top_k_estimate(approx_top_k_accumulate(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr) | struct>> | +| org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopKCombine | approx_top_k_combine | SELECT approx_top_k_estimate(approx_top_k_combine(sketch, 10000), 5) FROM (SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (2), (3), (4), (4) AS tab(expr)) | struct>> | | org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile | approx_percentile | SELECT approx_percentile(col, array(0.5, 0.4, 0.1), 100) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile | percentile_approx | SELECT percentile_approx(col, array(0.5, 0.4, 0.1), 100) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.Average | avg | SELECT avg(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 8219fce9b2178..040254ef038cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -21,7 +21,10 @@ import java.sql.{Date, Timestamp} import java.time.LocalDateTime import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType} class ApproxTopKSuite extends QueryTest with SharedSparkSession { @@ -328,4 +331,274 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { parameters = Map("maxItemsTracked" -> "5", "k" -> "10") ) } + + ///////////////////////////////// + // approx_top_k_combine + ///////////////////////////////// + + def setupMixedSizeAccumulations(size1: Int, size2: Int): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, $size1) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql(s"SELECT approx_top_k_accumulate(expr, $size2) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") + } + + def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") + } + + val mixedNumberTypes: Seq[(DataType, String, Seq[Any])] = Seq( + (IntegerType, "INT", + Seq(0, 0, 0, 1, 1, 2, 2, 3)), + (ByteType, "TINYINT", + Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")), + (ShortType, "SMALLINT", + Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")), + (LongType, "BIGINT", + Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")), + (FloatType, "FLOAT", + Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")), + (DoubleType, "DOUBLE", + Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")), + (DecimalType(4, 2), "DECIMAL(4,2)", + Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", "cast(1 AS DECIMAL(4, 2))")), + (DecimalType(10, 2), "DECIMAL(10,2)", + Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", "cast(1 AS DECIMAL(10, 2))")), + (DecimalType(20, 3), "DECIMAL(20,3)", + Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS DECIMAL(20, 3))")) + ) + + val mixedDateTimeTypes: Seq[(DataType, String, Seq[String])] = Seq( + (DateType, "DATE", + Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")), + (TimestampType, "TIMESTAMP", + Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")), + (TimestampNTZType, "TIMESTAMP_NTZ", + Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01 00:00:00'") + ) + ) + + // positive tests for approx_top_k_combine on every types + gridTest("SPARK-52798: same type, same size, specified combine size - success")(itemsWithTopK) { + case (input, expected) => + sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS tab(expr);") + .createOrReplaceTempView("accumulation1") + sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS tab(expr);") + .createOrReplaceTempView("accumulation2") + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + // expected should be doubled because we combine two identical sketches + val expectedDoubled = expected.map { + case Row(value: Any, count: Int) => Row(value, count * 2) + } + checkAnswer(est, Row(expectedDoubled)) + } + + test("SPARK-52798: same type, same size, specified combine size - success") { + setupMixedSizeAccumulations(10, 10) + + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-52798: same type, same size, unspecified combine size - success") { + setupMixedSizeAccumulations(10, 10) + + sql("SELECT approx_top_k_combine(acc) as com FROM unioned") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-52798: same type, different size, specified combine size - success") { + setupMixedSizeAccumulations(10, 20) + + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") + .createOrReplaceTempView("combination") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-52798: same type, different size, unspecified combine size - fail") { + setupMixedSizeAccumulations(10, 20) + + val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned") + + checkError( + exception = intercept[SparkRuntimeException] { + comb.collect() + }, + condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH", + parameters = Map("size1" -> "10", "size2" -> "20") + ) + } + + gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10, 20))) { + case (size1, size2) => + setupMixedSizeAccumulations(size1, size2) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 0) as com FROM unioned").collect() + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0") + ) + } + + test("SPARK-52798: among different number or datetime types - fail at combine") { + def checkMixedTypeError(mixedTypeSeq: Seq[(DataType, String, Seq[Any])]): Unit = { + for (i <- 0 until mixedTypeSeq.size - 1) { + for (j <- i + 1 until mixedTypeSeq.size) { + val (type1, _, seq1) = mixedTypeSeq(i) + val (type2, _, seq2) = mixedTypeSeq(j) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(type2)) + ) + } + } + } + + checkMixedTypeError(mixedNumberTypes) + checkMixedTypeError(mixedDateTimeTypes) + } + + // enumerate all combinations of number and datetime types + gridTest("SPARK-52798: number vs datetime - fail on UNION")( + for { + (type1, typeName1, seq1) <- mixedNumberTypes + (type2, typeName2, seq2) <- mixedDateTimeTypes + } yield ((type1, typeName1, seq1), (type2, typeName2, seq2))) { + case ((_, type1, seq1), (_, type2, seq2)) => + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + gridTest("SPARK-52798: number vs string - fail at combine")(mixedNumberTypes) { + case (type1, _, seq1) => + setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) + ) + } + + gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypes) { + case (_, type1, seq1) => + val seq2 = Seq("(true)", "(true)", "(false)", "(false)") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + gridTest("SPARK-52798: datetime vs string - fail at combine")(mixedDateTimeTypes) { + case (type1, _, seq1) => + setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) + ) + } + + gridTest("SPARK-52798: datetime vs boolean - fail at UNION")(mixedDateTimeTypes) { + case (_, type1, seq1) => + val seq2 = Seq("(true)", "(true)", "(false)", "(false)") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + test("SPARK-52798: string vs boolean - fail at combine") { + val seq1 = Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'") + val seq2 = Seq("(true)", "(true)", "(false)", "(false)") + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", + parameters = Map("type1" -> toSQLType(StringType), "type2" -> toSQLType(BooleanType)) + ) + } }