Skip to content

[SPARK-52798] [SQL] Add function approx_top_k_combine #51505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@
],
"sqlState" : "22004"
},
"APPROX_TOP_K_SKETCH_SIZE_NOT_MATCH" : {
"message" : [
"Combining approx_top_k sketches of different sizes is not allowed. Found sketches of size <size1> and <size2>."
],
"sqlState" : "42846"
},
"APPROX_TOP_K_SKETCH_TYPE_NOT_MATCH" : {
"message" : [
"Combining approx_top_k sketches of different types is not allowed. Found sketches of type <type1> and <type2>."
],
"sqlState" : "42846"
},
"ARITHMETIC_OVERFLOW" : {
"message" : [
"<message>.<alternative> If necessary set <config> to \"false\" to bypass this error."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ object FunctionRegistry {
expression[HllUnionAgg]("hll_union_agg"),
expression[ApproxTopK]("approx_top_k"),
expression[ApproxTopKAccumulate]("approx_top_k_accumulate"),
expression[ApproxTopKCombine]("approx_top_k_combine"),

// string functions
expression[Ascii]("ascii"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,35 +76,12 @@ case class ApproxTopKEstimate(state: Expression, k: Expression)

override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)

private def checkStateFieldAndType(state: Expression): TypeCheckResult = {
val stateStructType = state.dataType.asInstanceOf[StructType]
if (stateStructType.length != 3) {
return TypeCheckFailure("State must be a struct with 3 fields. " +
"Expected struct: struct<sketch:binary,itemDataType:any,maxItemsTracked:int>. " +
"Got: " + state.dataType.simpleString)
}

if (stateStructType.head.dataType != BinaryType) {
TypeCheckFailure("State struct must have the first field to be binary. " +
"Got: " + stateStructType.head.dataType.simpleString)
} else if (!ApproxTopK.isDataTypeSupported(itemDataType)) {
TypeCheckFailure("State struct must have the second field to be a supported data type. " +
"Got: " + itemDataType.simpleString)
} else if (stateStructType(2).dataType != IntegerType) {
TypeCheckFailure("State struct must have the third field to be int. " +
"Got: " + stateStructType(2).dataType.simpleString)
} else {
TypeCheckSuccess
}
}


override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else {
val stateCheck = checkStateFieldAndType(state)
val stateCheck = ApproxTopK.checkStateFieldAndType(state)
if (stateCheck.isFailure) {
stateCheck
} else if (!k.foldable) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ object ApproxTopK {
val DEFAULT_K: Int = 5
val DEFAULT_MAX_ITEMS_TRACKED: Int = 10000
private val MAX_ITEMS_TRACKED_LIMIT: Int = 1000000
val VOID_MAX_ITEMS_TRACKED = -1
val SKETCH_SIZE_PLACEHOLDER = 8

def checkExpressionNotNull(expr: Expression, exprName: String): Unit = {
if (expr == null || expr.eval() == null) {
Expand Down Expand Up @@ -318,7 +320,72 @@ object ApproxTopK {
StructType(
StructField("sketch", BinaryType, nullable = false) ::
StructField("itemDataType", itemDataType) ::
StructField("maxItemsTracked", IntegerType, nullable = false) :: Nil)
StructField("maxItemsTracked", IntegerType, nullable = false) ::
StructField("typeCode", BinaryType, nullable = false) :: Nil)

def dataTypeToBytes(dataType: DataType): Array[Byte] = {
dataType match {
case _: BooleanType => Array(0, 0, 0)
case _: ByteType => Array(1, 0, 0)
case _: ShortType => Array(2, 0, 0)
case _: IntegerType => Array(3, 0, 0)
case _: LongType => Array(4, 0, 0)
case _: FloatType => Array(5, 0, 0)
case _: DoubleType => Array(6, 0, 0)
case _: DateType => Array(7, 0, 0)
case _: TimestampType => Array(8, 0, 0)
case _: TimestampNTZType => Array(9, 0, 0)
case _: StringType => Array(10, 0, 0)
case dt: DecimalType => Array(11, dt.precision.toByte, dt.scale.toByte)
}
}

def bytesToDataType(bytes: Array[Byte]): DataType = {
bytes(0) match {
case 0 => BooleanType
case 1 => ByteType
case 2 => ShortType
case 3 => IntegerType
case 4 => LongType
case 5 => FloatType
case 6 => DoubleType
case 7 => DateType
case 8 => TimestampType
case 9 => TimestampNTZType
case 10 => StringType
case 11 => DecimalType(bytes(1).toInt, bytes(2).toInt)
}
}

def checkStateFieldAndType(state: Expression): TypeCheckResult = {
val stateStructType = state.dataType.asInstanceOf[StructType]
if (stateStructType.length != 4) {
return TypeCheckFailure("State must be a struct with 4 fields. " +
"Expected struct: " +
"struct<sketch:binary,itemDataType:any,maxItemsTracked:int,typeCode:binary>. " +
"Got: " + state.dataType.simpleString)
}

val fieldType1 = stateStructType.head.dataType
val fieldType2 = stateStructType(1).dataType
val fieldType3 = stateStructType(2).dataType
val fieldType4 = stateStructType(3).dataType
if (fieldType1 != BinaryType) {
TypeCheckFailure("State struct must have the first field to be binary. " +
"Got: " + fieldType1.simpleString)
} else if (!ApproxTopK.isDataTypeSupported(fieldType2)) {
TypeCheckFailure("State struct must have the second field to be a supported data type. " +
"Got: " + fieldType2.simpleString)
} else if (fieldType3 != IntegerType) {
TypeCheckFailure("State struct must have the third field to be int. " +
"Got: " + fieldType3.simpleString)
} else if (fieldType4 != BinaryType) {
TypeCheckFailure("State struct must have the fourth field to be binary. " +
"Got: " + fieldType4.simpleString)
} else {
TypeCheckSuccess
}
}
}

/**
Expand All @@ -328,7 +395,11 @@ object ApproxTopK {
*
* The output of this function is a struct containing the sketch in binary format,
* a null object indicating the type of items in the sketch,
* and the maximum number of items tracked by the sketch.
* the maximum number of items tracked by the sketch,
* and a binary typeCode encoding the data type of the items in the sketch.
*
* The null object is used in approx_top_k_estimate,
* while the typeCode is used in approx_top_k_combine.
*
* @param expr the child expression to accumulate items from
* @param maxItemsTracked the maximum number of items to track in the sketch
Expand Down Expand Up @@ -410,7 +481,8 @@ case class ApproxTopKAccumulate(

override def eval(buffer: ItemsSketch[Any]): Any = {
val sketchBytes = serialize(buffer)
InternalRow.apply(sketchBytes, null, maxItemsTrackedVal)
val typeCode = ApproxTopK.dataTypeToBytes(itemDataType)
InternalRow.apply(sketchBytes, null, maxItemsTrackedVal, typeCode)
}

override def serialize(buffer: ItemsSketch[Any]): Array[Byte] =
Expand All @@ -435,3 +507,203 @@ case class ApproxTopKAccumulate(
override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_accumulate")
}

class CombineInternal[T](
sketch: ItemsSketch[T],
var itemDataType: DataType,
var maxItemsTracked: Int) {
def getSketch: ItemsSketch[T] = sketch

def getItemDataType: DataType = itemDataType

def setItemDataType(dataType: DataType): Unit = {
if (this.itemDataType == null) {
this.itemDataType = dataType
} else if (this.itemDataType != dataType) {
throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(this.itemDataType, dataType)
}
}

def getMaxItemsTracked: Int = maxItemsTracked

def setMaxItemsTracked(maxItemsTracked: Int): Unit = this.maxItemsTracked = maxItemsTracked
}

/**
* An aggregate function that combines multiple sketches into a single sketch.
*
* @param state the expression containing the sketches to combine
* @param maxItemsTracked the maximum number of items to track in the sketch
* @param mutableAggBufferOffset the offset for mutable aggregation buffer
* @param inputAggBufferOffset the offset for input aggregation buffer
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(state, maxItemsTracked) - Combines multiple sketches into a single sketch.
`maxItemsTracked` An optional positive INTEGER literal with upper limit of 1000000. If maxItemsTracked is not specified, it defaults to 10000.
""",
examples = """
Examples:
> SELECT approx_top_k_estimate(_FUNC_(sketch, 10000), 5) FROM (SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (0), (0), (1), (1) AS tab(expr) UNION ALL SELECT approx_top_k_accumulate(expr) AS sketch FROM VALUES (2), (3), (4), (4) AS tab(expr));
[{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
""",
group = "agg_funcs",
since = "4.1.0")
// scalastyle:on line.size.limit
case class ApproxTopKCombine(
state: Expression,
maxItemsTracked: Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[CombineInternal[Any]]
with ImplicitCastInputTypes
with BinaryLike[Expression] {

def this(child: Expression, maxItemsTracked: Expression) = {
this(child, maxItemsTracked, 0, 0)
ApproxTopK.checkExpressionNotNull(maxItemsTracked, "maxItemsTracked")
ApproxTopK.checkMaxItemsTracked(maxItemsTrackedVal)
}

def this(child: Expression, maxItemsTracked: Int) = this(child, Literal(maxItemsTracked))

def this(child: Expression) = this(child, Literal(ApproxTopK.VOID_MAX_ITEMS_TRACKED), 0, 0)

private lazy val uncheckedItemDataType: DataType =
state.dataType.asInstanceOf[StructType](1).dataType
private lazy val maxItemsTrackedVal: Int = maxItemsTracked.eval().asInstanceOf[Int]
private lazy val combineSizeSpecified: Boolean =
maxItemsTrackedVal != ApproxTopK.VOID_MAX_ITEMS_TRACKED

override def left: Expression = state

override def right: Expression = maxItemsTracked

override def inputTypes: Seq[AbstractDataType] = Seq(StructType, IntegerType)

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else {
val stateCheck = ApproxTopK.checkStateFieldAndType(state)
if (stateCheck.isFailure) {
stateCheck
} else if (!maxItemsTracked.foldable) {
TypeCheckFailure("Number of items tracked must be a constant literal")
} else {
TypeCheckSuccess
}
}
}

override def dataType: DataType = ApproxTopK.getSketchStateDataType(uncheckedItemDataType)

override def createAggregationBuffer(): CombineInternal[Any] = {
if (combineSizeSpecified) {
val maxMapSize = ApproxTopK.calMaxMapSize(maxItemsTrackedVal)
new CombineInternal[Any](
new ItemsSketch[Any](maxMapSize),
null,
maxItemsTrackedVal)
} else {
new CombineInternal[Any](
new ItemsSketch[Any](ApproxTopK.SKETCH_SIZE_PLACEHOLDER),
null,
ApproxTopK.VOID_MAX_ITEMS_TRACKED)
}
}

override def update(buffer: CombineInternal[Any], input: InternalRow): CombineInternal[Any] = {
val inputState = state.eval(input).asInstanceOf[InternalRow]
val inputSketchBytes = inputState.getBinary(0)
val inputMaxItemsTracked = inputState.getInt(2)
val typeCode = inputState.getBinary(3)
val actualItemDataType = ApproxTopK.bytesToDataType(typeCode)
buffer.setItemDataType(actualItemDataType)
val inputSketch = ItemsSketch.getInstance(
Memory.wrap(inputSketchBytes), ApproxTopK.genSketchSerDe(buffer.getItemDataType))
buffer.getSketch.merge(inputSketch)
if (!combineSizeSpecified) {
buffer.setMaxItemsTracked(inputMaxItemsTracked)
}
buffer
}

override def merge(buffer: CombineInternal[Any], input: CombineInternal[Any])
: CombineInternal[Any] = {
if (!combineSizeSpecified) {
// check size
if (buffer.getMaxItemsTracked == ApproxTopK.VOID_MAX_ITEMS_TRACKED) {
// If buffer is a placeholder sketch, set it to the input sketch's max items tracked
buffer.setMaxItemsTracked(input.getMaxItemsTracked)
}
if (buffer.getMaxItemsTracked != input.getMaxItemsTracked) {
throw QueryExecutionErrors.approxTopKSketchSizeNotMatch(
buffer.getMaxItemsTracked,
input.getMaxItemsTracked
)
}
}
// check item data type
if (buffer.getItemDataType != null && input.getItemDataType != null &&
buffer.getItemDataType != input.getItemDataType) {
throw QueryExecutionErrors.approxTopKSketchTypeNotMatch(
buffer.getItemDataType,
input.getItemDataType
)
} else if (buffer.getItemDataType == null) {
// If buffer is a placeholder sketch, set it to the input sketch's item data type
buffer.setItemDataType(input.getItemDataType)
}
buffer.getSketch.merge(input.getSketch)
buffer
}

override def eval(buffer: CombineInternal[Any]): Any = {
val sketchBytes =
buffer.getSketch.toByteArray(ApproxTopK.genSketchSerDe(buffer.getItemDataType))
val maxItemsTracked = buffer.getMaxItemsTracked
val typeCode = ApproxTopK.dataTypeToBytes(buffer.getItemDataType)
InternalRow.apply(sketchBytes, null, maxItemsTracked, typeCode)
}

override def serialize(buffer: CombineInternal[Any]): Array[Byte] = {
val sketchBytes = buffer.getSketch.toByteArray(
ApproxTopK.genSketchSerDe(buffer.getItemDataType))
val maxItemsTrackedByte = buffer.getMaxItemsTracked.toByte
val itemDataTypeBytes = ApproxTopK.dataTypeToBytes(buffer.getItemDataType)
val byteArray = new Array[Byte](sketchBytes.length + 4)
byteArray(0) = maxItemsTrackedByte
System.arraycopy(itemDataTypeBytes, 0, byteArray, 1, itemDataTypeBytes.length)
System.arraycopy(sketchBytes, 0, byteArray, 4, sketchBytes.length)
byteArray
}

override def deserialize(buffer: Array[Byte]): CombineInternal[Any] = {
val maxItemsTracked = buffer(0).toInt
val itemDataTypeBytes = buffer.slice(1, 4)
val actualItemDataType = ApproxTopK.bytesToDataType(itemDataTypeBytes)
val sketchBytes = buffer.slice(4, buffer.length)
val sketch = ItemsSketch.getInstance(
Memory.wrap(sketchBytes), ApproxTopK.genSketchSerDe(actualItemDataType))
new CombineInternal[Any](sketch, actualItemDataType, maxItemsTracked)
}

override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression): Expression =
copy(state = newLeft, maxItemsTracked = newRight)

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def nullable: Boolean = false

override def prettyName: String =
getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("approx_top_k_combine")
}
Loading