Skip to content

Commit 527cb57

Browse files
authored
perf: Use DataFusion FilterExec for experimental native scans (#1395)
* Add boolean field to Filter's proto, set based on Comet native scan implementation. Planner uses that field to construct the correct FilterExec implementation. CometFilterExec does a deep copy of the batch due to logic in Comet Scan, while DF FilterExec can do a shallow copy because native Scans do not reuse batch buffers. * Refactor to reduce duplicate code. * Fix native test. * Address nit.
1 parent 57a4dca commit 527cb57

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

native/core/src/execution/planner.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
//! Converts Spark physical plan to DataFusion physical plan
1919
2020
use super::expressions::EvalMode;
21-
use crate::execution::operators::{CopyMode, FilterExec};
21+
use crate::execution::operators::CopyMode;
22+
use crate::execution::operators::FilterExec as CometFilterExec;
2223
use crate::{
2324
errors::ExpressionError,
2425
execution::{
@@ -79,6 +80,7 @@ use datafusion::datasource::listing::PartitionedFile;
7980
use datafusion::datasource::physical_plan::parquet::ParquetExecBuilder;
8081
use datafusion::datasource::physical_plan::FileScanConfig;
8182
use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
83+
use datafusion::physical_plan::filter::FilterExec as DataFusionFilterExec;
8284
use datafusion_comet_proto::{
8385
spark_expression::{
8486
self, agg_expr::ExprStruct as AggExprStruct, expr::ExprStruct, literal::Value, AggExpr,
@@ -992,10 +994,18 @@ impl PhysicalPlanner {
992994
let predicate =
993995
self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?;
994996

995-
let filter = Arc::new(FilterExec::try_new(
996-
predicate,
997-
Arc::clone(&child.native_plan),
998-
)?);
997+
let filter: Arc<dyn ExecutionPlan> = if filter.use_datafusion_filter {
998+
Arc::new(DataFusionFilterExec::try_new(
999+
predicate,
1000+
Arc::clone(&child.native_plan),
1001+
)?)
1002+
} else {
1003+
Arc::new(CometFilterExec::try_new(
1004+
predicate,
1005+
Arc::clone(&child.native_plan),
1006+
)?)
1007+
};
1008+
9991009
Ok((
10001010
scans,
10011011
Arc::new(SparkPlan::new(spark_plan.plan_id, filter, vec![child])),
@@ -2875,6 +2885,7 @@ mod tests {
28752885
children: vec![child_op],
28762886
op_struct: Some(OpStruct::Filter(spark_operator::Filter {
28772887
predicate: Some(expr),
2888+
use_datafusion_filter: false,
28782889
})),
28792890
}
28802891
}

native/proto/src/proto/operator.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ message Projection {
9999

100100
message Filter {
101101
spark.spark_expression.Expr predicate = 1;
102+
bool use_datafusion_filter = 2;
102103
}
103104

104105
message Sort {

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ import scala.collection.JavaConverters._
2323

2424
import org.apache.spark.internal.Logging
2525
import org.apache.spark.sql.catalyst.expressions._
26-
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, BitAndAgg, BitOrAgg, BitXorAgg, BloomFilterAggregate, Complete, Corr, Count, CovPopulation, CovSample, Final, First, Last, Max, Min, Partial, StddevPop, StddevSamp, Sum, VariancePop, VarianceSamp}
26+
import org.apache.spark.sql.catalyst.expressions.aggregate._
2727
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
2828
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, NormalizeNaNAndZero}
2929
import org.apache.spark.sql.catalyst.plans._
30-
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
30+
import org.apache.spark.sql.catalyst.plans.physical._
3131
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
32-
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision}
32+
import org.apache.spark.sql.comet._
3333
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
3434
import org.apache.spark.sql.execution
3535
import org.apache.spark.sql.execution._
@@ -46,12 +46,11 @@ import org.apache.spark.unsafe.types.UTF8String
4646

4747
import org.apache.comet.CometConf
4848
import org.apache.comet.CometSparkSessionExtensions.{isCometScan, isSpark34Plus, withInfo}
49-
import org.apache.comet.expressions.{CometCast, CometEvalMode, Compatible, Incompatible, RegExp, Unsupported}
49+
import org.apache.comet.expressions._
5050
import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc}
51-
import org.apache.comet.serde.ExprOuterClass.DataType.{DataTypeInfo, DecimalInfo, ListInfo, MapInfo, StructInfo}
51+
import org.apache.comet.serde.ExprOuterClass.DataType._
5252
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator}
53-
import org.apache.comet.shims.CometExprShim
54-
import org.apache.comet.shims.ShimQueryPlanSerde
53+
import org.apache.comet.shims.{CometExprShim, ShimQueryPlanSerde}
5554

5655
/**
5756
* An utility object for query plan and expression serialization.
@@ -2724,7 +2723,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
27242723
val cond = exprToProto(condition, child.output)
27252724

27262725
if (cond.isDefined && childOp.nonEmpty) {
2727-
val filterBuilder = OperatorOuterClass.Filter.newBuilder().setPredicate(cond.get)
2726+
val filterBuilder = OperatorOuterClass.Filter
2727+
.newBuilder()
2728+
.setPredicate(cond.get)
2729+
.setUseDatafusionFilter(
2730+
CometConf.COMET_NATIVE_SCAN_IMPL.get() == CometConf.SCAN_NATIVE_DATAFUSION ||
2731+
CometConf.COMET_NATIVE_SCAN_IMPL.get() == CometConf.SCAN_NATIVE_ICEBERG_COMPAT)
27282732
Some(result.setFilter(filterBuilder).build())
27292733
} else {
27302734
withInfo(op, condition, child)

0 commit comments

Comments
 (0)