diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index a45036be45..fc8df5ec10 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1107,15 +1107,8 @@ impl PhysicalPlanner { .collect(); let fetch = sort.fetch.map(|num| num as usize); - - // SortExec caches batches so we need to make a copy of incoming batches. Also, - // SortExec fails in some cases if we do not unpack dictionary-encoded arrays, and - // it would be more efficient if we could avoid that. - // https://github.com/apache/datafusion-comet/issues/963 - let child_copied = Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)); - let sort = Arc::new( - SortExec::new(LexOrdering::new(exprs?), Arc::clone(&child_copied)) + SortExec::new(LexOrdering::new(exprs?), Arc::clone(&child.native_plan)) .with_fetch(fetch), ); @@ -1285,7 +1278,7 @@ impl PhysicalPlanner { }?; let shuffle_writer = Arc::new(ShuffleWriterExec::try_new( - Self::wrap_in_copy_exec(Arc::clone(&child.native_plan)), + Arc::clone(&child.native_plan), partitioning, codec, writer.output_data_file.clone(), @@ -1344,6 +1337,7 @@ impl PhysicalPlanner { // if the child operator is `ScanExec`, because other operators after `ScanExec` // will create new arrays for the output batch. let input = if can_reuse_input_batch(&child.native_plan) { + // FIXME: handle me in Spark Planner Arc::new(CopyExec::new( Arc::clone(&child.native_plan), CopyMode::UnpackOrDeepCopy, @@ -1446,8 +1440,8 @@ impl PhysicalPlanner { // to copy the input batch to avoid the data corruption from reusing the input // batch. We also need to unpack dictionary arrays, because the join operators // do not support them. - let left = Self::wrap_in_copy_exec(Arc::clone(&join_params.left.native_plan)); - let right = Self::wrap_in_copy_exec(Arc::clone(&join_params.right.native_plan)); + let left = Arc::clone(&join_params.left.native_plan); + let right = Arc::clone(&join_params.right.native_plan); let hash_join = Arc::new(HashJoinExec::try_new( left, @@ -1535,6 +1529,20 @@ impl PhysicalPlanner { Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])), )) } + OpStruct::Copy(copy) => { + assert_eq!(children.len(), 1); + let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let copy_mode = if copy.mode == 0 { + CopyMode::UnpackOrDeepCopy + } else { + CopyMode::UnpackOrClone + }; + let copy = Arc::new(CopyExec::new(Arc::clone(&child.native_plan), copy_mode)); + Ok(( + scans, + Arc::new(SparkPlan::new(spark_plan.plan_id, copy, vec![child])), + )) + } } } @@ -1679,16 +1687,6 @@ impl PhysicalPlanner { )) } - /// Wrap an ExecutionPlan in a CopyExec, which will unpack any dictionary-encoded arrays - /// and make a deep copy of other arrays if the plan re-uses batches. - fn wrap_in_copy_exec(plan: Arc) -> Arc { - if can_reuse_input_batch(&plan) { - Arc::new(CopyExec::new(plan, CopyMode::UnpackOrDeepCopy)) - } else { - Arc::new(CopyExec::new(plan, CopyMode::UnpackOrClone)) - } - } - /// Create a DataFusion physical aggregate expression from Spark physical aggregate expression fn create_agg_expr( &self, diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 7ccce21a20..2d614cb81f 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -47,6 +47,7 @@ message Operator { HashJoin hash_join = 109; Window window = 110; NativeScan native_scan = 111; + Copy copy = 112; } } @@ -244,3 +245,13 @@ message Window { repeated spark.spark_expression.Expr partition_by_list = 3; Operator child = 4; } + + +enum CopyMode { + UnpackOrDeepCopy = 0; + UnpackOrClone = 1; +} + +message Copy { + CopyMode mode = 3; +} \ No newline at end of file diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 2383dd8440..89ac755852 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -19,6 +19,7 @@ package org.apache.comet.rules +import scala.annotation.tailrec import scala.collection.mutable.ListBuffer import org.apache.spark.sql.SparkSession @@ -338,6 +339,11 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { op.right, SerializedPlan(None))) + case op: CopyExec if op.children.forall(isCometNative) => + newPlanWithProto( + op, + CometCopyExec(_, op, op.output, op.copyMode, op.child, SerializedPlan(None))) + case op: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) && !op.children.forall(isCometNative) => @@ -671,7 +677,9 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { normalizePlan(plan) } - var newPlan = transform(normalizedPlan) + // FIXME: Should we move to separate Rule + var newPlan = transformAndAddCopyExec(normalizedPlan) + newPlan = transform(normalizedPlan) // if the plan cannot be run fully natively then explain why (when appropriate // config is enabled) @@ -751,6 +759,40 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { } } + private def transformAndAddCopyExec(plan: SparkPlan) = plan.transform { + case shj: ShuffledHashJoinExec => + val newLeft = wrapInCopyExec(shj.left) + val newRight = wrapInCopyExec(shj.right) + shj.copy(left = newLeft, right = newRight) + case se: SortExec => + val newChild = wrapInCopyExec(se.child) + se.copy(child = newChild) + case ee: ExpandExec => + val newChild = wrapInCopyExec(ee.child) + ee.copy(child = newChild) + } + + /// Returns true if given operator can return input array as output array without + /// modification. This is used to determine if we need to copy the input batch to avoid + /// data corruption from reusing the input batch. + @tailrec + private def canReuseInputBatch(plan: SparkPlan): Boolean = { + if (plan.isInstanceOf[ProjectExec] || plan.isInstanceOf[LocalLimitExec]) { + canReuseInputBatch(plan.children.head) + } else { + // FIXME + plan.isInstanceOf[CometScanExec] + } + } + + private def wrapInCopyExec(plan: SparkPlan): SparkPlan = { + if (canReuseInputBatch(plan)) { + CopyExec(plan, UnpackOrDeepCopy) + } else { + CopyExec(plan, UnpackOrClone) + } + } + /** * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with * partial mode, it will return None. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c46df52a10..856ad3c13b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2306,6 +2306,23 @@ object QueryPlanSerde extends Logging with CometExprShim { None } + case CopyExec(child, copyMode) => { + if (childOp.nonEmpty) { + val copyModeBuilder = if (copyMode == UnpackOrDeepCopy) { + OperatorOuterClass.CopyMode.UnpackOrClone + } else { + OperatorOuterClass.CopyMode.UnpackOrDeepCopy + } + val copyBuilder = OperatorOuterClass.Copy + .newBuilder() + .setMode(copyModeBuilder) + Some(result.setCopy(copyBuilder).build()) + } else { + withInfo(op, child) + None + } + } + case FilterExec(condition, child) if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) => val cond = exprToProto(condition, child.output) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CopyExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CopyExec.scala new file mode 100644 index 0000000000..42a2d1af40 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CopyExec.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} + +case class CopyExec(override val child: SparkPlan, copyMode: CopyMode) extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = { + // This method should never be invoked as CopyExec is an internal operator used + // during native execution offload to handle data deep copying/cloning Record batches + // The actual execution happens in the native layer through CometExecNode. + throw new UnsupportedOperationException( + "This method should not be called directly - this operator is meant for internal purposes only") + } + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) +} + +sealed abstract class CopyMode {} +case object UnpackOrDeepCopy extends CopyMode +case object UnpackOrClone extends CopyMode diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 3b9c6bdbca..791b839a72 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1000,3 +1000,22 @@ case class CometSinkPlaceHolder( override def stringArgs: Iterator[Any] = Iterator(originalPlan.output, child) } + +case class CometCopyExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + override val output: Seq[Attribute], + copyMode: CopyMode, + child: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometUnaryExec { + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + this.copy(child = newChild) + + override def verboseStringWithOperatorId(): String = { + s""" + |$formattedNodeName + |${ExplainUtils.generateFieldString("Input", child.output)} + |""".stripMargin + } +}