Skip to content

Commit af27b37

Browse files
authored
chore: Refactor aggregate serde to use map (#2055)
1 parent d688655 commit af27b37

File tree

1 file changed

+30
-23
lines changed

1 file changed

+30
-23
lines changed

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,29 @@ object QueryPlanSerde extends Logging with CometExprShim {
131131
classOf[SparkPartitionID] -> CometSparkPartitionId,
132132
classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId)
133133

134+
/**
135+
* Mapping of Spark aggregate expression class to Comet expression handler.
136+
*/
137+
private val aggrSerdeMap: Map[Class[_], CometAggregateExpressionSerde] = Map(
138+
classOf[Sum] -> CometSum,
139+
classOf[Average] -> CometAverage,
140+
classOf[Count] -> CometCount,
141+
classOf[Min] -> CometMin,
142+
classOf[Max] -> CometMax,
143+
classOf[First] -> CometFirst,
144+
classOf[Last] -> CometLast,
145+
classOf[BitAndAgg] -> CometBitAndAgg,
146+
classOf[BitOrAgg] -> CometBitOrAgg,
147+
classOf[BitXorAgg] -> CometBitXOrAgg,
148+
classOf[CovSample] -> CometCovSample,
149+
classOf[CovPopulation] -> CometCovPopulation,
150+
classOf[VarianceSamp] -> CometVarianceSamp,
151+
classOf[VariancePop] -> CometVariancePop,
152+
classOf[StddevSamp] -> CometStddevSamp,
153+
classOf[StddevPop] -> CometStddevPop,
154+
classOf[Corr] -> CometCorr,
155+
classOf[BloomFilterAggregate] -> CometBloomFilterAggregate)
156+
134157
def emitWarning(reason: String): Unit = {
135158
logWarning(s"Comet native execution is disabled due to: $reason")
136159
}
@@ -436,33 +459,17 @@ object QueryPlanSerde extends Logging with CometExprShim {
436459
return None
437460
}
438461

439-
val cometExpr: CometAggregateExpressionSerde = aggExpr.aggregateFunction match {
440-
case _: Sum => CometSum
441-
case _: Average => CometAverage
442-
case _: Count => CometCount
443-
case _: Min => CometMin
444-
case _: Max => CometMax
445-
case _: First => CometFirst
446-
case _: Last => CometLast
447-
case _: BitAndAgg => CometBitAndAgg
448-
case _: BitOrAgg => CometBitOrAgg
449-
case _: BitXorAgg => CometBitXOrAgg
450-
case _: CovSample => CometCovSample
451-
case _: CovPopulation => CometCovPopulation
452-
case _: VarianceSamp => CometVarianceSamp
453-
case _: VariancePop => CometVariancePop
454-
case _: StddevSamp => CometStddevSamp
455-
case _: StddevPop => CometStddevPop
456-
case _: Corr => CometCorr
457-
case _: BloomFilterAggregate => CometBloomFilterAggregate
458-
case fn =>
462+
val fn = aggExpr.aggregateFunction
463+
val cometExpr = aggrSerdeMap.get(fn.getClass)
464+
cometExpr match {
465+
case Some(handler) =>
466+
handler.convert(aggExpr, fn, inputs, binding, conf)
467+
case _ =>
459468
val msg = s"unsupported Spark aggregate function: ${fn.prettyName}"
460469
emitWarning(msg)
461470
withInfo(aggExpr, msg, fn.children: _*)
462-
return None
463-
471+
None
464472
}
465-
cometExpr.convert(aggExpr, aggExpr.aggregateFunction, inputs, binding, conf)
466473
}
467474

468475
def evalModeToProto(evalMode: CometEvalMode.Value): ExprOuterClass.EvalMode = {

0 commit comments

Comments
 (0)