From f23469be7a78607000ee68b862b87e48a674a180 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Sun, 13 Jul 2025 22:41:17 +0800 Subject: [PATCH 1/3] feat: imporve LiteralGuarantee for the case like (a=1 AND b=1) OR (a=2 AND b=3) --- .../physical-expr/src/utils/guarantee.rs | 181 ++++++++++++++++-- 1 file changed, 163 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 8092dc3c1a61..61f78795cb70 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -184,16 +184,6 @@ impl LiteralGuarantee { .filter_map(|expr| ColOpLit::try_new(expr)) .collect::>(); - if terms.is_empty() { - return builder; - } - - // if not all terms are of the form (col literal), - // can't infer any guarantees - if terms.len() != disjunctions.len() { - return builder; - } - // if all terms are 'col literal' with the same column // and operation we can infer any guarantees // @@ -203,18 +193,68 @@ impl LiteralGuarantee { // foo is required for the expression to be true. // So we can only create a multi value guarantee for `=` // (or a single value). (e.g. ignore `a != foo OR a != bar`) - let first_term = &terms[0]; - if terms.iter().all(|term| { - term.col.name() == first_term.col.name() - && term.guarantee == Guarantee::In - }) { + let first_term = terms.first(); + if !terms.is_empty() + && terms.len() == disjunctions.len() + && terms.iter().all(|term| { + term.col.name() == first_term.unwrap().col.name() + && term.guarantee == Guarantee::In + }) + { builder.aggregate_multi_conjunct( - first_term.col, + first_term.unwrap().col, Guarantee::In, terms.iter().map(|term| term.lit.value()), ) } else { - // can't infer anything + // Handle disjunctions with conjunctions like (a = 1 AND b = 2) OR (a = 2 AND b = 3) + // Extract termsets from each disjunction + // if in each termset, they have same column, and the guarantee is In, + // we can infer a guarantee for the column + // e.g. (a = 1 AND b = 2) OR (a = 2 AND b = 3) is `a IN (1, 2) AND b IN (2, 3)` + // otherwise, we can't infer a guarantee + let termsets: Vec> = disjunctions + .iter() + .map(|expr| { + split_conjunction(expr) + .into_iter() + .filter_map(ColOpLit::try_new) + .filter(|term| term.guarantee == Guarantee::In) + .collect() + }) + .collect(); + + // Early return if any termset is empty (can't infer guarantees) + if termsets.iter().any(|terms| terms.is_empty()) { + return builder; + } + + // Find columns that appear in all termsets + let common_cols = find_common_columns(&termsets); + if common_cols.is_empty() { + return builder; + } + + // Build guarantees for common columns + let mut builder = builder; + for col in common_cols { + let literals: Vec<_> = termsets + .iter() + .filter_map(|terms| { + terms + .iter() + .find(|term| term.col == col) + .map(|term| term.lit.value()) + }) + .collect(); + + builder = builder.aggregate_multi_conjunct( + col, + Guarantee::In, + literals.into_iter(), + ); + } + builder } } @@ -410,6 +450,36 @@ impl<'a> ColOpLit<'a> { } } +/// Find columns that appear in all termsets +fn find_common_columns<'a>( + termsets: &[Vec>], +) -> Vec<&'a crate::expressions::Column> { + if termsets.is_empty() { + return Vec::new(); + } + + // Start with columns from the first termset + let mut common_cols: HashSet<_> = termsets[0].iter().map(|term| term.col).collect(); + + // check if any common_col in one termset occur many times + // e.g. (a = 1 AND a = 2) OR (a = 2 AND b = 3), should not infer a guarantee + // TODO: for above case, we can infer a IN (2) AND b IN (3) + if common_cols.len() != termsets[0].len() { + return Vec::new(); + } + + // Intersect with columns from remaining termsets + for termset in termsets.iter().skip(1) { + let termset_cols: HashSet<_> = termset.iter().map(|term| term.col).collect(); + if termset_cols.len() != termset.len() { + return Vec::new(); + } + common_cols = common_cols.intersection(&termset_cols).cloned().collect(); + } + + common_cols.into_iter().collect() +} + #[cfg(test)] mod test { use std::sync::LazyLock; @@ -824,13 +894,87 @@ mod test { ); } + #[test] + fn test_disjunction_and_conjunction_multi_column() { + // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1, 2])], + ); + // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) OR (b = 3) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))) + .or(col("b").eq(lit(3))), + vec![in_guarantee("b", [1, 2, 3])], + ); + // (a = "foo" AND b = 1) OR (a = "bar" AND b = 2) OR (c = 3) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))) + .or(col("c").eq(lit(3))), + vec![], + ); + // (a = "foo" AND b = 1) OR (a != "bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").not_eq(lit("bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("b", [1, 2])], + ); + // (a = "foo" AND b > 1) OR (a = "bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").gt(lit(1)))) + .or(col("a").eq(lit("bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("a", ["foo", "bar"])], + ); + // (a = "foo" AND b = 1) OR (b = 1 AND c = 2) OR (c = 3 AND a = "bar") + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("b").eq(lit(1)).and(col("c").eq(lit(2)))) + .or(col("c").eq(lit(3)).and(col("a").eq(lit("bar")))), + vec![], + ); + // (a = "foo" AND a = "bar") OR (a = "good" AND b = 1) + // TODO: this should be `a IN ("good") AND b IN (1)` + test_analyze( + (col("a").eq(lit("foo")).and(col("a").eq(lit("bar")))) + .or(col("a").eq(lit("good")).and(col("b").eq(lit(1)))), + vec![], + ); + // (a = "foo" AND a = "foo") OR (a = "good" AND b = 1) + // TODO: this should be `a IN ("foo", "good")` + test_analyze( + (col("a").eq(lit("foo")).and(col("a").eq(lit("foo")))) + .or(col("a").eq(lit("good")).and(col("b").eq(lit(1)))), + vec![], + ); + // (a = "foo" AND b = 3) OR (b = 4 AND b = 1) OR (b = 2 AND a = "bar") + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(3)))) + .or(col("b").eq(lit(4)).and(col("b").eq(lit(1)))) + .or(col("b").eq(lit(2)).and(col("a").eq(lit("bar")))), + vec![], + ); + // (b = 1 AND b > 3) OR (a = "foo" AND b = 4) + test_analyze( + (col("b").eq(lit(1)).and(col("b").gt(lit(3)))) + .or(col("a").eq(lit("foo")).and(col("b").eq(lit(4)))), + // if b isn't 1 or 4, it can not be true (though the expression actually can never be true) + vec![in_guarantee("b", [1, 4])], + ); + } + /// Tests that analyzing expr results in the expected guarantees fn test_analyze(expr: Expr, expected: Vec) { println!("Begin analyze of {expr}"); let schema = schema(); let physical_expr = logical2physical(&expr, &schema); - let actual = LiteralGuarantee::analyze(&physical_expr); + let actual = LiteralGuarantee::analyze(&physical_expr) + .into_iter() + .sorted_by_key(|g| g.column.name().to_string()) + .collect::>(); assert_eq!( expected, actual, "expr: {expr}\ @@ -867,6 +1011,7 @@ mod test { Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), ])) }); Arc::clone(&SCHEMA) From 89dc6be99882013b92b634d40557731a9dceffd9 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Tue, 15 Jul 2025 09:54:42 +0800 Subject: [PATCH 2/3] support inlist --- .../physical-expr/src/utils/guarantee.rs | 196 +++++++++++++----- 1 file changed, 149 insertions(+), 47 deletions(-) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 61f78795cb70..43b0eab27be9 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -129,35 +129,15 @@ impl LiteralGuarantee { .as_any() .downcast_ref::() { - // Only support single-column inlist currently, multi-column inlist is not supported - let col = inlist - .expr() - .as_any() - .downcast_ref::(); - let Some(col) = col else { - return builder; - }; - - let literals = inlist - .list() - .iter() - .map(|e| e.as_any().downcast_ref::()) - .collect::>>(); - let Some(literals) = literals else { - return builder; - }; - - let guarantee = if inlist.negated() { - Guarantee::NotIn + if let Some(inlist) = ColInList::try_new(inlist) { + builder.aggregate_multi_conjunct( + inlist.col, + inlist.guarantee, + inlist.list.iter().map(|lit| lit.value()), + ) } else { - Guarantee::In - }; - - builder.aggregate_multi_conjunct( - col, - guarantee, - literals.iter().map(|e| e.value()), - ) + builder + } } else { // split disjunction: OR OR ... let disjunctions = split_disjunction(expr); @@ -213,13 +193,13 @@ impl LiteralGuarantee { // we can infer a guarantee for the column // e.g. (a = 1 AND b = 2) OR (a = 2 AND b = 3) is `a IN (1, 2) AND b IN (2, 3)` // otherwise, we can't infer a guarantee - let termsets: Vec> = disjunctions + let termsets: Vec> = disjunctions .iter() .map(|expr| { split_conjunction(expr) .into_iter() - .filter_map(ColOpLit::try_new) - .filter(|term| term.guarantee == Guarantee::In) + .filter_map(ColOpLitOrInList::try_new) + .filter(|term| term.guarantee() == Guarantee::In) .collect() }) .collect(); @@ -241,11 +221,13 @@ impl LiteralGuarantee { let literals: Vec<_> = termsets .iter() .filter_map(|terms| { - terms - .iter() - .find(|term| term.col == col) - .map(|term| term.lit.value()) + terms.iter().find(|term| term.col() == col).map( + |term| { + term.lits().into_iter().map(|lit| lit.value()) + }, + ) }) + .flatten() .collect(); builder = builder.aggregate_multi_conjunct( @@ -402,7 +384,7 @@ struct ColOpLit<'a> { } impl<'a> ColOpLit<'a> { - /// Returns Some(ColEqLit) if the expression is either: + /// Returns Some(ColOpLit) if the expression is either: /// 1. `col literal` /// 2. `literal col` /// 3. operator is `=` or `!=` @@ -450,16 +432,101 @@ impl<'a> ColOpLit<'a> { } } +/// Represents a single `col [not]in literal` expression +struct ColInList<'a> { + col: &'a crate::expressions::Column, + guarantee: Guarantee, + list: Vec<&'a crate::expressions::Literal>, +} + +impl<'a> ColInList<'a> { + /// Returns Some(ColInList) if the expression is either: + /// 1. `col (literal1, literal2, ...)` + /// 3. operator is `in` or `not in` + /// + /// Returns None otherwise + fn try_new(inlist: &'a crate::expressions::InListExpr) -> Option { + // Only support single-column inlist currently, multi-column inlist is not supported + let col = inlist + .expr() + .as_any() + .downcast_ref::(); + let Some(col) = col else { + return None; + }; + + let literals = inlist + .list() + .iter() + .map(|e| e.as_any().downcast_ref::()) + .collect::>>(); + let Some(literals) = literals else { + return None; + }; + + let guarantee = if inlist.negated() { + Guarantee::NotIn + } else { + Guarantee::In + }; + + Some(Self { + col, + guarantee, + list: literals, + }) + } +} + +/// Represents a single `col [not]in literal` expression or a single `col literal` expression +enum ColOpLitOrInList<'a> { + ColOpLit(ColOpLit<'a>), + ColInList(ColInList<'a>), +} + +impl<'a> ColOpLitOrInList<'a> { + fn try_new(expr: &'a Arc) -> Option { + match expr + .as_any() + .downcast_ref::() + { + Some(inlist) => Some(Self::ColInList(ColInList::try_new(inlist)?)), + None => ColOpLit::try_new(expr).map(Self::ColOpLit), + } + } + + fn guarantee(&self) -> Guarantee { + match self { + Self::ColOpLit(col_op_lit) => col_op_lit.guarantee, + Self::ColInList(col_in_list) => col_in_list.guarantee, + } + } + + fn col(&self) -> &'a crate::expressions::Column { + match self { + Self::ColOpLit(col_op_lit) => col_op_lit.col, + Self::ColInList(col_in_list) => col_in_list.col, + } + } + + fn lits(&self) -> Vec<&'a crate::expressions::Literal> { + match self { + Self::ColOpLit(col_op_lit) => vec![col_op_lit.lit], + Self::ColInList(col_in_list) => col_in_list.list.clone(), + } + } +} + /// Find columns that appear in all termsets fn find_common_columns<'a>( - termsets: &[Vec>], + termsets: &[Vec>], ) -> Vec<&'a crate::expressions::Column> { if termsets.is_empty() { return Vec::new(); } // Start with columns from the first termset - let mut common_cols: HashSet<_> = termsets[0].iter().map(|term| term.col).collect(); + let mut common_cols: HashSet<_> = termsets[0].iter().map(|term| term.col()).collect(); // check if any common_col in one termset occur many times // e.g. (a = 1 AND a = 2) OR (a = 2 AND b = 3), should not infer a guarantee @@ -470,7 +537,7 @@ fn find_common_columns<'a>( // Intersect with columns from remaining termsets for termset in termsets.iter().skip(1) { - let termset_cols: HashSet<_> = termset.iter().map(|term| term.col).collect(); + let termset_cols: HashSet<_> = termset.iter().map(|term| term.col()).collect(); if termset_cols.len() != termset.len() { return Vec::new(); } @@ -878,12 +945,11 @@ mod test { vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])], ); // b IN (1, 2, 3) OR b = 2 - // TODO this should be in_guarantee("b", [1, 2, 3]) but currently we don't support to analyze this kind of disjunction. Only `ColOpLit OR ColOpLit` is supported. test_analyze( col("b") .in_list(vec![lit(1), lit(2), lit(3)], false) .or(col("b").eq(lit(2))), - vec![], + vec![in_guarantee("b", [1, 2, 3])], ); // b IN (1, 2, 3) OR b != 3 test_analyze( @@ -916,12 +982,6 @@ mod test { .or(col("c").eq(lit(3))), vec![], ); - // (a = "foo" AND b = 1) OR (a != "bar" AND b = 2) - test_analyze( - (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) - .or(col("a").not_eq(lit("bar")).and(col("b").eq(lit(2)))), - vec![in_guarantee("b", [1, 2])], - ); // (a = "foo" AND b > 1) OR (a = "bar" AND b = 2) test_analyze( (col("a").eq(lit("foo")).and(col("b").gt(lit(1)))) @@ -963,6 +1023,48 @@ mod test { // if b isn't 1 or 4, it can not be true (though the expression actually can never be true) vec![in_guarantee("b", [1, 4])], ); + // (a = "foo" AND b = 1) OR (a != "bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").not_eq(lit("bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("b", [1, 2])], + ); + // (a = "foo" AND b = 1) OR (a LIKE "%bar" AND b = 2) + test_analyze( + (col("a").eq(lit("foo")).and(col("b").eq(lit(1)))) + .or(col("a").like(lit("%bar")).and(col("b").eq(lit(2)))), + vec![in_guarantee("b", [1, 2])], + ); + // (a IN ("foo", "bar") AND b = 5) OR (a IN ("foo", "bar") AND b = 6) + test_analyze( + (col("a") + .in_list(vec![lit("foo"), lit("bar")], false) + .and(col("b").eq(lit(5)))) + .or(col("a") + .in_list(vec![lit("foo"), lit("bar")], false) + .and(col("b").eq(lit(6)))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [5, 6])], + ); + // (a IN ("foo", "bar") AND b = 5) OR (a IN ("foo") AND b = 6) + test_analyze( + (col("a") + .in_list(vec![lit("foo"), lit("bar")], false) + .and(col("b").eq(lit(5)))) + .or(col("a") + .in_list(vec![lit("foo")], false) + .and(col("b").eq(lit(6)))), + vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [5, 6])], + ); + // (a NOT IN ("foo", "bar") AND b = 5) OR (a NOT IN ("foo") AND b = 6) + test_analyze( + (col("a") + .in_list(vec![lit("foo"), lit("bar")], true) + .and(col("b").eq(lit(5)))) + .or(col("a") + .in_list(vec![lit("foo")], true) + .and(col("b").eq(lit(6)))), + vec![in_guarantee("b", [5, 6])], + ); } /// Tests that analyzing expr results in the expected guarantees From add2ae0f31babd5fa09426eb4f8d2b8dfdf7eaf4 Mon Sep 17 00:00:00 2001 From: Huaijin Date: Tue, 15 Jul 2025 10:05:44 +0800 Subject: [PATCH 3/3] fmt and clippy --- datafusion/physical-expr/src/utils/guarantee.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index 43b0eab27be9..8a57cc7b7c15 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -450,19 +450,13 @@ impl<'a> ColInList<'a> { let col = inlist .expr() .as_any() - .downcast_ref::(); - let Some(col) = col else { - return None; - }; + .downcast_ref::()?; let literals = inlist .list() .iter() .map(|e| e.as_any().downcast_ref::()) - .collect::>>(); - let Some(literals) = literals else { - return None; - }; + .collect::>>()?; let guarantee = if inlist.negated() { Guarantee::NotIn