diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 0d784c917969..ff8a6dfd8e49 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -64,10 +64,9 @@ use datafusion_catalog::ScanArgs; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeLevel; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::TableReference; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - ScalarValue, + NullEquality, ScalarValue, TableReference, }; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::memory::MemorySourceConfig; @@ -85,7 +84,7 @@ use datafusion_expr::{ WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; -use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ create_physical_sort_exprs, LexOrdering, PhysicalSortExpr, }; @@ -100,6 +99,7 @@ use datafusion_physical_plan::unnest::ListUnnest; use async_trait::async_trait; use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper}; +use datafusion_physical_plan::result_table::ResultTableExec; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::debug; @@ -1432,12 +1432,68 @@ impl DefaultPhysicalPlanner { name, is_distinct, .. }) => { let [static_term, recursive_term] = children.two()?; - Arc::new(RecursiveQueryExec::try_new( - name.clone(), - static_term, - recursive_term, - *is_distinct, - )?) + let inner_schema = static_term.schema(); + let group_by = PhysicalGroupBy::new_single( + inner_schema + .fields() + .iter() + .enumerate() + .map(|(i, f)| { + (Arc::new(Column::new(f.name(), i)) as _, f.name().clone()) + }) + .collect(), + ); + if *is_distinct { + // We deduplicate each input to avoid duplicated values + // And we remove from the recursive term the only emitted values i.e. the results table. + Arc::new(RecursiveQueryExec::try_new( + name.clone(), + Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by.clone(), + Vec::new(), + Vec::new(), + static_term, + Arc::clone(&inner_schema), + )?), + Arc::new(HashJoinExec::try_new( + Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + Vec::new(), + Vec::new(), + recursive_term, + Arc::clone(&inner_schema), + )?), + Arc::new(ResultTableExec::new( + "union".into(), + Arc::clone(&inner_schema), + )), + inner_schema + .fields() + .iter() + .enumerate() + .map(|(i, f)| { + let col = Arc::new(Column::new(f.name(), i)) as _; + (Arc::clone(&col), col) + }) + .collect(), + None, + &JoinType::LeftAnti, + None, + PartitionMode::CollectLeft, + NullEquality::NullEqualsNull, + )?), + true, + )?) + } else { + Arc::new(RecursiveQueryExec::try_new( + name.clone(), + static_term, + recursive_term, + false, + )?) + } } // N Children diff --git a/datafusion/core/tests/data/recursive_cte/closure.csv b/datafusion/core/tests/data/recursive_cte/closure.csv new file mode 100644 index 000000000000..d012e9777a2b --- /dev/null +++ b/datafusion/core/tests/data/recursive_cte/closure.csv @@ -0,0 +1,5 @@ +start,end +1,2 +2,3 +2,4 +4,1 \ No newline at end of file diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 7a283b0420d3..33b0ad41d570 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -54,7 +54,7 @@ use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err, + exec_err, get_target_functional_dependencies, internal_datafusion_err, plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; @@ -178,12 +178,6 @@ impl LogicalPlanBuilder { recursive_term: LogicalPlan, is_distinct: bool, ) -> Result { - // TODO: we need to do a bunch of validation here. Maybe more. - if is_distinct { - return not_impl_err!( - "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported" - ); - } // Ensure that the static term and the recursive term have the same number of fields let static_fields_len = self.plan.schema().fields().len(); let recursive_fields_len = recursive_term.schema().fields().len(); diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 17628fd8ad1d..5696da140b5a 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -80,6 +80,7 @@ pub mod placeholder_row; pub mod projection; pub mod recursive_query; pub mod repartition; +pub mod result_table; pub mod sorts; pub mod spill; pub mod stream; diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index b4cdf2dff2bf..ecbba7381c17 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -18,11 +18,13 @@ //! Defines the recursive query plan use std::any::Any; +use std::mem::take; use std::sync::Arc; use std::task::{Context, Poll}; use super::work_table::{ReservedBatches, WorkTable, WorkTableExec}; use crate::execution_plan::{Boundedness, EmissionType}; +use crate::result_table::ResultTable; use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -64,8 +66,9 @@ pub struct RecursiveQueryExec { static_term: Arc, /// The dynamic part (recursive term) recursive_term: Arc, - /// Distinction is_distinct: bool, + /// If is_distinct is true, holds the result table that saves all previous results + result_table: Option>, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Cache holding plan properties like equivalences, output partitioning etc. @@ -77,13 +80,22 @@ impl RecursiveQueryExec { pub fn try_new( name: String, static_term: Arc, - recursive_term: Arc, + mut recursive_term: Arc, is_distinct: bool, ) -> Result { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new()); // Use the same work table for both the WorkTableExec and the recursive term - let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?; + recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?; + let result_table = if is_distinct { + let result_table = Arc::new(ResultTable::new()); + // Use the same result table for both the ResultTableExec and the result term + recursive_term = + assign_work_table(recursive_term, Arc::clone(&result_table))?; + Some(result_table) + } else { + None + }; let cache = Self::compute_properties(static_term.schema()); Ok(RecursiveQueryExec { name, @@ -93,6 +105,7 @@ impl RecursiveQueryExec { work_table, metrics: ExecutionPlanMetricsSet::new(), cache, + result_table, }) } @@ -193,6 +206,7 @@ impl ExecutionPlan for RecursiveQueryExec { Ok(Box::pin(RecursiveQueryStream::new( context, Arc::clone(&self.work_table), + self.result_table.as_ref().map(Arc::clone), Arc::clone(&self.recursive_term), static_stream, baseline_metrics, @@ -237,16 +251,16 @@ impl DisplayAs for RecursiveQueryExec { /// /// while batch := static_stream.next(): /// buffer.push(batch) -/// yield buffer +/// yield batch /// /// while buffer.len() > 0: /// sender, receiver = Channel() -/// register_continuation(handle_name, receiver) +/// register_work_table(handle_name, receiver) /// sender.send(buffer.drain()) /// recursive_stream = recursive_term.execute() /// while batch := recursive_stream.next(): /// buffer.append(batch) -/// yield buffer +/// yield batch /// struct RecursiveQueryStream { /// The context to be used for managing handlers & executing new tasks @@ -268,6 +282,8 @@ struct RecursiveQueryStream { buffer: Vec, /// Tracks the memory used by the buffer reservation: MemoryReservation, + /// The result table state, representing the table used for deduplication in case it is enabled + results_table: Option>, // /// Metrics. _baseline_metrics: BaselineMetrics, } @@ -277,6 +293,7 @@ impl RecursiveQueryStream { fn new( task_context: Arc, work_table: Arc, + results_table: Option>, recursive_term: Arc, static_stream: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, @@ -294,6 +311,7 @@ impl RecursiveQueryStream { buffer: vec![], reservation, _baseline_metrics: baseline_metrics, + results_table, } } @@ -327,11 +345,21 @@ impl RecursiveQueryStream { return Poll::Ready(None); } + // Update the union table with the current buffer + if self.results_table.is_some() { + // Note it's fine to take the memory reservation here, we are not cloning the underlying data, + // and the result table is going to outlive the work table. + let buffer = self.buffer.clone(); + let reservation = self.reservation.take(); + self.results_table + .as_mut() + .unwrap() + .append(buffer, reservation); + } + // Update the work table with the current buffer - let reserved_batches = ReservedBatches::new( - std::mem::take(&mut self.buffer), - self.reservation.take(), - ); + let reserved_batches = + ReservedBatches::new(take(&mut self.buffer), self.reservation.take()); self.work_table.update(reserved_batches); // We always execute (and re-execute iteratively) the first partition. @@ -345,9 +373,9 @@ impl RecursiveQueryStream { } } -fn assign_work_table( +fn assign_work_table( plan: Arc, - work_table: Arc, + work_table: Arc, ) -> Result> { let mut work_table_refs = 0; plan.transform_down(|plan| { @@ -380,7 +408,7 @@ fn assign_work_table( fn reset_plan_states(plan: Arc) -> Result> { plan.transform_up(|plan| { // WorkTableExec's states have already been updated correctly. - if plan.as_any().is::() { + if plan.as_any().is::() || plan.as_any().is::() { Ok(Transformed::no(plan)) } else { let new_plan = Arc::clone(&plan).reset_state()?; diff --git a/datafusion/physical-plan/src/result_table.rs b/datafusion/physical-plan/src/result_table.rs new file mode 100644 index 000000000000..7ed4f3b4e0ce --- /dev/null +++ b/datafusion/physical-plan/src/result_table.rs @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the result table query plan + +use std::any::Any; +use std::fmt; +use std::sync::{Arc, Mutex}; + +use crate::coop::cooperative; +use crate::execution_plan::{Boundedness, EmissionType, SchedulingType}; +use crate::memory::MemoryStream; +use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + SendableRecordBatchStream, Statistics, +}; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use datafusion_common::{internal_err, Result}; +use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; + +/// The name is from PostgreSQL's terminology. +/// See +/// This table serves as a mirror or buffer, allowing to add or read data. +#[derive(Debug)] +pub struct ResultTable { + inner: Mutex, +} + +#[derive(Debug)] +struct ResultTableInner { + batches: Vec, + reservation: Option, +} + +impl ResultTable { + /// Create a new result table. + pub(super) fn new() -> Self { + Self { + inner: Mutex::new(ResultTableInner { + batches: Vec::new(), + reservation: None, + }), + } + } + + /// Return the content of the result table. + /// This will be called by the [`ResultTableExec`] when it is executed. + fn get(&self) -> Vec { + self.inner.lock().unwrap().batches.clone() + } + + /// Add extra data to the table + pub(super) fn append( + &self, + batches: Vec, + reservation: MemoryReservation, + ) { + let mut guard = self.inner.lock().unwrap(); + if let Some(r) = &mut guard.reservation { + r.grow(reservation.size()); + } else { + guard.reservation = Some(reservation); + } + guard.batches.extend(batches); + } +} + +/// A temporary "result table" operation where the input data will be +/// taken from the named handle during the execution and will be re-published +/// as is (kind of like a mirror). +/// +/// It is used in deduplicating recursive queries to store previously emitted results and avoid +/// considering them again in future iterations. +/// +/// This is key to avoiding infinite loops in transitive closures on arbitrary graphs. +#[derive(Clone, Debug)] +pub struct ResultTableExec { + /// Name of the relation handler + name: String, + /// The schema of the stream + schema: SchemaRef, + /// The result table + result_table: Arc, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, +} + +impl ResultTableExec { + /// Create a new execution plan for a result table exec. + pub fn new(name: String, schema: SchemaRef) -> Self { + let cache = Self::compute_properties(Arc::clone(&schema)); + Self { + name, + schema, + metrics: ExecutionPlanMetricsSet::new(), + result_table: Arc::new(ResultTable::new()), + cache, + } + } + + /// Arc clone of ref to schema + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + .with_scheduling_type(SchedulingType::Cooperative) + } +} + +impl DisplayAs for ResultTableExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "ResultTableExec: name={}", self.name) + } + DisplayFormatType::TreeRender => { + write!(f, "name={}", self.name) + } + } + } +} + +impl ExecutionPlan for ResultTableExec { + fn name(&self) -> &'static str { + "ResultTableExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + Vec::new() + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(Arc::clone(&self) as _) + } + + /// Stream the batches that were written to the result table. + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + // ResultTable streams must be the plan base. + if partition != 0 { + return internal_err!( + "ResultTableExec got an invalid partition {partition} (expected 0)" + ); + } + Ok(Box::pin(cooperative(MemoryStream::try_new( + self.result_table.get(), + Arc::clone(&self.schema), + None, + )?))) + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } + + fn partition_statistics(&self, _partition: Option) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } + + /// Injects run-time state into this `ResultTableExec`. + /// + /// The only state this node currently understands is an [`Arc`]. + /// If `state` can be down-cast to that type, a new `ResultTableExec` backed + /// by the provided work table is returned. Otherwise `None` is returned + /// so that callers can attempt to propagate the state further down the + /// execution plan tree. + fn with_new_state( + &self, + state: Arc, + ) -> Option> { + // Down-cast to the expected state type; propagate `None` on failure + let work_table = state.downcast::().ok()?; + + Some(Arc::new(Self { + name: self.name.clone(), + schema: Arc::clone(&self.schema), + metrics: ExecutionPlanMetricsSet::new(), + result_table: work_table, + cache: self.cache.clone(), + })) + } +} diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index a581bcb539a9..e5522f36b3f1 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -58,18 +58,6 @@ WITH RECURSIVE nodes AS ( statement ok set datafusion.execution.enable_recursive_ctes = true; - -# DISTINCT UNION is not supported -query error DataFusion error: This feature is not implemented: Recursive queries with a distinct 'UNION' \(in which the previous iteration's results will be de\-duplicated\) is not supported -WITH RECURSIVE nodes AS ( - SELECT 1 as id - UNION - SELECT id + 1 as id - FROM nodes - WHERE id < 3 -) SELECT * FROM nodes - - # trivial recursive CTE works query I rowsort WITH RECURSIVE nodes AS ( @@ -1049,6 +1037,75 @@ physical_plan 05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false] 06)------WorkTableExec: name=r +# setup +statement ok +CREATE EXTERNAL TABLE closure STORED as CSV LOCATION '../core/tests/data/recursive_cte/closure.csv' OPTIONS ('format.has_header' 'true'); + +# transitive closure with loop +query II +WITH RECURSIVE trans AS ( + SELECT * FROM closure + UNION + SELECT l.start, r.end + FROM trans as l, closure AS r + WHERE l.end = r.start +) SELECT * FROM trans ORDER BY start, end +---- +1 1 +1 2 +1 3 +1 4 +2 1 +2 2 +2 3 +2 4 +4 1 +4 2 +4 3 +4 4 + +query TT +EXPLAIN WITH RECURSIVE trans AS ( + SELECT * FROM closure + UNION + SELECT l.start, r.end + FROM trans as l, closure AS r + WHERE l.end = r.start +) SELECT * FROM trans +---- +logical_plan +01)SubqueryAlias: trans +02)--RecursiveQuery: is_distinct=true +03)----Projection: closure.start, closure.end +04)------TableScan: closure +05)----Projection: l.start, r.end +06)------Inner Join: l.end = r.start +07)--------SubqueryAlias: l +08)----------TableScan: trans +09)--------SubqueryAlias: r +10)----------TableScan: closure +physical_plan +01)RecursiveQueryExec: name=trans, is_distinct=true +02)--AggregateExec: mode=Final, gby=[start@0 as start, end@1 as end], aggr=[] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true +04)--CoalescePartitionsExec +05)----CoalesceBatchesExec: target_batch_size=8182 +06)------HashJoinExec: mode=CollectLeft, join_type=LeftAnti, on=[(start@0, start@0), (end@1, end@1)], NullsEqual: true +07)--------AggregateExec: mode=Final, gby=[start@0 as start, end@1 as end], aggr=[] +08)----------CoalescePartitionsExec +09)------------CoalesceBatchesExec: target_batch_size=8182 +10)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(end@1, start@0)], projection=[start@0, end@3] +11)----------------CoalesceBatchesExec: target_batch_size=8182 +12)------------------RepartitionExec: partitioning=Hash([end@1], 4), input_partitions=4 +13)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)----------------------WorkTableExec: name=trans +15)----------------CoalesceBatchesExec: target_batch_size=8182 +16)------------------RepartitionExec: partitioning=Hash([start@0], 4), input_partitions=4 +17)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +18)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/closure.csv]]}, projection=[start, end], file_type=csv, has_header=true +19)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +20)----------ResultTableExec: name=union + statement count 0 set datafusion.execution.enable_recursive_ctes = false;