From 206324bf7b3e57cd69828ee027fea26bb483388d Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 11:16:01 -0700 Subject: [PATCH 01/17] init commit for combine --- .../src/main/resources/error/error-conditions.json | 12 ++++++++++++ .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + 2 files changed, 13 insertions(+) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 0a94270dd89f3..3dc00d5cfb407 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -114,6 +114,18 @@ ], "sqlState" : "22004" }, + "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { + "message" : [ + "Combining approx_top_k sketches of different types is not allowed. Found sketches of type and ." + ], + "sqlState": "42846" + }, + "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED" : { + "message" : [ + "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." + ], + "sqlState": "42846" + }, "ARITHMETIC_OVERFLOW" : { "message" : [ ". If necessary set to \"false\" to bypass this error." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 76c3b1d80b294..89a8e8a9a6a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -529,6 +529,7 @@ object FunctionRegistry { expression[HllUnionAgg]("hll_union_agg"), expression[ApproxTopK]("approx_top_k"), expression[ApproxTopKAccumulate]("approx_top_k_accumulate"), + expression[ApproxTopKCombine]("approx_top_k_combine"), // string functions expression[Ascii]("ascii"), From cd8d647f31380badf7df263b86f64400f3fb9b1d Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 12:55:02 -0700 Subject: [PATCH 02/17] update state struct --- .../expressions/ApproxTopKExpressions.scala | 10 ++-- .../aggregate/ApproxTopKAggregates.scala | 46 +++++++++++++++++-- .../aggregate/ApproxTopKSuite.scala | 44 ++++++++++++------ 3 files changed, 80 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index 3c9440764a9a1..f6ae5df4b5825 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -78,9 +78,10 @@ case class ApproxTopKEstimate(state: Expression, k: Expression) private def checkStateFieldAndType(state: Expression): TypeCheckResult = { val stateStructType = state.dataType.asInstanceOf[StructType] - if (stateStructType.length != 3) { - return TypeCheckFailure("State must be a struct with 3 fields. " + - "Expected struct: struct. " + + if (stateStructType.length != 4) { + return TypeCheckFailure("State must be a struct with 4 fields. " + + "Expected struct: " + + "struct. " + "Got: " + state.dataType.simpleString) } @@ -93,6 +94,9 @@ case class ApproxTopKEstimate(state: Expression, k: Expression) } else if (stateStructType(2).dataType != IntegerType) { TypeCheckFailure("State struct must have the third field to be int. " + "Got: " + stateStructType(2).dataType.simpleString) + } else if (stateStructType(3).dataType != BinaryType) { + TypeCheckFailure("State struct must have the fourth field to be binary. " + + "Got: " + stateStructType(3).dataType.simpleString) } else { TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index cefe0a14dee56..24bdd511d8cf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -318,7 +318,42 @@ object ApproxTopK { StructType( StructField("sketch", BinaryType, nullable = false) :: StructField("itemDataType", itemDataType) :: - StructField("maxItemsTracked", IntegerType, nullable = false) :: Nil) + StructField("maxItemsTracked", IntegerType, nullable = false) :: + StructField("typeCode", BinaryType, nullable = false) :: Nil) + + def dataTypeToBytes(dataType: DataType): Array[Byte] = { + dataType match { + case _: BooleanType => Array(0, 0, 0) + case _: ByteType => Array(1, 0, 0) + case _: ShortType => Array(2, 0, 0) + case _: IntegerType => Array(3, 0, 0) + case _: LongType => Array(4, 0, 0) + case _: FloatType => Array(5, 0, 0) + case _: DoubleType => Array(6, 0, 0) + case _: DateType => Array(7, 0, 0) + case _: TimestampType => Array(8, 0, 0) + case _: TimestampNTZType => Array(9, 0, 0) + case _: StringType => Array(10, 0, 0) + case dt: DecimalType => Array(11, dt.precision.toByte, dt.scale.toByte) + } + } + + def bytesToDataType(bytes: Array[Byte]): DataType = { + bytes(0) match { + case 0 => BooleanType + case 1 => ByteType + case 2 => ShortType + case 3 => IntegerType + case 4 => LongType + case 5 => FloatType + case 6 => DoubleType + case 7 => DateType + case 8 => TimestampType + case 9 => TimestampNTZType + case 10 => StringType + case 11 => DecimalType(bytes(1).toInt, bytes(2).toInt) + } + } } /** @@ -328,7 +363,11 @@ object ApproxTopK { * * The output of this function is a struct containing the sketch in binary format, * a null object indicating the type of items in the sketch, - * and the maximum number of items tracked by the sketch. + * the maximum number of items tracked by the sketch, + * and a binary typeCode encoding the data type of the items in the sketch. + * + * The null object is used in approx_top_k_estimate, + * while the typeCode is used in approx_top_k_combine. * * @param expr the child expression to accumulate items from * @param maxItemsTracked the maximum number of items to track in the sketch @@ -410,7 +449,8 @@ case class ApproxTopKAccumulate( override def eval(buffer: ItemsSketch[Any]): Any = { val sketchBytes = serialize(buffer) - InternalRow.apply(sketchBytes, null, maxItemsTrackedVal) + val typeCode = ApproxTopK.dataTypeToBytes(itemDataType) + InternalRow.apply(sketchBytes, null, maxItemsTrackedVal, typeCode) } override def serialize(buffer: ItemsSketch[Any]): Array[Byte] = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala index 2b339003abd4c..f674f186d6b7d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -137,18 +137,14 @@ class ApproxTopKSuite extends SparkFunSuite { // ApproxTopKEstimate tests ///////////////////////////// - val stateStructType: StructType = StructType(Seq( - StructField("sketch", BinaryType), - StructField("itemDataType", IntegerType), - StructField("maxItemsTracked", IntegerType) - )) - test("SPARK-52588: invalid estimate if k are not foldable") { val badEstimate = ApproxTopKEstimate( state = BoundReference(0, StructType(Seq( StructField("sketch", BinaryType), StructField("itemDataType", IntegerType), - StructField("maxItemsTracked", IntegerType))), nullable = false), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) + )), nullable = false), k = Sum(BoundReference(1, IntegerType, nullable = true)) ) assert(badEstimate.checkInputDataTypes().isFailure) @@ -184,11 +180,11 @@ class ApproxTopKSuite extends SparkFunSuite { ) } - test("SPARK-52588: invalid estimate if state struct length is not 3") { + test("SPARK-52588: invalid estimate if state struct length is not 4") { val invalidState = StructType(Seq( StructField("sketch", BinaryType), StructField("itemDataType", IntegerType) - // Missing "maxItemsTracked" + // Missing "maxItemsTracked", "typeCode" fields )) val badEstimate = ApproxTopKEstimate( state = BoundReference(0, invalidState, nullable = false), @@ -196,8 +192,9 @@ class ApproxTopKSuite extends SparkFunSuite { ) assert(badEstimate.checkInputDataTypes().isFailure) assert(badEstimate.checkInputDataTypes() == - TypeCheckFailure("State must be a struct with 3 fields. " + - "Expected struct: struct. " + + TypeCheckFailure("State must be a struct with 4 fields. " + + "Expected struct: " + + "struct. " + "Got: struct")) } @@ -205,7 +202,8 @@ class ApproxTopKSuite extends SparkFunSuite { val invalidState = StructType(Seq( StructField("notSketch", IntegerType), // Should be BinaryType StructField("itemDataType", IntegerType), - StructField("maxItemsTracked", IntegerType) + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) )) val badEstimate = ApproxTopKEstimate( state = BoundReference(0, invalidState, nullable = false), @@ -227,7 +225,8 @@ class ApproxTopKSuite extends SparkFunSuite { val invalidState = StructType(Seq( StructField("sketch", BinaryType), StructField("itemDataType", dataType), - StructField("maxItemsTracked", IntegerType) + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) )) val badEstimate = ApproxTopKEstimate( state = BoundReference(0, invalidState, nullable = false), @@ -243,7 +242,8 @@ class ApproxTopKSuite extends SparkFunSuite { val invalidState = StructType(Seq( StructField("sketch", BinaryType), StructField("itemDataType", IntegerType), - StructField("maxItemsTracked", LongType) // Should be IntegerType + StructField("maxItemsTracked", LongType), // Should be IntegerType + StructField("typeCode", BinaryType) )) val badEstimate = ApproxTopKEstimate( state = BoundReference(0, invalidState, nullable = false), @@ -253,4 +253,20 @@ class ApproxTopKSuite extends SparkFunSuite { assert(badEstimate.checkInputDataTypes() == TypeCheckFailure("State struct must have the third field to be int. Got: bigint")) } + + test("SPARK-52588: invalid estimate if state struct's fourth field is not binary") { + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", StringType) // Should be BinaryType + )) + val badEstimate = ApproxTopKEstimate( + state = BoundReference(0, invalidState, nullable = false), + k = Literal(5) + ) + assert(badEstimate.checkInputDataTypes().isFailure) + assert(badEstimate.checkInputDataTypes() == + TypeCheckFailure("State struct must have the fourth field to be binary. Got: string")) + } } From c42fecc429db864446b634799d65a8782ddae97f Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 14:43:11 -0700 Subject: [PATCH 03/17] upload combine impl --- .../expressions/ApproxTopKExpressions.scala | 29 +-- .../aggregate/ApproxTopKAggregates.scala | 220 ++++++++++++++++++ 2 files changed, 221 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala index f6ae5df4b5825..ed95a5265c2dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApproxTopKExpressions.scala @@ -76,39 +76,12 @@ case class ApproxTopKEstimate(state: Expression, k: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) - private def checkStateFieldAndType(state: Expression): TypeCheckResult = { - val stateStructType = state.dataType.asInstanceOf[StructType] - if (stateStructType.length != 4) { - return TypeCheckFailure("State must be a struct with 4 fields. " + - "Expected struct: " + - "struct. " + - "Got: " + state.dataType.simpleString) - } - - if (stateStructType.head.dataType != BinaryType) { - TypeCheckFailure("State struct must have the first field to be binary. " + - "Got: " + stateStructType.head.dataType.simpleString) - } else if (!ApproxTopK.isDataTypeSupported(itemDataType)) { - TypeCheckFailure("State struct must have the second field to be a supported data type. " + - "Got: " + itemDataType.simpleString) - } else if (stateStructType(2).dataType != IntegerType) { - TypeCheckFailure("State struct must have the third field to be int. " + - "Got: " + stateStructType(2).dataType.simpleString) - } else if (stateStructType(3).dataType != BinaryType) { - TypeCheckFailure("State struct must have the fourth field to be binary. " + - "Got: " + stateStructType(3).dataType.simpleString) - } else { - TypeCheckSuccess - } - } - - override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() if (defaultCheck.isFailure) { defaultCheck } else { - val stateCheck = checkStateFieldAndType(state) + val stateCheck = ApproxTopK.checkStateFieldAndType(state) if (stateCheck.isFailure) { stateCheck } else if (!k.foldable) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 24bdd511d8cf4..5480ad2fcf2fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -21,6 +21,7 @@ import org.apache.datasketches.common._ import org.apache.datasketches.frequencies.{ErrorType, ItemsSketch} import org.apache.datasketches.memory.Memory +import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -177,6 +178,8 @@ object ApproxTopK { val DEFAULT_K: Int = 5 val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000 private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000 + val VOID_MAX_ITEMS_TRACKED = -1 + val SKETCH_SIZE_PLACEHOLDER = 8 def checkExpressionNotNull(expr: Expression, exprName: String): Unit = { if (expr == null || expr.eval() == null) { @@ -354,6 +357,36 @@ object ApproxTopK { case 11 => DecimalType(bytes(1).toInt, bytes(2).toInt) } } + + def checkStateFieldAndType(state: Expression): TypeCheckResult = { + val stateStructType = state.dataType.asInstanceOf[StructType] + if (stateStructType.length != 4) { + return TypeCheckFailure("State must be a struct with 4 fields. " + + "Expected struct: " + + "struct. " + + "Got: " + state.dataType.simpleString) + } + + val fieldType1 = stateStructType.head.dataType + val fieldType2 = stateStructType(1).dataType + val fieldType3 = stateStructType(2).dataType + val fieldType4 = stateStructType(3).dataType + if (fieldType1 != BinaryType) { + TypeCheckFailure("State struct must have the first field to be binary. " + + "Got: " + fieldType1.simpleString) + } else if (!ApproxTopK.isDataTypeSupported(fieldType2)) { + TypeCheckFailure("State struct must have the second field to be a supported data type. " + + "Got: " + fieldType2.simpleString) + } else if (fieldType3 != IntegerType) { + TypeCheckFailure("State struct must have the third field to be int. " + + "Got: " + fieldType3.simpleString) + } else if (fieldType4 != BinaryType) { + TypeCheckFailure("State struct must have the fourth field to be binary. " + + "Got: " + fieldType4.simpleString) + } else { + TypeCheckSuccess + } + } } /** @@ -475,3 +508,190 @@ case class ApproxTopKAccumulate( override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate") } + +class CombineInternal[T]( + sketch: ItemsSketch[T], + var itemDataType: DataType, + var maxItemsTracked: Int) { + def getSketch: ItemsSketch[T] = sketch + + def getItemDataType: DataType = itemDataType + + def setItemDataType(dataType: DataType): Unit = { + if (this.itemDataType == null) { + this.itemDataType = dataType + } else if (this.itemDataType != dataType) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + messageParameters = Map( + "type1" -> this.itemDataType.typeName, + "type2" -> dataType.typeName)) + } + } + + def getMaxItemsTracked: Int = maxItemsTracked + + def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked +} + +case class ApproxTopKCombine( + state: Expression, + maxItemsTracked: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[CombineInternal[Any]] + with ImplicitCastInputTypes + with BinaryLike[Expression] { + + def this(child: Expression, maxItemsTracked: Expression) = { + this(child, maxItemsTracked, 0, 0) + ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked") + ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal) + } + + def this(child: Expression, maxItemsTracked: Int) = this(child, Literal(maxItemsTracked)) + + def this(child: Expression) = this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0) + + private lazy val uncheckedItemDataType: DataType = + state.dataType.asInstanceOf[StructType](1).dataType + private lazy val maxItemsTrackedVal: Int = maxItemsTracked.eval().asInstanceOf[Int] + private lazy val combineSizeSpecified: Boolean = + maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED + + override def left: Expression = state + + override def right: Expression = maxItemsTracked + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else { + val stateCheck = ApproxTopK.checkStateFieldAndType(state) + if (stateCheck.isFailure) { + stateCheck + } else if (!maxItemsTracked.foldable) { + TypeCheckFailure("Number of items tracked must be a constant literal") + } else { + TypeCheckSuccess + } + } + } + + override def dataType: DataType = ApproxTopK.getSketchStateDataType(uncheckedItemDataType) + + override def createAggregationBuffer(): CombineInternal[Any] = { + if (combineSizeSpecified) { + val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal) + new CombineInternal[Any]( + new ItemsSketch[Any](maxMapSize), + null, + maxItemsTrackedVal) + } else { + new CombineInternal[Any]( + new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER), + null, + ApproxTopK.VOID_MAX_ITEMS_TRACKED) + } + } + + override def update(buffer: CombineInternal[Any], input: InternalRow): CombineInternal[Any] = { + val inputState = state.eval(input).asInstanceOf[InternalRow] + val inputSketchBytes = inputState.getBinary(0) + val inputMaxItemsTracked = inputState.getInt(2) + val typeCode = inputState.getBinary(3) + val actualItemDataType = ApproxTopK.bytesToDataType(typeCode) + buffer.setItemDataType(actualItemDataType) + val inputSketch = ItemsSketch.getInstance( + Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + buffer.getSketch.merge(inputSketch) + if (!combineSizeSpecified) { + buffer.setMaxItemsTracked(inputMaxItemsTracked) + } + buffer + } + + override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any]) + : CombineInternal[Any] = { + if (!combineSizeSpecified) { + // check size + if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) { + // If buffer is a placeholder sketch, set it to the input sketch's max items tracked + buffer.setMaxItemsTracked(input.getMaxItemsTracked) + } + if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + messageParameters = Map( + "size1" -> buffer.getMaxItemsTracked.toString, + "size2" -> input.getMaxItemsTracked.toString)) + } + } + // check item data type + if (buffer.getItemDataType != null && input.getItemDataType != null && + buffer.getItemDataType != input.getItemDataType) { + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + messageParameters = Map( + "type1" -> buffer.getItemDataType.typeName, + "type2" -> input.getItemDataType.typeName)) + } else if (buffer.getItemDataType == null) { + // If buffer is a placeholder sketch, set it to the input sketch's item data type + buffer.setItemDataType(input.getItemDataType) + } + buffer.getSketch.merge(input.getSketch) + buffer + } + + override def eval(buffer: CombineInternal[Any]): Any = { + val sketchBytes = try { + buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + } catch { + case _: ArrayStoreException => + throw new SparkUnsupportedOperationException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" + ) + } + val maxItemsTracked = buffer.getMaxItemsTracked + val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) + InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode) + } + + override def serialize(buffer: CombineInternal[Any]): Array[Byte] = { + val sketchBytes = buffer.getSketch.toByteArray( + ApproxTopK.genSketchSerDe(buffer.getItemDataType)) + val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte + val itemDataTypeBytes = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) + val byteArray = new Array[Byte](sketchBytes.length + 4) + byteArray(0) = maxItemsTrackedByte + System.arraycopy(itemDataTypeBytes, 0, byteArray, 1, itemDataTypeBytes.length) + System.arraycopy(sketchBytes, 0, byteArray, 4, sketchBytes.length) + byteArray + } + + override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = { + val maxItemsTracked = buffer(0).toInt + val itemDataTypeBytes = buffer.slice(1, 4) + val actualItemDataType = ApproxTopK.bytesToDataType(itemDataTypeBytes) + val sketchBytes = buffer.slice(4, buffer.length) + val sketch = ItemsSketch.getInstance( + Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(actualItemDataType)) + new CombineInternal[Any](sketch, actualItemDataType, maxItemsTracked) + } + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): Expression = + copy(state = newLeft, maxItemsTracked = newRight) + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = false +} From 1cd9704c51e7eb119738c911d9628f282a1447bf Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 16:16:11 -0700 Subject: [PATCH 04/17] add combine tests --- .../aggregate/ApproxTopKSuite.scala | 2 + .../apache/spark/sql/ApproxTopKSuite.scala | 216 +++++++++++++++++- 2 files changed, 217 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala index f674f186d6b7d..e52583e7d0b4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -269,4 +269,6 @@ class ApproxTopKSuite extends SparkFunSuite { assert(badEstimate.checkInputDataTypes() == TypeCheckFailure("State struct must have the fourth field to be binary. Got: string")) } + + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 8219fce9b2178..08e9cd4ca883d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.time.LocalDateTime -import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} +import org.apache.spark.{SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{ByteType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType} class ApproxTopKSuite extends QueryTest with SharedSparkSession { @@ -328,4 +330,216 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { parameters = Map("maxItemsTracked" -> "5", "k" -> "10") ) } + + ///////////////////////////////// + // approx_top_k_combine + ///////////////////////////////// + + def setupAccumulations(size1: Int, size2: Int): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, $size1) as acc " + + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql(s"SELECT approx_top_k_accumulate(expr, $size2) as acc " + + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") + .createOrReplaceTempView("accumulation2") + } + + test("SPARK-52798: same type, same size, specified combine size - success") { + setupAccumulations(10, 10) + + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-52798: same type, same size, unspecified combine size - success") { + setupAccumulations(10, 10) + + sql("SELECT approx_top_k_combine(acc) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-52798: same type, different size, specified combine size - success") { + setupAccumulations(10, 20) + + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combination") + + val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") + checkAnswer(est, Row(Seq(Row(2, 4), Row(1, 4), Row(0, 3), Row(3, 3), Row(4, 2)))) + } + + test("SPARK-52798: same type, different size, unspecified combine size - fail") { + setupAccumulations(10, 20) + + val comb = sql("SELECT approx_top_k_combine(acc) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + + checkError( + exception = intercept[SparkUnsupportedOperationException] { + comb.collect() + }, + condition = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + parameters = Map("size1" -> "10", "size2" -> "20") + ) + } + + gridTest("SPARK-combine: invalid combine size - fail")(Seq((10, 10), (10, 20))) { + case (size1, size2) => + setupAccumulations(size1, size2) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 0) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .collect() + }, + condition = "APPROX_TOP_K_NON_POSITIVE_ARG", + parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0") + ) + } + + def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") + } + + + val mixedNumberTypeSeqs: Seq[(String, String, Seq[Any])] = Seq( + (IntegerType.typeName, "INT", + Seq(0, 0, 0, 1, 1, 2, 2, 3)), + (ByteType.typeName, "TINYINT", + Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")), + (ShortType.typeName, "SMALLINT", + Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")), + (LongType.typeName, "BIGINT", + Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")), + (FloatType.typeName, "FLOAT", + Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")), + (DoubleType.typeName, "DOUBLE", + Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")), + (DecimalType(4, 2).typeName, "DECIMAL(4,2)", + Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", "cast(1 AS DECIMAL(4, 2))")), + (DecimalType(10, 2).typeName, "DECIMAL(10,2)", + Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", "cast(1 AS DECIMAL(10, 2))")), + (DecimalType(20, 3).typeName, "DECIMAL(20,3)", + Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS DECIMAL(20, 3))")) + ) + + val mixedDateTimeSeqs: Seq[(String, String, Seq[String])] = Seq( + (DateType.typeName, "DATE", + Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")), + (TimestampType.typeName, "TIMESTAMP", + Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")), + (TimestampNTZType.typeName, "TIMESTAMP_NTZ", + Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01 00:00:00'") + ) + ) + + def checkMixedTypeError(mixedTypeSeq: Seq[(String, String, Seq[Any])]): Unit = { + for (i <- 0 until mixedTypeSeq.size - 1) { + for (j <- i + 1 until mixedTypeSeq.size) { + val (type1, _, seq1) = mixedTypeSeq(i) + val (type2, _, seq2) = mixedTypeSeq(j) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> type2) + ) + } + } + } + + test("SPARK-combine: among different number or datetime types - fail") { + checkMixedTypeError(mixedNumberTypeSeqs) + checkMixedTypeError(mixedDateTimeSeqs) + } + + gridTest("SPARK-combine: string vs number - fail")(mixedNumberTypeSeqs) { + case (type1, _, seq1) => + setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> StringType.typeName) + ) + } + + + gridTest("SPARK-combine: different types - fail on UNION")( + Seq( + ("INT", Seq(0, 0, 0, 1, 1, 2, 2, 3), + "BOOLEAN", Seq("(true)", "(true)", "(false)", "(false)")), + ("INT", Seq(0, 0, 0, 1, 1, 2, 2, 3), + "DATE", Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")), + ("BIGINT", Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)"), + "TIMESTAMP", Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")) + )) { + case (type1, seq1, type2, seq2) => + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + gridTest("SPARK-combine: boolean and number types - fail on UNION")(mixedNumberTypeSeqs) { + numberItems => + val (_, type1, seq1) = numberItems + val seq2 = Seq("(true)", "(true)", "(false)", "(false)") + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } } From c41d81cd464ddacbedc1527f55a8661338357e37 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 17:05:51 -0700 Subject: [PATCH 05/17] all type mismatch tests --- .../apache/spark/sql/ApproxTopKSuite.scala | 79 ++++++++++++++----- 1 file changed, 59 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 08e9cd4ca883d..3860e09c0857f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -420,7 +420,6 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { .createOrReplaceTempView("unioned") } - val mixedNumberTypeSeqs: Seq[(String, String, Seq[Any])] = Seq( (IntegerType.typeName, "INT", Seq(0, 0, 0, 1, 1, 2, 2, 3)), @@ -469,12 +468,12 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { } } - test("SPARK-combine: among different number or datetime types - fail") { + test("SPARK-combine: among different number or datetime types - fail at combine") { checkMixedTypeError(mixedNumberTypeSeqs) checkMixedTypeError(mixedDateTimeSeqs) } - gridTest("SPARK-combine: string vs number - fail")(mixedNumberTypeSeqs) { + gridTest("SPARK-combine: string vs number - fail at combine")(mixedNumberTypeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -486,17 +485,9 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - - gridTest("SPARK-combine: different types - fail on UNION")( - Seq( - ("INT", Seq(0, 0, 0, 1, 1, 2, 2, 3), - "BOOLEAN", Seq("(true)", "(true)", "(false)", "(false)")), - ("INT", Seq(0, 0, 0, 1, 1, 2, 2, 3), - "DATE", Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")), - ("BIGINT", Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)"), - "TIMESTAMP", Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")) - )) { - case (type1, seq1, type2, seq2) => + gridTest("SPARK-combine: boolean vs number - fail at UNION")(mixedNumberTypeSeqs) { + case (_, type1, seq1) => + val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( exception = intercept[ExtendedAnalysisException] { setupMixedTypeAccumulation(seq1, seq2) @@ -505,11 +496,11 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { parameters = Map( "tableOrdinalNumber" -> "second", "columnOrdinalNumber" -> "first", - "dataType2" -> ("\"STRUCT\""), + "dataType2" -> ("\"STRUCT\""), "operator" -> "UNION", "hint" -> "", - "dataType1" -> ("\"STRUCT ("\"STRUCT\"") ), queryContext = Array( @@ -518,9 +509,20 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: boolean and number types - fail on UNION")(mixedNumberTypeSeqs) { - numberItems => - val (_, type1, seq1) = numberItems + gridTest("SPARK-combine: string vs datetime - fail at combine")(mixedDateTimeSeqs) { + case (type1, _, seq1) => + setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> StringType.typeName) + ) + } + + gridTest("SPARK-combine: boolean vs datetime - fail at UNION")(mixedDateTimeSeqs) { + case (_, type1, seq1) => val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( exception = intercept[ExtendedAnalysisException] { @@ -542,4 +544,41 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) ) } + + test("SPARK-combine: string vs boolean - fail at combine") { + val seq1 = Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'") + val seq2 = Seq("(true)", "(true)", "(false)", "(false)") + setupMixedTypeAccumulation(seq1, seq2) + } + + // enumerate all combinations of number and datetime types + val mixedNumberAndDateTimeSeqs: Seq[((String, String, Seq[Any]), (String, String, Seq[Any]))] = + for { + (type1, typeName1, seq1) <- mixedNumberTypeSeqs + (type2, typeName2, seq2) <- mixedDateTimeSeqs + } yield ((type1, typeName1, seq1), (type2, typeName2, seq2)) + + + gridTest("SPARK-combine: number vs datetime - fail on UNION")(mixedNumberAndDateTimeSeqs) { + case ((_, type1, seq1), (_, type2, seq2)) => + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } } From fa6e86768f98a1ce3d8d943e8a99b0ac6397dcfe Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 17:39:08 -0700 Subject: [PATCH 06/17] all type mismatch tests --- .../apache/spark/sql/ApproxTopKSuite.scala | 185 +++++++++--------- 1 file changed, 91 insertions(+), 94 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 3860e09c0857f..3e17931b09b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -345,6 +345,50 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { .createOrReplaceTempView("accumulation2") } + def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = { + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation1") + + sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + + s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);") + .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") + } + + val mixedNumberTypeSeqs: Seq[(String, String, Seq[Any])] = Seq( + (IntegerType.typeName, "INT", + Seq(0, 0, 0, 1, 1, 2, 2, 3)), + (ByteType.typeName, "TINYINT", + Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")), + (ShortType.typeName, "SMALLINT", + Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")), + (LongType.typeName, "BIGINT", + Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")), + (FloatType.typeName, "FLOAT", + Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")), + (DoubleType.typeName, "DOUBLE", + Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")), + (DecimalType(4, 2).typeName, "DECIMAL(4,2)", + Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", "cast(1 AS DECIMAL(4, 2))")), + (DecimalType(10, 2).typeName, "DECIMAL(10,2)", + Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", "cast(1 AS DECIMAL(10, 2))")), + (DecimalType(20, 3).typeName, "DECIMAL(20,3)", + Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS DECIMAL(20, 3))")) + ) + + val mixedDateTimeSeqs: Seq[(String, String, Seq[String])] = Seq( + (DateType.typeName, "DATE", + Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")), + (TimestampType.typeName, "TIMESTAMP", + Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")), + (TimestampNTZType.typeName, "TIMESTAMP_NTZ", + Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01 00:00:00'") + ) + ) + test("SPARK-52798: same type, same size, specified combine size - success") { setupAccumulations(10, 10) @@ -407,73 +451,57 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = { - sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + - s"FROM VALUES ${seq1.mkString(", ")} AS tab(expr);") - .createOrReplaceTempView("accumulation1") - - sql(s"SELECT approx_top_k_accumulate(expr, 10) as acc " + - s"FROM VALUES ${seq2.mkString(", ")} AS tab(expr);") - .createOrReplaceTempView("accumulation2") - - sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") - .createOrReplaceTempView("unioned") - } - - val mixedNumberTypeSeqs: Seq[(String, String, Seq[Any])] = Seq( - (IntegerType.typeName, "INT", - Seq(0, 0, 0, 1, 1, 2, 2, 3)), - (ByteType.typeName, "TINYINT", - Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")), - (ShortType.typeName, "SMALLINT", - Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")), - (LongType.typeName, "BIGINT", - Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")), - (FloatType.typeName, "FLOAT", - Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")), - (DoubleType.typeName, "DOUBLE", - Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")), - (DecimalType(4, 2).typeName, "DECIMAL(4,2)", - Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", "cast(1 AS DECIMAL(4, 2))")), - (DecimalType(10, 2).typeName, "DECIMAL(10,2)", - Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", "cast(1 AS DECIMAL(10, 2))")), - (DecimalType(20, 3).typeName, "DECIMAL(20,3)", - Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS DECIMAL(20, 3))")) - ) - - val mixedDateTimeSeqs: Seq[(String, String, Seq[String])] = Seq( - (DateType.typeName, "DATE", - Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")), - (TimestampType.typeName, "TIMESTAMP", - Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")), - (TimestampNTZType.typeName, "TIMESTAMP_NTZ", - Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01 00:00:00'") - ) - ) - - def checkMixedTypeError(mixedTypeSeq: Seq[(String, String, Seq[Any])]): Unit = { - for (i <- 0 until mixedTypeSeq.size - 1) { - for (j <- i + 1 until mixedTypeSeq.size) { - val (type1, _, seq1) = mixedTypeSeq(i) - val (type2, _, seq2) = mixedTypeSeq(j) - setupMixedTypeAccumulation(seq1, seq2) - checkError( - exception = intercept[SparkUnsupportedOperationException] { - sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() - }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", - parameters = Map("type1" -> type1, "type2" -> type2) - ) + test("SPARK-combine: among different number or datetime types - fail at combine") { + def checkMixedTypeError(mixedTypeSeq: Seq[(String, String, Seq[Any])]): Unit = { + for (i <- 0 until mixedTypeSeq.size - 1) { + for (j <- i + 1 until mixedTypeSeq.size) { + val (type1, _, seq1) = mixedTypeSeq(i) + val (type2, _, seq2) = mixedTypeSeq(j) + setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkUnsupportedOperationException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> type1, "type2" -> type2) + ) + } } } - } - test("SPARK-combine: among different number or datetime types - fail at combine") { checkMixedTypeError(mixedNumberTypeSeqs) checkMixedTypeError(mixedDateTimeSeqs) } - gridTest("SPARK-combine: string vs number - fail at combine")(mixedNumberTypeSeqs) { + // enumerate all combinations of number and datetime types + gridTest("SPARK-combine: number vs datetime - fail on UNION")( + for { + (type1, typeName1, seq1) <- mixedNumberTypeSeqs + (type2, typeName2, seq2) <- mixedDateTimeSeqs + } yield ((type1, typeName1, seq1), (type2, typeName2, seq2))) { + case ((_, type1, seq1), (_, type2, seq2)) => + checkError( + exception = intercept[ExtendedAnalysisException] { + setupMixedTypeAccumulation(seq1, seq2) + }, + condition = "INCOMPATIBLE_COLUMN_TYPE", + parameters = Map( + "tableOrdinalNumber" -> "second", + "columnOrdinalNumber" -> "first", + "dataType2" -> ("\"STRUCT\""), + "operator" -> "UNION", + "hint" -> "", + "dataType1" -> ("\"STRUCT\"") + ), + queryContext = Array( + ExpectedContext( + "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) + ) + } + + gridTest("SPARK-combine: number vs string - fail at combine")(mixedNumberTypeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -485,7 +513,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: boolean vs number - fail at UNION")(mixedNumberTypeSeqs) { + gridTest("SPARK-combine: number vs boolean - fail at UNION")(mixedNumberTypeSeqs) { case (_, type1, seq1) => val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( @@ -509,7 +537,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: string vs datetime - fail at combine")(mixedDateTimeSeqs) { + gridTest("SPARK-combine: datetime vs string - fail at combine")(mixedDateTimeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -521,7 +549,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: boolean vs datetime - fail at UNION")(mixedDateTimeSeqs) { + gridTest("SPARK-combine: datetime vs boolean - fail at UNION")(mixedDateTimeSeqs) { case (_, type1, seq1) => val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( @@ -550,35 +578,4 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { val seq2 = Seq("(true)", "(true)", "(false)", "(false)") setupMixedTypeAccumulation(seq1, seq2) } - - // enumerate all combinations of number and datetime types - val mixedNumberAndDateTimeSeqs: Seq[((String, String, Seq[Any]), (String, String, Seq[Any]))] = - for { - (type1, typeName1, seq1) <- mixedNumberTypeSeqs - (type2, typeName2, seq2) <- mixedDateTimeSeqs - } yield ((type1, typeName1, seq1), (type2, typeName2, seq2)) - - - gridTest("SPARK-combine: number vs datetime - fail on UNION")(mixedNumberAndDateTimeSeqs) { - case ((_, type1, seq1), (_, type2, seq2)) => - checkError( - exception = intercept[ExtendedAnalysisException] { - setupMixedTypeAccumulation(seq1, seq2) - }, - condition = "INCOMPATIBLE_COLUMN_TYPE", - parameters = Map( - "tableOrdinalNumber" -> "second", - "columnOrdinalNumber" -> "first", - "dataType2" -> ("\"STRUCT\""), - "operator" -> "UNION", - "hint" -> "", - "dataType1" -> ("\"STRUCT\"") - ), - queryContext = Array( - ExpectedContext( - "SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2", 0, 68)) - ) - } } From f7b4d90b5ab932568deb32935f96794bbe6e4225 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 17:48:02 -0700 Subject: [PATCH 07/17] sql tests --- .../apache/spark/sql/ApproxTopKSuite.scala | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 3e17931b09b16..99e6ec26137e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -389,6 +389,24 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) ) + // positive tests for approx_top_k_combine on every types + gridTest("SPARK-52798: same type, same size, specified combine size - success")(itemsWithTopK) { + case (input, expected) => + sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS tab(expr);") + .createOrReplaceTempView("accumulation1") + sql(s"SELECT approx_top_k_accumulate(expr) AS acc FROM VALUES $input AS tab(expr);") + .createOrReplaceTempView("accumulation2") + sql("SELECT approx_top_k_combine(acc, 30) as com " + + "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + .createOrReplaceTempView("combined") + val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") + // expected should be doubled because we combine two identical sketches + val expectedDoubled = expected.map { + case Row(value: Any, count: Int) => Row(value, count * 2) + } + checkAnswer(est, Row(expectedDoubled)) + } + test("SPARK-52798: same type, same size, specified combine size - success") { setupAccumulations(10, 10) @@ -437,7 +455,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: invalid combine size - fail")(Seq((10, 10), (10, 20))) { + gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10, 20))) { case (size1, size2) => setupAccumulations(size1, size2) checkError( @@ -451,7 +469,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - test("SPARK-combine: among different number or datetime types - fail at combine") { + test("SPARK-52798: among different number or datetime types - fail at combine") { def checkMixedTypeError(mixedTypeSeq: Seq[(String, String, Seq[Any])]): Unit = { for (i <- 0 until mixedTypeSeq.size - 1) { for (j <- i + 1 until mixedTypeSeq.size) { @@ -474,7 +492,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { } // enumerate all combinations of number and datetime types - gridTest("SPARK-combine: number vs datetime - fail on UNION")( + gridTest("SPARK-52798: number vs datetime - fail on UNION")( for { (type1, typeName1, seq1) <- mixedNumberTypeSeqs (type2, typeName2, seq2) <- mixedDateTimeSeqs @@ -501,7 +519,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: number vs string - fail at combine")(mixedNumberTypeSeqs) { + gridTest("SPARK-52798: number vs string - fail at combine")(mixedNumberTypeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -513,7 +531,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: number vs boolean - fail at UNION")(mixedNumberTypeSeqs) { + gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypeSeqs) { case (_, type1, seq1) => val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( @@ -537,7 +555,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: datetime vs string - fail at combine")(mixedDateTimeSeqs) { + gridTest("SPARK-52798: datetime vs string - fail at combine")(mixedDateTimeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -549,7 +567,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-combine: datetime vs boolean - fail at UNION")(mixedDateTimeSeqs) { + gridTest("SPARK-52798: datetime vs boolean - fail at UNION")(mixedDateTimeSeqs) { case (_, type1, seq1) => val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( @@ -573,7 +591,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - test("SPARK-combine: string vs boolean - fail at combine") { + test("SPARK-52798: string vs boolean - fail at combine") { val seq1 = Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'") val seq2 = Seq("(true)", "(true)", "(false)", "(false)") setupMixedTypeAccumulation(seq1, seq2) From 04cc87bfbf73a2ba90dc0ff9b181af2e83be4bb0 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 17:59:06 -0700 Subject: [PATCH 08/17] expression tests --- .../aggregate/ApproxTopKSuite.scala | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala index e52583e7d0b4b..ffa0dabca9a97 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKSuite.scala @@ -270,5 +270,139 @@ class ApproxTopKSuite extends SparkFunSuite { TypeCheckFailure("State struct must have the fourth field to be binary. Got: string")) } + ///////////////////////////// + // ApproxTopKCombine tests + ///////////////////////////// + test("SPARK-52798: invalid combine if maxItemsTracked is not foldable") { + val badCombine = ApproxTopKCombine( + state = BoundReference(0, StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) + )), nullable = false), + maxItemsTracked = Sum(BoundReference(1, IntegerType, nullable = true)) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(1), + "requiredType" -> "\"INT\"", + "inputSql" -> "\"sum(boundreference())\"", + "inputType" -> "\"BIGINT\"" + ) + ) + ) + } + + test("SPARK-52798: invalid combine if state is not a struct") { + val badCombine = ApproxTopKCombine( + state = BoundReference(0, IntegerType, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(0), + "requiredType" -> "\"STRUCT\"", + "inputSql" -> "\"boundreference()\"", + "inputType" -> "\"INT\"" + ) + ) + ) + } + + test("SPARK-52798: invalid combine if state struct length is not 4") { + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType) + // Missing "maxItemsTracked", "typeCode" fields + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure("State must be a struct with 4 fields. " + + "Expected struct: " + + "struct. " + + "Got: struct")) + } + + test("SPARK-52798: invalid combine if state struct's first field is not binary") { + val invalidState = StructType(Seq( + StructField("sketch", IntegerType), // Should be BinaryType + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure("State struct must have the first field to be binary. Got: int")) + } + + gridTest("SPARK-52798: invalid combine if state struct's second field is not supported")( + Seq( + ("array", ArrayType(IntegerType)), + ("map", MapType(StringType, IntegerType)), + ("struct", StructType(Seq(StructField("a", IntegerType)))), + ("binary", BinaryType) + )) { unSupportedType => + val (typeName, dataType) = unSupportedType + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", dataType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", BinaryType) + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure(s"State struct must have the second field to be a supported data type. " + + s"Got: $typeName")) + } + + test("SPARK-52798: invalid combine if state struct's third field is not int") { + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", LongType), // Should be IntegerType + StructField("typeCode", BinaryType) + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure("State struct must have the third field to be int. Got: bigint")) + } + test("SPARK-52798: invalid combine if state struct's fourth field is not binary") { + val invalidState = StructType(Seq( + StructField("sketch", BinaryType), + StructField("itemDataType", IntegerType), + StructField("maxItemsTracked", IntegerType), + StructField("typeCode", StringType) // Should be BinaryType + )) + val badCombine = ApproxTopKCombine( + state = BoundReference(0, invalidState, nullable = false), + maxItemsTracked = Literal(10) + ) + assert(badCombine.checkInputDataTypes().isFailure) + assert(badCombine.checkInputDataTypes() == + TypeCheckFailure("State struct must have the fourth field to be binary. Got: string")) + } } From a9fd60c47e37419c66bd545ebc6fe725bb01fd3e Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Tue, 15 Jul 2025 18:09:09 -0700 Subject: [PATCH 09/17] add doc --- .../aggregate/ApproxTopKAggregates.scala | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 5480ad2fcf2fe..5a4528020b48d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -534,6 +534,28 @@ class CombineInternal[T]( def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked } +/** + * An aggregate function that combines multiple sketches into a single sketch. + * + * @param state the expression containing the sketches to combine + * @param maxItemsTracked the maximum number of items to track in the sketch + * @param mutableAggBufferOffset the offset for mutable aggregation buffer + * @param inputAggBufferOffset the offset for input aggregation buffer + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single sketch. + `maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000. + """, + examples = """ + Examples: + > SELECT approx_top_k_estimate(_FUNC_(sketch)) FROM (SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (2), (3), (4), (4) AS tab(expr)); + [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] + """, + group = "agg_funcs", + since = "4.1.0") +// scalastyle:on line.size.limit case class ApproxTopKCombine( state: Expression, maxItemsTracked: Expression, From 12524a7a55e34fa57eedc417039ab14e51c186ab Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 16 Jul 2025 10:01:22 -0700 Subject: [PATCH 10/17] fix error and schema tests --- .../resources/error/error-conditions.json | 12 +++--- .../aggregate/ApproxTopKAggregates.scala | 41 ++++++++----------- .../sql/errors/QueryExecutionErrors.scala | 16 ++++++++ .../sql-functions/sql-expression-schema.md | 1 + 4 files changed, 40 insertions(+), 30 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3dc00d5cfb407..7135abe57df5e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -114,17 +114,17 @@ ], "sqlState" : "22004" }, - "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { + "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED" : { "message" : [ - "Combining approx_top_k sketches of different types is not allowed. Found sketches of type and ." + "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." ], - "sqlState": "42846" + "sqlState" : "42846" }, - "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED" : { + "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { "message" : [ - "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." + "Combining approx_top_k sketches of different types is not allowed. Found sketches of type and ." ], - "sqlState": "42846" + "sqlState" : "42846" }, "ARITHMETIC_OVERFLOW" : { "message" : [ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 5a4528020b48d..08e5db4e093a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -21,7 +21,6 @@ import org.apache.datasketches.common._ import org.apache.datasketches.frequencies.{ErrorType, ItemsSketch} import org.apache.datasketches.memory.Memory -import org.apache.spark.SparkUnsupportedOperationException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} @@ -521,11 +520,10 @@ class CombineInternal[T]( if (this.itemDataType == null) { this.itemDataType = dataType } else if (this.itemDataType != dataType) { - throw new SparkUnsupportedOperationException( - errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", - messageParameters = Map( - "type1" -> this.itemDataType.typeName, - "type2" -> dataType.typeName)) + throw QueryExecutionErrors.approxTopKSketchTypeUnmatched( + this.itemDataType.typeName, + dataType.typeName + ) } } @@ -550,7 +548,7 @@ class CombineInternal[T]( """, examples = """ Examples: - > SELECT approx_top_k_estimate(_FUNC_(sketch)) FROM (SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (2), (3), (4), (4) AS tab(expr)); + > SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (2), (3), (4), (4) AS tab(expr)); [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}] """, group = "agg_funcs", @@ -645,21 +643,19 @@ case class ApproxTopKCombine( buffer.setMaxItemsTracked(input.getMaxItemsTracked) } if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) { - throw new SparkUnsupportedOperationException( - errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", - messageParameters = Map( - "size1" -> buffer.getMaxItemsTracked.toString, - "size2" -> input.getMaxItemsTracked.toString)) + throw QueryExecutionErrors.approxTopKSketchSizeUnmatched( + buffer.getMaxItemsTracked, + input.getMaxItemsTracked + ) } } // check item data type if (buffer.getItemDataType != null && input.getItemDataType != null && buffer.getItemDataType != input.getItemDataType) { - throw new SparkUnsupportedOperationException( - errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", - messageParameters = Map( - "type1" -> buffer.getItemDataType.typeName, - "type2" -> input.getItemDataType.typeName)) + throw QueryExecutionErrors.approxTopKSketchTypeUnmatched( + buffer.getItemDataType.typeName, + input.getItemDataType.typeName + ) } else if (buffer.getItemDataType == null) { // If buffer is a placeholder sketch, set it to the input sketch's item data type buffer.setItemDataType(input.getItemDataType) @@ -669,14 +665,8 @@ case class ApproxTopKCombine( } override def eval(buffer: CombineInternal[Any]): Any = { - val sketchBytes = try { + val sketchBytes = buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType)) - } catch { - case _: ArrayStoreException => - throw new SparkUnsupportedOperationException( - errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" - ) - } val maxItemsTracked = buffer.getMaxItemsTracked val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType) InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode) @@ -716,4 +706,7 @@ case class ApproxTopKCombine( copy(inputAggBufferOffset = newInputAggBufferOffset) override def nullable: Boolean = false + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_combine") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 680d7b4b172db..022dd6abb1a05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2814,6 +2814,22 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "limit" -> toSQLValue(limit, IntegerType))) } + def approxTopKSketchSizeUnmatched(size1: Int, size2: Int): Throwable = { + new SparkRuntimeException( + errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + messageParameters = Map( + "leftSize" -> toSQLValue(size1, IntegerType), + "rightSize" -> toSQLValue(size2, IntegerType))) + } + + def approxTopKSketchTypeUnmatched(leftType: String, rightType: String): Throwable = { + new SparkRuntimeException( + errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + messageParameters = Map( + "leftType" -> toSQLType(leftType), + "rightType" -> toSQLType(rightType))) + } + def mergeCardinalityViolationError(): SparkRuntimeException = { new SparkRuntimeException( errorClass = "MERGE_CARDINALITY_VIOLATION", diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 3b0b21b9cd776..1812d98432322 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -403,6 +403,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue | any_value | SELECT any_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopK | approx_top_k | SELECT approx_top_k(expr) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr) | struct>> | | org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopKAccumulate | approx_top_k_accumulate | SELECT approx_top_k_estimate(approx_top_k_accumulate(expr)) FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) AS tab(expr) | struct>> | +| org.apache.spark.sql.catalyst.expressions.aggregate.ApproxTopKCombine | approx_top_k_combine | SELECT approx_top_k_estimate(approx_top_k_combine(sketch, 10000), 5) FROM (SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (2), (3), (4), (4) AS tab(expr)) | struct>> | | org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile | approx_percentile | SELECT approx_percentile(col, array(0.5, 0.4, 0.1), 100) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile | percentile_approx | SELECT percentile_approx(col, array(0.5, 0.4, 0.1), 100) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.Average | avg | SELECT avg(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | From 8d50b3e887f9f126a9abccdf6bcf2eb62a066693 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 16 Jul 2025 10:42:58 -0700 Subject: [PATCH 11/17] update errors --- .../sql/errors/QueryExecutionErrors.scala | 53 ++++++++++--------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 022dd6abb1a05..d303d113123cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -152,13 +152,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE s: UTF8String, fmt: UTF8String, hint: String): SparkIllegalArgumentException = { - new SparkIllegalArgumentException( - errorClass = "CONVERSION_INVALID_INPUT", - messageParameters = Map( - "str" -> toSQLValue(s, StringType), - "fmt" -> toSQLValue(fmt, StringType), - "targetType" -> toSQLType(to), - "suggestion" -> toSQLId(hint))) + new SparkIllegalArgumentException( + errorClass = "CONVERSION_INVALID_INPUT", + messageParameters = Map( + "str" -> toSQLValue(s, StringType), + "fmt" -> toSQLValue(fmt, StringType), + "targetType" -> toSQLType(to), + "suggestion" -> toSQLId(hint))) } def cannotCastFromNullTypeError(to: DataType): Throwable = { @@ -249,7 +249,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def invalidBitmapPositionError(bitPosition: Long, - bitmapNumBytes: Long): ArrayIndexOutOfBoundsException = { + bitmapNumBytes: Long): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_BITMAP_POSITION", messageParameters = Map( @@ -346,7 +346,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE errorClass = "UNSUPPORTED_FEATURE.LITERAL_TYPE", messageParameters = Map( "value" -> v.toString, - "type" -> v.getClass.toString)) + "type" -> v.getClass.toString)) } def pivotColumnUnsupportedError(v: Any, expr: Expression): RuntimeException = { @@ -668,7 +668,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def withoutSuggestionIntervalArithmeticOverflowError( - context: QueryContext): SparkArithmeticException = { + context: QueryContext): SparkArithmeticException = { new SparkArithmeticException( errorClass = "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION", messageParameters = Map(), @@ -728,7 +728,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "windowFunc" -> windowFuncList.map(toSQLStmt(_)).mkString(","), "columnName" -> columnNameList.map(toSQLId(_)).mkString(","), "windowSpec" -> windowSpecList.map(toSQLStmt(_)).mkString(",")), - origin = origin) + origin = origin) } def multiplePathsSpecifiedError(allPaths: Seq[String]): SparkIllegalArgumentException = { @@ -1421,7 +1421,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def arrayFunctionWithElementsExceedLimitError( - prettyName: String, numberOfElements: Long): SparkRuntimeException = { + prettyName: String, numberOfElements: Long): SparkRuntimeException = { new SparkRuntimeException( errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.FUNCTION", messageParameters = Map( @@ -1432,7 +1432,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def createArrayWithElementsExceedLimitError( - prettyName: String, count: Any): SparkRuntimeException = { + prettyName: String, count: Any): SparkRuntimeException = { new SparkRuntimeException( errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER", messageParameters = Map( @@ -1547,7 +1547,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE errorClass = "_LEGACY_ERROR_TEMP_2176", messageParameters = Map( "numElements" -> numElements.toString(), - "maxRoundedArrayLength"-> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), + "maxRoundedArrayLength" -> ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(), "additionalErrorMessage" -> additionalErrorMessage)) } @@ -1742,7 +1742,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def stateNotDefinedOrAlreadyRemovedError(): Throwable = { - new NoSuchElementException("State is either not defined or has already been removed") + new NoSuchElementException("State is either not defined or has already been removed") } def statefulOperatorNotMatchInStateMetadataError( @@ -1750,6 +1750,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE opsInCurBatchSeq: Map[Long, String]): SparkRuntimeException = { def formatPairString(pair: (Long, String)): String = s"(OperatorId: ${pair._1} -> OperatorName: ${pair._2})" + new SparkRuntimeException( errorClass = s"STREAMING_STATEFUL_OPERATOR_NOT_MATCH_IN_STATE_METADATA", messageParameters = Map( @@ -2645,7 +2646,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE acquiredThreadInfo: String, timeWaitedMs: Long, stackTraceOutput: String): Throwable = { - new SparkException ( + new SparkException( errorClass = "CANNOT_LOAD_STATE_STORE.UNRELEASED_THREAD_ERROR", messageParameters = Map( "loggingId" -> loggingId, @@ -2658,7 +2659,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def cannotReadCheckpoint(expectedVersion: String, actualVersion: String): Throwable = { - new SparkException ( + new SparkException( errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_READ_CHECKPOINT", messageParameters = Map( "expectedVersion" -> expectedVersion, @@ -2667,7 +2668,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE } def cannotFindBaseSnapshotCheckpoint(lineage: String): Throwable = { - new SparkException ( + new SparkException( errorClass = "CANNOT_LOAD_STATE_STORE.CANNOT_FIND_BASE_SNAPSHOT_CHECKPOINT", messageParameters = Map("lineage" -> lineage), @@ -2818,16 +2819,16 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE new SparkRuntimeException( errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", messageParameters = Map( - "leftSize" -> toSQLValue(size1, IntegerType), - "rightSize" -> toSQLValue(size2, IntegerType))) + "size1" -> toSQLValue(size1, IntegerType), + "size2" -> toSQLValue(size2, IntegerType))) } - def approxTopKSketchTypeUnmatched(leftType: String, rightType: String): Throwable = { + def approxTopKSketchTypeUnmatched(type1: String, type2: String): Throwable = { new SparkRuntimeException( errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", messageParameters = Map( - "leftType" -> toSQLType(leftType), - "rightType" -> toSQLType(rightType))) + "type1" -> toSQLType(type1), + "type2" -> toSQLType(type2))) } def mergeCardinalityViolationError(): SparkRuntimeException = { @@ -2856,7 +2857,11 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE val errorParmsMutable = collection.mutable.Map[String, String]() errorParms.foreach(StringType, StringType, { case (key, value) => errorParmsMutable += (key.toString -> - (if (value == null) { "null" } else { value.toString } )) + (if (value == null) { + "null" + } else { + value.toString + })) }) errorParmsMutable.toMap } else { From 66a67785dd257b326325ab87db0f48e16711b546 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 16 Jul 2025 10:43:53 -0700 Subject: [PATCH 12/17] change error class --- .../org/apache/spark/sql/ApproxTopKSuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 99e6ec26137e5..a3802d6ceced7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.time.LocalDateTime -import org.apache.spark.{SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{ByteType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType} @@ -447,7 +447,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") checkError( - exception = intercept[SparkUnsupportedOperationException] { + exception = intercept[SparkRuntimeException] { comb.collect() }, condition = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", @@ -469,7 +469,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - test("SPARK-52798: among different number or datetime types - fail at combine") { + test("SPARK-52798T: among different number or datetime types - fail at combine") { def checkMixedTypeError(mixedTypeSeq: Seq[(String, String, Seq[Any])]): Unit = { for (i <- 0 until mixedTypeSeq.size - 1) { for (j <- i + 1 until mixedTypeSeq.size) { @@ -477,7 +477,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { val (type2, _, seq2) = mixedTypeSeq(j) setupMixedTypeAccumulation(seq1, seq2) checkError( - exception = intercept[SparkUnsupportedOperationException] { + exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", @@ -519,11 +519,11 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-52798: number vs string - fail at combine")(mixedNumberTypeSeqs) { + gridTest("SPARK-52798T: number vs string - fail at combine")(mixedNumberTypeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( - exception = intercept[SparkUnsupportedOperationException] { + exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", @@ -555,11 +555,11 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-52798: datetime vs string - fail at combine")(mixedDateTimeSeqs) { + gridTest("SPARK-52798T: datetime vs string - fail at combine")(mixedDateTimeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( - exception = intercept[SparkUnsupportedOperationException] { + exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", From 0c22de1641201fc2cbaf56022afb5fcdf8da4794 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 16 Jul 2025 10:53:14 -0700 Subject: [PATCH 13/17] fix type tests --- .../apache/spark/sql/ApproxTopKSuite.scala | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index a3802d6ceced7..808455c5b90fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -22,8 +22,9 @@ import java.time.LocalDateTime import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{ByteType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType} class ApproxTopKSuite extends QueryTest with SharedSparkSession { @@ -469,7 +470,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - test("SPARK-52798T: among different number or datetime types - fail at combine") { + test("SPARK-52798: among different number or datetime types - fail at combine") { def checkMixedTypeError(mixedTypeSeq: Seq[(String, String, Seq[Any])]): Unit = { for (i <- 0 until mixedTypeSeq.size - 1) { for (j <- i + 1 until mixedTypeSeq.size) { @@ -481,7 +482,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", - parameters = Map("type1" -> type1, "type2" -> type2) + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(type2)) ) } } @@ -519,7 +520,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-52798T: number vs string - fail at combine")(mixedNumberTypeSeqs) { + gridTest("SPARK-52798: number vs string - fail at combine")(mixedNumberTypeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -527,7 +528,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", - parameters = Map("type1" -> type1, "type2" -> StringType.typeName) + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) ) } @@ -555,7 +556,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-52798T: datetime vs string - fail at combine")(mixedDateTimeSeqs) { + gridTest("SPARK-52798: datetime vs string - fail at combine")(mixedDateTimeSeqs) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -563,7 +564,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", - parameters = Map("type1" -> type1, "type2" -> StringType.typeName) + parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) ) } @@ -595,5 +596,12 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { val seq1 = Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'") val seq2 = Seq("(true)", "(true)", "(false)", "(false)") setupMixedTypeAccumulation(seq1, seq2) + checkError( + exception = intercept[SparkRuntimeException] { + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() + }, + condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + parameters = Map("type1" -> toSQLType(StringType), "type2" -> toSQLType(BooleanType)) + ) } } From 7f547ba1aed796b84f5f5693ab29acbcb0e32f8b Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 16 Jul 2025 11:07:28 -0700 Subject: [PATCH 14/17] unify type error output --- .../aggregate/ApproxTopKAggregates.scala | 9 ++---- .../sql/errors/QueryExecutionErrors.scala | 2 +- .../apache/spark/sql/ApproxTopKSuite.scala | 32 +++++++++---------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 08e5db4e093a3..532373705fed9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -520,10 +520,7 @@ class CombineInternal[T]( if (this.itemDataType == null) { this.itemDataType = dataType } else if (this.itemDataType != dataType) { - throw QueryExecutionErrors.approxTopKSketchTypeUnmatched( - this.itemDataType.typeName, - dataType.typeName - ) + throw QueryExecutionErrors.approxTopKSketchTypeUnmatched(this.itemDataType, dataType) } } @@ -653,8 +650,8 @@ case class ApproxTopKCombine( if (buffer.getItemDataType != null && input.getItemDataType != null && buffer.getItemDataType != input.getItemDataType) { throw QueryExecutionErrors.approxTopKSketchTypeUnmatched( - buffer.getItemDataType.typeName, - input.getItemDataType.typeName + buffer.getItemDataType, + input.getItemDataType ) } else if (buffer.getItemDataType == null) { // If buffer is a placeholder sketch, set it to the input sketch's item data type diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index d303d113123cf..74ccfc907e495 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2823,7 +2823,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "size2" -> toSQLValue(size2, IntegerType))) } - def approxTopKSketchTypeUnmatched(type1: String, type2: String): Throwable = { + def approxTopKSketchTypeUnmatched(type1: DataType, type2: DataType): Throwable = { new SparkRuntimeException( errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", messageParameters = Map( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 808455c5b90fb..aedf73b009e33 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.{SparkArithmeticException, SparkRuntimeException} import org.apache.spark.sql.catalyst.ExtendedAnalysisException import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType} class ApproxTopKSuite extends QueryTest with SharedSparkSession { @@ -359,33 +359,33 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { .createOrReplaceTempView("unioned") } - val mixedNumberTypeSeqs: Seq[(String, String, Seq[Any])] = Seq( - (IntegerType.typeName, "INT", + val mixedNumberTypeSeqs: Seq[(DataType, String, Seq[Any])] = Seq( + (IntegerType, "INT", Seq(0, 0, 0, 1, 1, 2, 2, 3)), - (ByteType.typeName, "TINYINT", + (ByteType, "TINYINT", Seq("cast(0 AS BYTE)", "cast(0 AS BYTE)", "cast(1 AS BYTE)")), - (ShortType.typeName, "SMALLINT", + (ShortType, "SMALLINT", Seq("cast(0 AS SHORT)", "cast(0 AS SHORT)", "cast(1 AS SHORT)")), - (LongType.typeName, "BIGINT", + (LongType, "BIGINT", Seq("cast(0 AS LONG)", "cast(0 AS LONG)", "cast(1 AS LONG)")), - (FloatType.typeName, "FLOAT", + (FloatType, "FLOAT", Seq("cast(0 AS FLOAT)", "cast(0 AS FLOAT)", "cast(1 AS FLOAT)")), - (DoubleType.typeName, "DOUBLE", + (DoubleType, "DOUBLE", Seq("cast(0 AS DOUBLE)", "cast(0 AS DOUBLE)", "cast(1 AS DOUBLE)")), - (DecimalType(4, 2).typeName, "DECIMAL(4,2)", + (DecimalType(4, 2), "DECIMAL(4,2)", Seq("cast(0 AS DECIMAL(4, 2))", "cast(0 AS DECIMAL(4, 2))", "cast(1 AS DECIMAL(4, 2))")), - (DecimalType(10, 2).typeName, "DECIMAL(10,2)", + (DecimalType(10, 2), "DECIMAL(10,2)", Seq("cast(0 AS DECIMAL(10, 2))", "cast(0 AS DECIMAL(10, 2))", "cast(1 AS DECIMAL(10, 2))")), - (DecimalType(20, 3).typeName, "DECIMAL(20,3)", + (DecimalType(20, 3), "DECIMAL(20,3)", Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS DECIMAL(20, 3))")) ) - val mixedDateTimeSeqs: Seq[(String, String, Seq[String])] = Seq( - (DateType.typeName, "DATE", + val mixedDateTimeSeqs: Seq[(DataType, String, Seq[String])] = Seq( + (DateType, "DATE", Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")), - (TimestampType.typeName, "TIMESTAMP", + (TimestampType, "TIMESTAMP", Seq("TIMESTAMP'2025-01-01 00:00:00'", "TIMESTAMP'2025-01-01 00:00:00'")), - (TimestampNTZType.typeName, "TIMESTAMP_NTZ", + (TimestampNTZType, "TIMESTAMP_NTZ", Seq("TIMESTAMP_NTZ'2025-01-01 00:00:00'", "TIMESTAMP_NTZ'2025-01-01 00:00:00'") ) ) @@ -471,7 +471,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { } test("SPARK-52798: among different number or datetime types - fail at combine") { - def checkMixedTypeError(mixedTypeSeq: Seq[(String, String, Seq[Any])]): Unit = { + def checkMixedTypeError(mixedTypeSeq: Seq[(DataType, String, Seq[Any])]): Unit = { for (i <- 0 until mixedTypeSeq.size - 1) { for (j <- i + 1 until mixedTypeSeq.size) { val (type1, _, seq1) = mixedTypeSeq(i) From a32d8c7bd9e64b4c3dd18ceaa3854db044e05a54 Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 16 Jul 2025 11:48:18 -0700 Subject: [PATCH 15/17] rename error class --- .../src/main/resources/error/error-conditions.json | 4 ++-- .../expressions/aggregate/ApproxTopKAggregates.scala | 6 +++--- .../apache/spark/sql/errors/QueryExecutionErrors.scala | 8 ++++---- .../scala/org/apache/spark/sql/ApproxTopKSuite.scala | 10 +++++----- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 7135abe57df5e..dce1befe58240 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -114,13 +114,13 @@ ], "sqlState" : "22004" }, - "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED" : { + "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH" : { "message" : [ "Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size and ." ], "sqlState" : "42846" }, - "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED" : { + "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH" : { "message" : [ "Combining approx_top_k sketches of different types is not allowed. Found sketches of type and ." ], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala index 532373705fed9..1989c190aa609 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxTopKAggregates.scala @@ -520,7 +520,7 @@ class CombineInternal[T]( if (this.itemDataType == null) { this.itemDataType = dataType } else if (this.itemDataType != dataType) { - throw QueryExecutionErrors.approxTopKSketchTypeUnmatched(this.itemDataType, dataType) + throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType) } } @@ -640,7 +640,7 @@ case class ApproxTopKCombine( buffer.setMaxItemsTracked(input.getMaxItemsTracked) } if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) { - throw QueryExecutionErrors.approxTopKSketchSizeUnmatched( + throw QueryExecutionErrors.approxTopKSketchSizeNotMatch( buffer.getMaxItemsTracked, input.getMaxItemsTracked ) @@ -649,7 +649,7 @@ case class ApproxTopKCombine( // check item data type if (buffer.getItemDataType != null && input.getItemDataType != null && buffer.getItemDataType != input.getItemDataType) { - throw QueryExecutionErrors.approxTopKSketchTypeUnmatched( + throw QueryExecutionErrors.approxTopKSketchTypeNotMatch( buffer.getItemDataType, input.getItemDataType ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 74ccfc907e495..bdd252c245620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2815,17 +2815,17 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "limit" -> toSQLValue(limit, IntegerType))) } - def approxTopKSketchSizeUnmatched(size1: Int, size2: Int): Throwable = { + def approxTopKSketchSizeNotMatch(size1: Int, size2: Int): Throwable = { new SparkRuntimeException( - errorClass = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + errorClass = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH", messageParameters = Map( "size1" -> toSQLValue(size1, IntegerType), "size2" -> toSQLValue(size2, IntegerType))) } - def approxTopKSketchTypeUnmatched(type1: DataType, type2: DataType): Throwable = { + def approxTopKSketchTypeNotMatch(type1: DataType, type2: DataType): Throwable = { new SparkRuntimeException( - errorClass = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + errorClass = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", messageParameters = Map( "type1" -> toSQLType(type1), "type2" -> toSQLType(type2))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index aedf73b009e33..5ebb762afe465 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -451,7 +451,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { comb.collect() }, - condition = "APPROX_TOP_K_SKETCH_SIZE_UNMATCHED", + condition = "APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH", parameters = Map("size1" -> "10", "size2" -> "20") ) } @@ -481,7 +481,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(type2)) ) } @@ -527,7 +527,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) ) } @@ -563,7 +563,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", parameters = Map("type1" -> toSQLType(type1), "type2" -> toSQLType(StringType)) ) } @@ -600,7 +600,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned;").collect() }, - condition = "APPROX_TOP_K_SKETCH_TYPE_UNMATCHED", + condition = "APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH", parameters = Map("type1" -> toSQLType(StringType), "type2" -> toSQLType(BooleanType)) ) } From abe4a99c4d732b4c3061bd287422348c56ad445d Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 16 Jul 2025 11:53:16 -0700 Subject: [PATCH 16/17] update setup --- .../apache/spark/sql/ApproxTopKSuite.scala | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 5ebb762afe465..58c50dbd7ed00 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -344,6 +344,9 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { sql(s"SELECT approx_top_k_accumulate(expr, $size2) as acc " + "FROM VALUES (1), (1), (2), (2), (3), (3), (4), (4) AS tab(expr);") .createOrReplaceTempView("accumulation2") + + sql("SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2") + .createOrReplaceTempView("unioned") } def setupMixedTypeAccumulation(seq1: Seq[Any], seq2: Seq[Any]): Unit = { @@ -411,8 +414,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { test("SPARK-52798: same type, same size, specified combine size - success") { setupAccumulations(10, 10) - sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") .createOrReplaceTempView("combined") val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") @@ -422,8 +424,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { test("SPARK-52798: same type, same size, unspecified combine size - success") { setupAccumulations(10, 10) - sql("SELECT approx_top_k_combine(acc) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + sql("SELECT approx_top_k_combine(acc) as com FROM unioned") .createOrReplaceTempView("combined") val est = sql("SELECT approx_top_k_estimate(com) FROM combined;") @@ -433,8 +434,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { test("SPARK-52798: same type, different size, specified combine size - success") { setupAccumulations(10, 20) - sql("SELECT approx_top_k_combine(acc, 30) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") .createOrReplaceTempView("combination") val est = sql("SELECT approx_top_k_estimate(com) FROM combination;") @@ -444,8 +444,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { test("SPARK-52798: same type, different size, unspecified combine size - fail") { setupAccumulations(10, 20) - val comb = sql("SELECT approx_top_k_combine(acc) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") + val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned") checkError( exception = intercept[SparkRuntimeException] { @@ -461,9 +460,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { setupAccumulations(size1, size2) checkError( exception = intercept[SparkRuntimeException] { - sql("SELECT approx_top_k_combine(acc, 0) as com " + - "FROM (SELECT acc from accumulation1 UNION ALL SELECT acc FROM accumulation2);") - .collect() + sql("SELECT approx_top_k_combine(acc, 0) as com FROM unioned").collect() }, condition = "APPROX_TOP_K_NON_POSITIVE_ARG", parameters = Map("argName" -> "`maxItemsTracked`", "argValue" -> "0") From 6cd62225251be7dd9722f3e6364881a687cf225a Mon Sep 17 00:00:00 2001 From: yhuang-db Date: Wed, 16 Jul 2025 11:55:53 -0700 Subject: [PATCH 17/17] refactor names --- .../apache/spark/sql/ApproxTopKSuite.scala | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala index 58c50dbd7ed00..040254ef038cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxTopKSuite.scala @@ -336,7 +336,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { // approx_top_k_combine ///////////////////////////////// - def setupAccumulations(size1: Int, size2: Int): Unit = { + def setupMixedSizeAccumulations(size1: Int, size2: Int): Unit = { sql(s"SELECT approx_top_k_accumulate(expr, $size1) as acc " + "FROM VALUES (0), (0), (0), (1), (1), (2), (2), (3) AS tab(expr);") .createOrReplaceTempView("accumulation1") @@ -362,7 +362,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { .createOrReplaceTempView("unioned") } - val mixedNumberTypeSeqs: Seq[(DataType, String, Seq[Any])] = Seq( + val mixedNumberTypes: Seq[(DataType, String, Seq[Any])] = Seq( (IntegerType, "INT", Seq(0, 0, 0, 1, 1, 2, 2, 3)), (ByteType, "TINYINT", @@ -383,7 +383,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { Seq("cast(0 AS DECIMAL(20, 3))", "cast(0 AS DECIMAL(20, 3))", "cast(1 AS DECIMAL(20, 3))")) ) - val mixedDateTimeSeqs: Seq[(DataType, String, Seq[String])] = Seq( + val mixedDateTimeTypes: Seq[(DataType, String, Seq[String])] = Seq( (DateType, "DATE", Seq("DATE'2025-01-01'", "DATE'2025-01-01'", "DATE'2025-01-02'")), (TimestampType, "TIMESTAMP", @@ -412,7 +412,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { } test("SPARK-52798: same type, same size, specified combine size - success") { - setupAccumulations(10, 10) + setupMixedSizeAccumulations(10, 10) sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") .createOrReplaceTempView("combined") @@ -422,7 +422,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { } test("SPARK-52798: same type, same size, unspecified combine size - success") { - setupAccumulations(10, 10) + setupMixedSizeAccumulations(10, 10) sql("SELECT approx_top_k_combine(acc) as com FROM unioned") .createOrReplaceTempView("combined") @@ -432,7 +432,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { } test("SPARK-52798: same type, different size, specified combine size - success") { - setupAccumulations(10, 20) + setupMixedSizeAccumulations(10, 20) sql("SELECT approx_top_k_combine(acc, 30) as com FROM unioned") .createOrReplaceTempView("combination") @@ -442,7 +442,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { } test("SPARK-52798: same type, different size, unspecified combine size - fail") { - setupAccumulations(10, 20) + setupMixedSizeAccumulations(10, 20) val comb = sql("SELECT approx_top_k_combine(acc) as com FROM unioned") @@ -457,7 +457,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { gridTest("SPARK-52798: invalid combine size - fail")(Seq((10, 10), (10, 20))) { case (size1, size2) => - setupAccumulations(size1, size2) + setupMixedSizeAccumulations(size1, size2) checkError( exception = intercept[SparkRuntimeException] { sql("SELECT approx_top_k_combine(acc, 0) as com FROM unioned").collect() @@ -485,15 +485,15 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { } } - checkMixedTypeError(mixedNumberTypeSeqs) - checkMixedTypeError(mixedDateTimeSeqs) + checkMixedTypeError(mixedNumberTypes) + checkMixedTypeError(mixedDateTimeTypes) } // enumerate all combinations of number and datetime types gridTest("SPARK-52798: number vs datetime - fail on UNION")( for { - (type1, typeName1, seq1) <- mixedNumberTypeSeqs - (type2, typeName2, seq2) <- mixedDateTimeSeqs + (type1, typeName1, seq1) <- mixedNumberTypes + (type2, typeName2, seq2) <- mixedDateTimeTypes } yield ((type1, typeName1, seq1), (type2, typeName2, seq2))) { case ((_, type1, seq1), (_, type2, seq2)) => checkError( @@ -517,7 +517,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-52798: number vs string - fail at combine")(mixedNumberTypeSeqs) { + gridTest("SPARK-52798: number vs string - fail at combine")(mixedNumberTypes) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -529,7 +529,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypeSeqs) { + gridTest("SPARK-52798: number vs boolean - fail at UNION")(mixedNumberTypes) { case (_, type1, seq1) => val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError( @@ -553,7 +553,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-52798: datetime vs string - fail at combine")(mixedDateTimeSeqs) { + gridTest("SPARK-52798: datetime vs string - fail at combine")(mixedDateTimeTypes) { case (type1, _, seq1) => setupMixedTypeAccumulation(seq1, Seq("'a'", "'b'", "'c'", "'c'", "'c'", "'c'", "'d'", "'d'")) checkError( @@ -565,7 +565,7 @@ class ApproxTopKSuite extends QueryTest with SharedSparkSession { ) } - gridTest("SPARK-52798: datetime vs boolean - fail at UNION")(mixedDateTimeSeqs) { + gridTest("SPARK-52798: datetime vs boolean - fail at UNION")(mixedDateTimeTypes) { case (_, type1, seq1) => val seq2 = Seq("(true)", "(true)", "(false)", "(false)") checkError(