@@ -174,7 +174,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
174
174
/**
175
175
* Mapping of Spark aggregate expression class to Comet expression handler.
176
176
*/
177
- private val aggrSerdeMap : Map [Class [_], CometAggregateExpressionSerde ] = Map (
177
+ private val aggrSerdeMap : Map [Class [_], CometAggregateExpressionSerde [_] ] = Map (
178
178
classOf [Sum ] -> CometSum ,
179
179
classOf [Average ] -> CometAverage ,
180
180
classOf [Count ] -> CometCount ,
@@ -498,7 +498,9 @@ object QueryPlanSerde extends Logging with CometExprShim {
498
498
val cometExpr = aggrSerdeMap.get(fn.getClass)
499
499
cometExpr match {
500
500
case Some (handler) =>
501
- handler.convert(aggExpr, fn, inputs, binding, conf)
501
+ handler
502
+ .asInstanceOf [CometAggregateExpressionSerde [AggregateFunction ]]
503
+ .convert(aggExpr, fn, inputs, binding, conf)
502
504
case _ =>
503
505
withInfo(
504
506
aggExpr,
@@ -2456,7 +2458,7 @@ trait CometExpressionSerde[T <: Expression] {
2456
2458
/**
2457
2459
* Trait for providing serialization logic for aggregate expressions.
2458
2460
*/
2459
- trait CometAggregateExpressionSerde {
2461
+ trait CometAggregateExpressionSerde [ T <: AggregateFunction ] {
2460
2462
2461
2463
/**
2462
2464
* Convert a Spark expression into a protocol buffer representation that can be passed into
@@ -2479,7 +2481,7 @@ trait CometAggregateExpressionSerde {
2479
2481
*/
2480
2482
def convert (
2481
2483
aggExpr : AggregateExpression ,
2482
- expr : Expression ,
2484
+ expr : T ,
2483
2485
inputs : Seq [Attribute ],
2484
2486
binding : Boolean ,
2485
2487
conf : SQLConf ): Option [ExprOuterClass .AggExpr ]
0 commit comments