@@ -131,6 +131,29 @@ object QueryPlanSerde extends Logging with CometExprShim {
131
131
classOf [SparkPartitionID ] -> CometSparkPartitionId ,
132
132
classOf [MonotonicallyIncreasingID ] -> CometMonotonicallyIncreasingId )
133
133
134
+ /**
135
+ * Mapping of Spark aggregate expression class to Comet expression handler.
136
+ */
137
+ private val aggrSerdeMap : Map [Class [_], CometAggregateExpressionSerde ] = Map (
138
+ classOf [Sum ] -> CometSum ,
139
+ classOf [Average ] -> CometAverage ,
140
+ classOf [Count ] -> CometCount ,
141
+ classOf [Min ] -> CometMin ,
142
+ classOf [Max ] -> CometMax ,
143
+ classOf [First ] -> CometFirst ,
144
+ classOf [Last ] -> CometLast ,
145
+ classOf [BitAndAgg ] -> CometBitAndAgg ,
146
+ classOf [BitOrAgg ] -> CometBitOrAgg ,
147
+ classOf [BitXorAgg ] -> CometBitXOrAgg ,
148
+ classOf [CovSample ] -> CometCovSample ,
149
+ classOf [CovPopulation ] -> CometCovPopulation ,
150
+ classOf [VarianceSamp ] -> CometVarianceSamp ,
151
+ classOf [VariancePop ] -> CometVariancePop ,
152
+ classOf [StddevSamp ] -> CometStddevSamp ,
153
+ classOf [StddevPop ] -> CometStddevPop ,
154
+ classOf [Corr ] -> CometCorr ,
155
+ classOf [BloomFilterAggregate ] -> CometBloomFilterAggregate )
156
+
134
157
def emitWarning (reason : String ): Unit = {
135
158
logWarning(s " Comet native execution is disabled due to: $reason" )
136
159
}
@@ -436,33 +459,17 @@ object QueryPlanSerde extends Logging with CometExprShim {
436
459
return None
437
460
}
438
461
439
- val cometExpr : CometAggregateExpressionSerde = aggExpr.aggregateFunction match {
440
- case _ : Sum => CometSum
441
- case _ : Average => CometAverage
442
- case _ : Count => CometCount
443
- case _ : Min => CometMin
444
- case _ : Max => CometMax
445
- case _ : First => CometFirst
446
- case _ : Last => CometLast
447
- case _ : BitAndAgg => CometBitAndAgg
448
- case _ : BitOrAgg => CometBitOrAgg
449
- case _ : BitXorAgg => CometBitXOrAgg
450
- case _ : CovSample => CometCovSample
451
- case _ : CovPopulation => CometCovPopulation
452
- case _ : VarianceSamp => CometVarianceSamp
453
- case _ : VariancePop => CometVariancePop
454
- case _ : StddevSamp => CometStddevSamp
455
- case _ : StddevPop => CometStddevPop
456
- case _ : Corr => CometCorr
457
- case _ : BloomFilterAggregate => CometBloomFilterAggregate
458
- case fn =>
462
+ val fn = aggExpr.aggregateFunction
463
+ val cometExpr = aggrSerdeMap.get(fn.getClass)
464
+ cometExpr match {
465
+ case Some (handler) =>
466
+ handler.convert(aggExpr, fn, inputs, binding, conf)
467
+ case _ =>
459
468
val msg = s " unsupported Spark aggregate function: ${fn.prettyName}"
460
469
emitWarning(msg)
461
470
withInfo(aggExpr, msg, fn.children: _* )
462
- return None
463
-
471
+ None
464
472
}
465
- cometExpr.convert(aggExpr, aggExpr.aggregateFunction, inputs, binding, conf)
466
473
}
467
474
468
475
def evalModeToProto (evalMode : CometEvalMode .Value ): ExprOuterClass .EvalMode = {
0 commit comments