From fe61624c7e1f1c34848a059962dae02a62e939d5 Mon Sep 17 00:00:00 2001 From: Ahmed Mezghani Date: Mon, 23 Jun 2025 10:03:31 +0200 Subject: [PATCH 1/2] fix: extend recursive protection to prevent stack overflows in additional functions --- Cargo.lock | 2 ++ datafusion/core/Cargo.toml | 1 + datafusion/expr/src/expr.rs | 1 + datafusion/expr/src/logical_plan/invariants.rs | 2 ++ datafusion/expr/src/logical_plan/plan.rs | 1 + datafusion/expr/src/utils.rs | 3 +++ datafusion/optimizer/src/decorrelate.rs | 1 + datafusion/optimizer/src/decorrelate_predicate_subquery.rs | 1 + datafusion/optimizer/src/eliminate_cross_join.rs | 4 ++++ datafusion/optimizer/src/eliminate_group_by_constant.rs | 1 + datafusion/optimizer/src/eliminate_limit.rs | 1 + datafusion/optimizer/src/eliminate_outer_join.rs | 1 + datafusion/optimizer/src/push_down_filter.rs | 2 ++ datafusion/optimizer/src/push_down_limit.rs | 1 + datafusion/optimizer/src/scalar_subquery_to_join.rs | 1 + datafusion/optimizer/src/simplify_expressions/utils.rs | 4 ++++ datafusion/physical-expr/Cargo.toml | 4 ++++ datafusion/physical-expr/src/expressions/binary.rs | 3 +++ datafusion/physical-expr/src/intervals/utils.rs | 1 + datafusion/physical-expr/src/planner.rs | 2 ++ datafusion/physical-optimizer/src/enforce_distribution.rs | 2 ++ datafusion/physical-optimizer/src/filter_pushdown.rs | 1 + datafusion/physical-optimizer/src/limit_pushdown.rs | 1 + datafusion/substrait/Cargo.toml | 2 ++ .../src/logical_plan/consumer/expr/scalar_function.rs | 1 + 25 files changed, 44 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 194483b7ab3a..06e5c2adbb23 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2443,6 +2443,7 @@ dependencies = [ "paste", "petgraph 0.8.2", "rand 0.9.1", + "recursive", "rstest", ] @@ -2656,6 +2657,7 @@ dependencies = [ "object_store", "pbjson-types", "prost", + "recursive", "serde_json", "substrait", "tokio", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 9747f4424060..50bf1a8ff85d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -76,6 +76,7 @@ recursive_protection = [ "datafusion-common/recursive_protection", "datafusion-expr/recursive_protection", "datafusion-optimizer/recursive_protection", + "datafusion-physical-expr/recursive_protection", "datafusion-physical-optimizer/recursive_protection", "datafusion-sql/recursive_protection", "sqlparser/recursive-protection", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 97f83305dcbe..b514661a7c22 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2092,6 +2092,7 @@ impl Normalizeable for Expr { } impl NormalizeEq for Expr { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn normalize_eq(&self, other: &Self) -> bool { match (self, other) { ( diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs index d8d6739b0e8f..78731e9a8b30 100644 --- a/datafusion/expr/src/logical_plan/invariants.rs +++ b/datafusion/expr/src/logical_plan/invariants.rs @@ -71,6 +71,7 @@ pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> { /// /// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode) /// for more details of user-provided extension node invariants. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> { plan.apply_with_subqueries(|plan: &LogicalPlan| { if let LogicalPlan::Extension(Extension { node }) = plan { @@ -372,6 +373,7 @@ fn check_aggregation_in_scalar_subquery( Ok(()) } +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan { match inner_plan { LogicalPlan::Projection(projection) => { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 876c14f1000f..e8b7e39f444c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2278,6 +2278,7 @@ impl Filter { Self::try_new_internal(predicate, input) } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn is_allowed_filter_type(data_type: &DataType) -> bool { match data_type { // Interpret NULL as a missing boolean value. diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index b7851e530099..7fab05554425 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -934,6 +934,7 @@ pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { split_conjunction_impl(expr, vec![]) } +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> { match expr { Expr::BinaryExpr(BinaryExpr { @@ -1051,6 +1052,7 @@ pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { split_binary_owned_impl(expr, op, vec![]) } +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn split_binary_owned_impl( expr: Expr, operator: Operator, @@ -1078,6 +1080,7 @@ pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { split_binary_impl(expr, op, vec![]) } +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn split_binary_impl<'a>( expr: &'a Expr, operator: Operator, diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 63236787743a..07677d29ed1c 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -445,6 +445,7 @@ fn can_pullup_over_aggregation(expr: &Expr) -> bool { } } +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn collect_local_correlated_cols( plan: &LogicalPlan, all_cols_map: &HashMap>, diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index a72657bf689d..ae4185b31ca3 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -55,6 +55,7 @@ impl OptimizerRule for DecorrelatePredicateSubquery { true } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index ae1d7df46d52..357d3aa747b5 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -228,6 +228,7 @@ fn rewrite_children( /// Assumes can_flatten_join_inputs has returned true and thus the plan can be /// flattened. Adds all leaf inputs to `all_inputs` and join_keys to /// possible_join_keys +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn flatten_join_inputs( plan: LogicalPlan, possible_join_keys: &mut JoinKeySet, @@ -264,6 +265,7 @@ fn flatten_join_inputs( /// `flatten_join_inputs` /// /// Must stay in sync with `flatten_join_inputs` +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { // can only flatten inner / cross joins match plan { @@ -368,6 +370,7 @@ fn find_inner_join( } /// Extract join keys from a WHERE clause +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { match op { @@ -399,6 +402,7 @@ fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { /// # Returns /// * `Some()` when there are few remaining predicates in filter_expr /// * `None` otherwise +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { match expr { Expr::BinaryExpr(BinaryExpr { diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs index 9c47ce024f91..3c8fe1c5d0e7 100644 --- a/datafusion/optimizer/src/eliminate_group_by_constant.rs +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -95,6 +95,7 @@ impl OptimizerRule for EliminateGroupByConstant { /// /// Intended to be used only within this rule, helper function, which heavily /// relies on `SimplifyExpressions` result. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn is_constant_expression(expr: &Expr) -> bool { match expr { Expr::Alias(e) => is_constant_expression(&e.expr), diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 2007e0c82045..dc5835ffe6bb 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -53,6 +53,7 @@ impl OptimizerRule for EliminateLimit { true } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 45877642f276..090280714144 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -172,6 +172,7 @@ pub fn eliminate_outer( /// For IS NOT NULL/NOT expr, always returns false for NULL input. /// extracts columns from these exprs. /// For all other exprs, fall through +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn extract_non_nullable_columns( expr: &Expr, non_nullable_cols: &mut Vec, diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7c4a02678899..8f14ed939f51 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -366,6 +366,7 @@ fn extract_or_clauses_for_join<'a>( /// Otherwise, return None. /// /// For other clause, apply the rule above to extract clause. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { let mut predicate = None; @@ -764,6 +765,7 @@ impl OptimizerRule for PushDownFilter { true } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index ec042dd350ca..f64b18562863 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -48,6 +48,7 @@ impl OptimizerRule for PushDownLimit { true } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 2f9a2f6bb9ed..58fae9a34e26 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -74,6 +74,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { true } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn rewrite( &self, plan: LogicalPlan, diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 4df0e125eb18..74518e9d2520 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -67,6 +67,7 @@ pub static POWS_OF_TEN: [i128; 38] = [ /// returns true if `needle` is found in a chain of search_op /// expressions. Such as: (A AND B) AND C +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { @@ -86,6 +87,7 @@ pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { /// expressions. Such as: A ^ (A ^ (B ^ A)) pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr { /// Deletes recursively 'needles' in a chain of xor expressions + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn recursive_delete_xor_in_expr( expr: &Expr, needle: &Expr, @@ -266,6 +268,7 @@ pub fn as_bool_lit(expr: &Expr) -> Result> { /// For Between, not (A between B and C) ===> (A not between B and C) /// not (A not between B and C) ===> (A between B and C) /// For others, use Not clause +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn negate_clause(expr: Expr) -> Expr { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => { @@ -335,6 +338,7 @@ pub fn negate_clause(expr: Expr) -> Expr { /// For Negative: /// ~(~A) ===> A /// For others, use Negative clause +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn distribute_negation(expr: Expr) -> Expr { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => { diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 881969ef32ad..14694de58d3c 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,6 +37,9 @@ workspace = true [lib] name = "datafusion_physical_expr" +[features] +recursive_protection = ["dep:recursive"] + [dependencies] ahash = { workspace = true } arrow = { workspace = true } @@ -52,6 +55,7 @@ itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" petgraph = "0.8.2" +recursive = { workspace = true, optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 798e68a459ce..b82bcb6a586d 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -343,6 +343,7 @@ impl PhysicalExpr for BinaryExpr { self } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn data_type(&self, input_schema: &Schema) -> Result { BinaryTypeCoercer::new( &self.left.data_type(input_schema)?, @@ -356,6 +357,7 @@ impl PhysicalExpr for BinaryExpr { Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?) } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn evaluate(&self, batch: &RecordBatch) -> Result { use arrow::compute::kernels::numeric::*; @@ -648,6 +650,7 @@ impl PhysicalExpr for BinaryExpr { } } + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn write_child( f: &mut std::fmt::Formatter, diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 910631ef4a43..269df59ce17f 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -35,6 +35,7 @@ use datafusion_expr::Operator; /// We do not support every type of [`Operator`]s either. Over time, this check /// will relax as more types of `PhysicalExpr`s and `Operator`s are supported. /// Currently, [`CastExpr`], [`NegativeExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { let expr_any = expr.as_any(); if let Some(binary_expr) = expr_any.downcast_ref::() { diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index fbc19b1202ee..1bb0f783dd48 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -105,6 +105,7 @@ use datafusion_expr::{ /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references /// to qualified or unqualified fields by name. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, @@ -385,6 +386,7 @@ pub fn create_physical_expr( } /// Create vector of Physical Expression from a vector of logical expression +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn create_physical_exprs<'a, I>( exprs: I, input_dfschema: &DFSchema, diff --git a/datafusion/physical-optimizer/src/enforce_distribution.rs b/datafusion/physical-optimizer/src/enforce_distribution.rs index 39eb557ea601..e61758d47237 100644 --- a/datafusion/physical-optimizer/src/enforce_distribution.rs +++ b/datafusion/physical-optimizer/src/enforce_distribution.rs @@ -691,6 +691,7 @@ pub fn reorder_join_keys_to_inputs( } /// Reorder the current join keys ordering based on either left partition or right partition +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn reorder_current_join_keys( join_keys: JoinKeyPairs, left_partition: Option<&Partitioning>, @@ -1011,6 +1012,7 @@ fn remove_dist_changing_operators( /// " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", /// " DataSourceExec: file_groups={2 groups: \[\[x], \[y]]}, projection=\[a, b, c, d, e], output_ordering=\[a@0 ASC], file_type=parquet", /// ``` +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub fn replace_order_preserving_variants( mut context: DistributionContext, ) -> Result { diff --git a/datafusion/physical-optimizer/src/filter_pushdown.rs b/datafusion/physical-optimizer/src/filter_pushdown.rs index 885280576b4b..82bf60126a3b 100644 --- a/datafusion/physical-optimizer/src/filter_pushdown.rs +++ b/datafusion/physical-optimizer/src/filter_pushdown.rs @@ -428,6 +428,7 @@ enum ParentPredicateStates { Supported, } +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn push_down_filters( node: Arc, parent_predicates: Vec>, diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 7469c3af9344..21f02a26eff2 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -262,6 +262,7 @@ pub fn pushdown_limit_helper( } /// Pushes down the limit through the plan. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] pub(crate) fn pushdown_limits( pushdown_plan: Arc, global_state: GlobalRequirements, diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 8bc3eccc684d..ce701303ec0c 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -39,6 +39,7 @@ itertools = { workspace = true } object_store = { workspace = true } pbjson-types = { workspace = true } prost = { workspace = true } +recursive = { workspace = true, optional = true } substrait = { version = "0.57", features = ["serde"] } url = { workspace = true } tokio = { workspace = true, features = ["fs"] } @@ -54,6 +55,7 @@ insta = { workspace = true } default = ["physical"] physical = ["datafusion/parquet"] protoc = ["substrait/protoc"] +recursive_protection = ["dep:recursive", "datafusion/recursive_protection"] [package.metadata.docs.rs] # Use default features ("physical") for docs, plus "protoc". "protoc" is needed diff --git a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs index 7797c935211f..2e7752d5361b 100644 --- a/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs +++ b/datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs @@ -128,6 +128,7 @@ fn arg_list_to_binary_op_tree(op: Operator, mut args: Vec) -> Result /// /// `take_len` represents the number of elements to take from `args` before returning. /// We use `take_len` to avoid recursively building a `Take>>` type. +#[cfg_attr(feature = "recursive_protection", recursive::recursive)] fn arg_list_to_binary_op_tree_inner( op: Operator, args: &mut Drain, From 1a9584b64349acfea70b6e82dbaabbdeaa625758 Mon Sep 17 00:00:00 2001 From: Ahmed Mezghani Date: Mon, 30 Jun 2025 16:39:21 +0200 Subject: [PATCH 2/2] Add reproducing test cases --- Cargo.lock | 1 + datafusion/substrait/Cargo.toml | 1 + .../tests/cases/deeply_nested_plan.rs | 117 ++++++ datafusion/substrait/tests/cases/mod.rs | 1 + .../test_plans/deeply_nested_tpl.json | 382 ++++++++++++++++++ 5 files changed, 502 insertions(+) create mode 100644 datafusion/substrait/tests/cases/deeply_nested_plan.rs create mode 100644 datafusion/substrait/tests/testdata/test_plans/deeply_nested_tpl.json diff --git a/Cargo.lock b/Cargo.lock index 06e5c2adbb23..823e1ca27579 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2652,6 +2652,7 @@ dependencies = [ "chrono", "datafusion", "datafusion-functions-aggregate", + "futures", "insta", "itertools 0.14.0", "object_store", diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index ce701303ec0c..0cddec65f719 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -35,6 +35,7 @@ async-recursion = "1.0" async-trait = { workspace = true } chrono = { workspace = true } datafusion = { workspace = true } +futures = { workspace = true } itertools = { workspace = true } object_store = { workspace = true } pbjson-types = { workspace = true } diff --git a/datafusion/substrait/tests/cases/deeply_nested_plan.rs b/datafusion/substrait/tests/cases/deeply_nested_plan.rs new file mode 100644 index 000000000000..d96b3b89b662 --- /dev/null +++ b/datafusion/substrait/tests/cases/deeply_nested_plan.rs @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for deeply nested plans causing stack overflows + +#[cfg(test)] +mod tests { + use crate::utils::test::add_plan_schemas_to_ctx; + use datafusion::common::Result; + use datafusion::logical_expr::LogicalPlan; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use futures::StreamExt; + use serde_json::{json, Value}; + use substrait::proto::Plan; + + // The depth of the nested plan to generate (number of arguments in literal list) + const DEPTH: usize = 1000; + + #[tokio::test] + async fn test_stack_overflow_planning() -> Result<()> { + let (ctx, plan) = setup().await?; + ctx.state().create_physical_plan(&plan).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_stack_overflow_execution() -> Result<()> { + let (ctx, plan) = setup().await?; + let plan = ctx.state().create_physical_plan(&plan).await?; + let mut records = + datafusion::physical_plan::execute_stream(plan, ctx.task_ctx().clone())?; + while let Some(record) = records.next().await { + record?; + } + + Ok(()) + } + + /// Setup returns a session context and a logical plan for a deeply nested substrait plan. + async fn setup() -> Result<(SessionContext, LogicalPlan)> { + let proto = generate_deep_plan(DEPTH); + + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; + let plan = from_substrait_plan(&ctx.state(), &proto).await?; + Ok((ctx, plan)) + } + + /// Generate a deeply nested substrait plan by extending the arguments of the scalar function + /// in deeply_nested_tpl.json. This avoids committing a large json file to the repo. + fn generate_deep_plan(depth: usize) -> Plan { + let template = include_str!("../testdata/test_plans/deeply_nested_tpl.json"); + let mut data: Value = + serde_json::from_str(template).expect("failed to parse json"); + + // Locate the `arguments` array we want to extend + let args = data + .pointer_mut("/relations/0/root/input/project/input/aggregate/input/filter/condition/scalarFunction/arguments/2/value/scalarFunction/arguments") + .and_then(Value::as_array_mut) + .expect("couldn't find the arguments array"); + + // Insert N new arguments + for i in 1..depth { + let new_arg = json!( { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "string": format!("VALUE_{}", i) + } + } + } + ] + } + } + }); + args.push(new_arg); + } + + serde_json::from_value(data).expect("failed to deserialize from value") + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index 777246e4139b..18c854bb00c6 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -16,6 +16,7 @@ // under the License. mod consumer_integration; +mod deeply_nested_plan; mod emit_kind_tests; mod function_test; mod logical_plans; diff --git a/datafusion/substrait/tests/testdata/test_plans/deeply_nested_tpl.json b/datafusion/substrait/tests/testdata/test_plans/deeply_nested_tpl.json new file mode 100644 index 000000000000..cb8fd32df9e8 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/deeply_nested_tpl.json @@ -0,0 +1,382 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, + { + "extensionUriAnchor": 5, + "uri": "/functions_arithmetic.yaml" + }, + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, + { + "extensionUriAnchor": 2, + "uri": "/functions_string.yaml" + }, + { + "extensionUriAnchor": 3, + "uri": "/functions_comparison.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, + { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "like:str_str" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "not_equal:any_any" + } + }, + { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 3, + "name": "count:" + } + }, + { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 4, + "name": "sum:fp64" + } + }, + { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 5, + "name": "coalesce:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 2, + 3 + ] + } + }, + "input": { + "aggregate": { + "common": { + "direct": {} + }, + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "size", + "name", + "id" + ], + "struct": { + "types": [ + { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { + "fields": [ + { + "fp64": 0.1, + "nullable": true + }, + { + "string": "field_1", + "nullable": true + }, + { + "string": "field_2", + "nullable": true + } + ] + }, + { + "fields": [ + { + "fp64": 0.1, + "nullable": true + }, + { + "string": "field_1", + "nullable": true + }, + { + "string": "field_2", + "nullable": true + } + ] + } + ] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "string": "%field_1%" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "string": "%field_2%" + } + } + } + ] + } + } + }, + { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "string": "VALUE_0" + } + } + } + ] + } + } + } + ] + } + } + } + ] + } + } + } + }, + "groupings": [ + {} + ], + "measures": [ + { + "measure": { + "functionReference": 3, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + }, + { + "measure": { + "functionReference": 4, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "fp64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + }, + { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "fp64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "fp64": 0 + } + } + } + ] + } + } + ] + } + }, + "names": [ + "count", + "size" + ] + } + } + ], + "version": { + "minorNumber": 1, + "producer": "producer" + } +} \ No newline at end of file