Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ use datafusion_catalog::ScanArgs;
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::format::ExplainAnalyzeLevel;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
use datafusion_common::TableReference;
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema,
ScalarValue,
NullEquality, ScalarValue, TableReference,
};
use datafusion_datasource::file_groups::FileGroup;
use datafusion_datasource::memory::MemorySourceConfig;
Expand All @@ -85,7 +84,7 @@ use datafusion_expr::{
WindowFrame, WindowFrameBound, WriteOp,
};
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr::expressions::{Column, Literal};
use datafusion_physical_expr::{
create_physical_sort_exprs, LexOrdering, PhysicalSortExpr,
};
Expand All @@ -100,6 +99,7 @@ use datafusion_physical_plan::unnest::ListUnnest;

use async_trait::async_trait;
use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper};
use datafusion_physical_plan::result_table::ResultTableExec;
use futures::{StreamExt, TryStreamExt};
use itertools::{multiunzip, Itertools};
use log::debug;
Expand Down Expand Up @@ -1432,12 +1432,68 @@ impl DefaultPhysicalPlanner {
name, is_distinct, ..
}) => {
let [static_term, recursive_term] = children.two()?;
Arc::new(RecursiveQueryExec::try_new(
name.clone(),
static_term,
recursive_term,
*is_distinct,
)?)
let inner_schema = static_term.schema();
let group_by = PhysicalGroupBy::new_single(
inner_schema
.fields()
.iter()
.enumerate()
.map(|(i, f)| {
(Arc::new(Column::new(f.name(), i)) as _, f.name().clone())
})
.collect(),
);
if *is_distinct {
// We deduplicate each input to avoid duplicated values
// And we remove from the recursive term the only emitted values i.e. the results table.
Arc::new(RecursiveQueryExec::try_new(
name.clone(),
Arc::new(AggregateExec::try_new(
AggregateMode::Final,
group_by.clone(),
Vec::new(),
Vec::new(),
static_term,
Arc::clone(&inner_schema),
)?),
Arc::new(HashJoinExec::try_new(
Arc::new(AggregateExec::try_new(
AggregateMode::Final,
group_by,
Vec::new(),
Vec::new(),
recursive_term,
Arc::clone(&inner_schema),
)?),
Arc::new(ResultTableExec::new(
"union".into(),
Arc::clone(&inner_schema),
)),
inner_schema
.fields()
.iter()
.enumerate()
.map(|(i, f)| {
let col = Arc::new(Column::new(f.name(), i)) as _;
(Arc::clone(&col), col)
})
.collect(),
None,
&JoinType::LeftAnti,
None,
PartitionMode::CollectLeft,
NullEquality::NullEqualsNull,
)?),
true,
)?)
} else {
Arc::new(RecursiveQueryExec::try_new(
name.clone(),
static_term,
recursive_term,
false,
)?)
}
}

// N Children
Expand Down
5 changes: 5 additions & 0 deletions datafusion/core/tests/data/recursive_cte/closure.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
start,end
1,2
2,3
2,4
4,1
8 changes: 1 addition & 7 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef};
use datafusion_common::display::ToStringifiedPlan;
use datafusion_common::file_options::file_type::FileType;
use datafusion_common::{
exec_err, get_target_functional_dependencies, internal_datafusion_err, not_impl_err,
exec_err, get_target_functional_dependencies, internal_datafusion_err,
plan_datafusion_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef,
NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions,
};
Expand Down Expand Up @@ -178,12 +178,6 @@ impl LogicalPlanBuilder {
recursive_term: LogicalPlan,
is_distinct: bool,
) -> Result<Self> {
// TODO: we need to do a bunch of validation here. Maybe more.
if is_distinct {
return not_impl_err!(
"Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported"
);
}
// Ensure that the static term and the recursive term have the same number of fields
let static_fields_len = self.plan.schema().fields().len();
let recursive_fields_len = recursive_term.schema().fields().len();
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-plan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ pub mod placeholder_row;
pub mod projection;
pub mod recursive_query;
pub mod repartition;
pub mod result_table;
pub mod sorts;
pub mod spill;
pub mod stream;
Expand Down
54 changes: 41 additions & 13 deletions datafusion/physical-plan/src/recursive_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
//! Defines the recursive query plan

use std::any::Any;
use std::mem::take;
use std::sync::Arc;
use std::task::{Context, Poll};

use super::work_table::{ReservedBatches, WorkTable, WorkTableExec};
use crate::execution_plan::{Boundedness, EmissionType};
use crate::result_table::ResultTable;
use crate::{
metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet},
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
Expand Down Expand Up @@ -64,8 +66,9 @@ pub struct RecursiveQueryExec {
static_term: Arc<dyn ExecutionPlan>,
/// The dynamic part (recursive term)
recursive_term: Arc<dyn ExecutionPlan>,
/// Distinction
is_distinct: bool,
/// If is_distinct is true, holds the result table that saves all previous results
result_table: Option<Arc<ResultTable>>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
/// Cache holding plan properties like equivalences, output partitioning etc.
Expand All @@ -77,13 +80,22 @@ impl RecursiveQueryExec {
pub fn try_new(
name: String,
static_term: Arc<dyn ExecutionPlan>,
recursive_term: Arc<dyn ExecutionPlan>,
mut recursive_term: Arc<dyn ExecutionPlan>,
is_distinct: bool,
) -> Result<Self> {
// Each recursive query needs its own work table
let work_table = Arc::new(WorkTable::new());
// Use the same work table for both the WorkTableExec and the recursive term
let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?;
let result_table = if is_distinct {
let result_table = Arc::new(ResultTable::new());
// Use the same result table for both the ResultTableExec and the result term
recursive_term =
assign_work_table(recursive_term, Arc::clone(&result_table))?;
Some(result_table)
} else {
None
};
let cache = Self::compute_properties(static_term.schema());
Ok(RecursiveQueryExec {
name,
Expand All @@ -93,6 +105,7 @@ impl RecursiveQueryExec {
work_table,
metrics: ExecutionPlanMetricsSet::new(),
cache,
result_table,
})
}

Expand Down Expand Up @@ -193,6 +206,7 @@ impl ExecutionPlan for RecursiveQueryExec {
Ok(Box::pin(RecursiveQueryStream::new(
context,
Arc::clone(&self.work_table),
self.result_table.as_ref().map(Arc::clone),
Arc::clone(&self.recursive_term),
static_stream,
baseline_metrics,
Expand Down Expand Up @@ -237,16 +251,16 @@ impl DisplayAs for RecursiveQueryExec {
///
/// while batch := static_stream.next():
/// buffer.push(batch)
/// yield buffer
/// yield batch
///
/// while buffer.len() > 0:
/// sender, receiver = Channel()
/// register_continuation(handle_name, receiver)
/// register_work_table(handle_name, receiver)
/// sender.send(buffer.drain())
/// recursive_stream = recursive_term.execute()
/// while batch := recursive_stream.next():
/// buffer.append(batch)
/// yield buffer
/// yield batch
///
struct RecursiveQueryStream {
/// The context to be used for managing handlers & executing new tasks
Expand All @@ -268,6 +282,8 @@ struct RecursiveQueryStream {
buffer: Vec<RecordBatch>,
/// Tracks the memory used by the buffer
reservation: MemoryReservation,
/// The result table state, representing the table used for deduplication in case it is enabled
results_table: Option<Arc<ResultTable>>,
// /// Metrics.
_baseline_metrics: BaselineMetrics,
}
Expand All @@ -277,6 +293,7 @@ impl RecursiveQueryStream {
fn new(
task_context: Arc<TaskContext>,
work_table: Arc<WorkTable>,
results_table: Option<Arc<ResultTable>>,
recursive_term: Arc<dyn ExecutionPlan>,
static_stream: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
Expand All @@ -294,6 +311,7 @@ impl RecursiveQueryStream {
buffer: vec![],
reservation,
_baseline_metrics: baseline_metrics,
results_table,
}
}

Expand Down Expand Up @@ -327,11 +345,21 @@ impl RecursiveQueryStream {
return Poll::Ready(None);
}

// Update the union table with the current buffer
if self.results_table.is_some() {
// Note it's fine to take the memory reservation here, we are not cloning the underlying data,
// and the result table is going to outlive the work table.
let buffer = self.buffer.clone();
let reservation = self.reservation.take();
self.results_table
.as_mut()
.unwrap()
.append(buffer, reservation);
}

// Update the work table with the current buffer
let reserved_batches = ReservedBatches::new(
std::mem::take(&mut self.buffer),
self.reservation.take(),
);
let reserved_batches =
ReservedBatches::new(take(&mut self.buffer), self.reservation.take());
self.work_table.update(reserved_batches);

// We always execute (and re-execute iteratively) the first partition.
Expand All @@ -345,9 +373,9 @@ impl RecursiveQueryStream {
}
}

fn assign_work_table(
fn assign_work_table<T: Any + Send + Sync>(
plan: Arc<dyn ExecutionPlan>,
work_table: Arc<WorkTable>,
work_table: Arc<T>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut work_table_refs = 0;
plan.transform_down(|plan| {
Expand Down Expand Up @@ -380,7 +408,7 @@ fn assign_work_table(
fn reset_plan_states(plan: Arc<dyn ExecutionPlan>) -> Result<Arc<dyn ExecutionPlan>> {
plan.transform_up(|plan| {
// WorkTableExec's states have already been updated correctly.
if plan.as_any().is::<WorkTableExec>() {
if plan.as_any().is::<WorkTableExec>() || plan.as_any().is::<ResultTable>() {
Ok(Transformed::no(plan))
} else {
let new_plan = Arc::clone(&plan).reset_state()?;
Expand Down
Loading