Skip to content

Commit f1f3b66

Browse files
committed
Refactor InListExpr to store arrays and support structs.
Changes: - Enhance InListExpr to efficiently store homogeneous lists as arrays and avoid a conversion to Vec<PhysicalExpr> by adding an internal InListStorage enum with Array and Exprs variants - Re-use existing hashing and comparison utilities to support Struct arrays and other complex types - Add public function `in_list_from_array(expr, list_array, negated)` for creating InList from arrays
1 parent 4eb87cd commit f1f3b66

File tree

8 files changed

+1209
-238
lines changed

8 files changed

+1209
-238
lines changed

datafusion/physical-expr/src/expressions/in_list.rs

Lines changed: 1140 additions & 189 deletions
Large diffs are not rendered by default.

datafusion/physical-expr/src/expressions/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ pub use cast_column::CastColumnExpr;
4646
pub use column::{col, with_new_schema, Column};
4747
pub use datafusion_expr::utils::format_state_name;
4848
pub use dynamic_filters::DynamicFilterPhysicalExpr;
49-
pub use in_list::{in_list, InListExpr};
49+
pub use in_list::{in_list, in_list_from_array, InListExpr};
5050
pub use is_not_null::{is_not_null, IsNotNullExpr};
5151
pub use is_null::{is_null, IsNullExpr};
5252
pub use like::{like, LikeExpr};

datafusion/physical-expr/src/utils/guarantee.rs

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,12 @@ impl LiteralGuarantee {
9393
/// Create a new instance of the guarantee if the provided operator is
9494
/// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to
9595
/// create these structures from an predicate (boolean expression).
96-
fn new<'a>(
96+
fn new(
9797
column_name: impl Into<String>,
9898
guarantee: Guarantee,
99-
literals: impl IntoIterator<Item = &'a ScalarValue>,
99+
literals: impl IntoIterator<Item = ScalarValue>,
100100
) -> Self {
101-
let literals: HashSet<_> = literals.into_iter().cloned().collect();
101+
let literals: HashSet<_> = literals.into_iter().collect();
102102

103103
Self {
104104
column: Column::from_name(column_name),
@@ -130,11 +130,18 @@ impl LiteralGuarantee {
130130
.downcast_ref::<crate::expressions::InListExpr>()
131131
{
132132
if let Some(inlist) = ColInList::try_new(inlist) {
133-
builder.aggregate_multi_conjunct(
134-
inlist.col,
135-
inlist.guarantee,
136-
inlist.list.iter().map(|lit| lit.value()),
137-
)
133+
match inlist.inlist.list() {
134+
Ok(list) => builder.aggregate_multi_conjunct(
135+
inlist.col,
136+
inlist.guarantee,
137+
list.iter().filter_map(|expr| {
138+
expr.as_any()
139+
.downcast_ref::<crate::expressions::Literal>()
140+
.map(|lit| lit.value().clone())
141+
}),
142+
),
143+
Err(_) => builder,
144+
}
138145
} else {
139146
builder
140147
}
@@ -184,7 +191,7 @@ impl LiteralGuarantee {
184191
builder.aggregate_multi_conjunct(
185192
first_term.unwrap().col,
186193
Guarantee::In,
187-
terms.iter().map(|term| term.lit.value()),
194+
terms.iter().map(|term| term.lit.value().clone()),
188195
)
189196
} else {
190197
// Handle disjunctions with conjunctions like (a = 1 AND b = 2) OR (a = 2 AND b = 3)
@@ -221,19 +228,18 @@ impl LiteralGuarantee {
221228
let literals: Vec<_> = termsets
222229
.iter()
223230
.filter_map(|terms| {
224-
terms.iter().find(|term| term.col() == col).map(
225-
|term| {
226-
term.lits().into_iter().map(|lit| lit.value())
227-
},
228-
)
231+
terms
232+
.iter()
233+
.find(|term| term.col() == col)
234+
.and_then(|term| term.lit_values().ok())
229235
})
230236
.flatten()
231237
.collect();
232238

233239
builder = builder.aggregate_multi_conjunct(
234240
col,
235241
Guarantee::In,
236-
literals.into_iter(),
242+
literals,
237243
);
238244
}
239245

@@ -296,7 +302,7 @@ impl<'a> GuaranteeBuilder<'a> {
296302
self.aggregate_multi_conjunct(
297303
col_op_lit.col,
298304
col_op_lit.guarantee,
299-
[col_op_lit.lit.value()],
305+
[col_op_lit.lit.value().clone()],
300306
)
301307
}
302308

@@ -313,7 +319,7 @@ impl<'a> GuaranteeBuilder<'a> {
313319
mut self,
314320
col: &'a crate::expressions::Column,
315321
guarantee: Guarantee,
316-
new_values: impl IntoIterator<Item = &'a ScalarValue>,
322+
new_values: impl IntoIterator<Item = ScalarValue>,
317323
) -> Self {
318324
let key = (col, guarantee);
319325
if let Some(index) = self.map.get(&key) {
@@ -336,20 +342,19 @@ impl<'a> GuaranteeBuilder<'a> {
336342
// another `AND a != 6` we know that a must not be either 5 or 6
337343
// for the expression to be true
338344
Guarantee::NotIn => {
339-
let new_values: HashSet<_> = new_values.into_iter().collect();
340-
existing.literals.extend(new_values.into_iter().cloned());
345+
existing.literals.extend(new_values);
341346
}
342347
Guarantee::In => {
343348
let intersection = new_values
344349
.into_iter()
345-
.filter(|new_value| existing.literals.contains(*new_value))
350+
.filter(|new_value| existing.literals.contains(new_value))
346351
.collect::<Vec<_>>();
347352
// for an In guarantee, if the intersection is not empty, we can extend the guarantee
348353
// e.g. `a IN (1,2,3) AND a IN (2,3,4)` is `a IN (2,3)`
349354
// otherwise, we invalidate the guarantee
350355
// e.g. `a IN (1,2,3) AND a IN (4,5,6)` is `a IN ()`, which is invalid
351356
if !intersection.is_empty() {
352-
existing.literals = intersection.into_iter().cloned().collect();
357+
existing.literals = intersection.into_iter().collect();
353358
} else {
354359
// at least one was not, so invalidate the guarantee
355360
*entry = None;
@@ -436,7 +441,7 @@ impl<'a> ColOpLit<'a> {
436441
struct ColInList<'a> {
437442
col: &'a crate::expressions::Column,
438443
guarantee: Guarantee,
439-
list: Vec<&'a crate::expressions::Literal>,
444+
inlist: &'a crate::expressions::InListExpr,
440445
}
441446

442447
impl<'a> ColInList<'a> {
@@ -452,11 +457,12 @@ impl<'a> ColInList<'a> {
452457
.as_any()
453458
.downcast_ref::<crate::expressions::Column>()?;
454459

455-
let literals = inlist
456-
.list()
457-
.iter()
458-
.map(|e| e.as_any().downcast_ref::<crate::expressions::Literal>())
459-
.collect::<Option<Vec<_>>>()?;
460+
let list = inlist.list().ok()?;
461+
// Verify all items are literals
462+
for expr in &list {
463+
expr.as_any()
464+
.downcast_ref::<crate::expressions::Literal>()?;
465+
}
460466

461467
let guarantee = if inlist.negated() {
462468
Guarantee::NotIn
@@ -467,7 +473,7 @@ impl<'a> ColInList<'a> {
467473
Some(Self {
468474
col,
469475
guarantee,
470-
list: literals,
476+
inlist,
471477
})
472478
}
473479
}
@@ -503,10 +509,20 @@ impl<'a> ColOpLitOrInList<'a> {
503509
}
504510
}
505511

506-
fn lits(&self) -> Vec<&'a crate::expressions::Literal> {
512+
fn lit_values(&self) -> datafusion_common::Result<Vec<ScalarValue>> {
507513
match self {
508-
Self::ColOpLit(col_op_lit) => vec![col_op_lit.lit],
509-
Self::ColInList(col_in_list) => col_in_list.list.clone(),
514+
Self::ColOpLit(col_op_lit) => Ok(vec![col_op_lit.lit.value().clone()]),
515+
Self::ColInList(col_in_list) => {
516+
let list = col_in_list.inlist.list()?;
517+
Ok(list
518+
.iter()
519+
.filter_map(|expr| {
520+
expr.as_any()
521+
.downcast_ref::<crate::expressions::Literal>()
522+
.map(|lit| lit.value().clone())
523+
})
524+
.collect())
525+
}
510526
}
511527
}
512528
}
@@ -1088,7 +1104,7 @@ mod test {
10881104
S: Into<ScalarValue> + 'a,
10891105
{
10901106
let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect();
1091-
LiteralGuarantee::new(column, Guarantee::In, literals.iter())
1107+
LiteralGuarantee::new(column, Guarantee::In, literals)
10921108
}
10931109

10941110
/// Guarantee that the expression is true if the column is NOT any of the specified values
@@ -1098,7 +1114,7 @@ mod test {
10981114
S: Into<ScalarValue> + 'a,
10991115
{
11001116
let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect();
1101-
LiteralGuarantee::new(column, Guarantee::NotIn, literals.iter())
1117+
LiteralGuarantee::new(column, Guarantee::NotIn, literals)
11021118
}
11031119

11041120
// Schema for testing

datafusion/proto/src/physical_plan/to_proto.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,10 @@ pub fn serialize_physical_expr(
315315
expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new(
316316
protobuf::PhysicalInListNode {
317317
expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)),
318-
list: serialize_physical_exprs(expr.list(), codec)?,
318+
// TODO: serialize the inner ArrayRef directly to avoid materialization into literals
319+
// by extending the protobuf definition to support both representations and adding a public
320+
// accessor method to InListExpr to get the inner ArrayRef
321+
list: serialize_physical_exprs(&expr.list()?, codec)?,
319322
negated: expr.negated(),
320323
},
321324
))),

datafusion/pruning/src/pruning_predicate.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,9 +1424,11 @@ fn build_predicate_expression(
14241424
}
14251425
}
14261426
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
1427-
if !in_list.list().is_empty()
1428-
&& in_list.list().len() <= MAX_LIST_VALUE_SIZE_REWRITE
1429-
{
1427+
if !in_list.is_empty() && in_list.len() <= MAX_LIST_VALUE_SIZE_REWRITE {
1428+
let list = match in_list.list() {
1429+
Ok(list) => list,
1430+
Err(_) => return unhandled_hook.handle(expr),
1431+
};
14301432
let eq_op = if in_list.negated() {
14311433
Operator::NotEq
14321434
} else {
@@ -1437,8 +1439,7 @@ fn build_predicate_expression(
14371439
} else {
14381440
Operator::Or
14391441
};
1440-
let change_expr = in_list
1441-
.list()
1442+
let change_expr = list
14421443
.iter()
14431444
.map(|e| {
14441445
Arc::new(phys_expr::BinaryExpr::new(

datafusion/sqllogictest/test_files/array.slt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6408,7 +6408,7 @@ physical_plan
64086408
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
64096409
05)--------ProjectionExec: expr=[]
64106410
06)----------CoalesceBatchesExec: target_batch_size=8192
6411-
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
6411+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
64126412
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
64136413
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
64146414

@@ -6437,7 +6437,7 @@ physical_plan
64376437
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
64386438
05)--------ProjectionExec: expr=[]
64396439
06)----------CoalesceBatchesExec: target_batch_size=8192
6440-
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
6440+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
64416441
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
64426442
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
64436443

@@ -6466,7 +6466,7 @@ physical_plan
64666466
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
64676467
05)--------ProjectionExec: expr=[]
64686468
06)----------CoalesceBatchesExec: target_batch_size=8192
6469-
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
6469+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
64706470
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
64716471
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
64726472

@@ -6495,7 +6495,7 @@ physical_plan
64956495
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
64966496
05)--------ProjectionExec: expr=[]
64976497
06)----------CoalesceBatchesExec: target_batch_size=8192
6498-
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
6498+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
64996499
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
65006500
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
65016501

@@ -6524,7 +6524,7 @@ physical_plan
65246524
04)------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))]
65256525
05)--------ProjectionExec: expr=[]
65266526
06)----------CoalesceBatchesExec: target_batch_size=8192
6527-
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
6527+
07)------------FilterExec: substr(md5(CAST(value@0 AS Utf8View)), 1, 32) IN (SET) ([7f4b18de3cfeb9b4ac78c381ee2ad278, a, b, c])
65286528
08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
65296529
09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192]
65306530

datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ physical_plan
6969
03)----CoalescePartitionsExec
7070
04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
7171
05)--------CoalesceBatchesExec: target_batch_size=8192
72-
06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN ([SM CASE, SM BOX, SM PACK, SM PKG]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([MED BAG, MED BOX, MED PKG, MED PACK]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([LG CASE, LG BOX, LG PACK, LG PKG]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3]
72+
06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3]
7373
07)------------CoalesceBatchesExec: target_batch_size=8192
7474
08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4
7575
09)----------------CoalesceBatchesExec: target_batch_size=8192
@@ -78,6 +78,6 @@ physical_plan
7878
12)------------CoalesceBatchesExec: target_batch_size=8192
7979
13)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4
8080
14)----------------CoalesceBatchesExec: target_batch_size=8192
81-
15)------------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15) AND p_size@2 >= 1
81+
15)------------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15) AND p_size@2 >= 1
8282
16)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
8383
17)----------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false

datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ physical_plan
9191
15)----------------------------CoalesceBatchesExec: target_batch_size=8192
9292
16)------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4
9393
17)--------------------------------CoalesceBatchesExec: target_batch_size=8192
94-
18)----------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([13, 31, 23, 29, 30, 18, 17])
94+
18)----------------------------------FilterExec: substr(c_phone@1, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17])
9595
19)------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
9696
20)--------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false
9797
21)----------------------------CoalesceBatchesExec: target_batch_size=8192
@@ -101,6 +101,6 @@ physical_plan
101101
25)----------------------CoalescePartitionsExec
102102
26)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)]
103103
27)--------------------------CoalesceBatchesExec: target_batch_size=8192
104-
28)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1]
104+
28)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN (SET) ([13, 31, 23, 29, 30, 18, 17]), projection=[c_acctbal@1]
105105
29)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
106106
30)--------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false

0 commit comments

Comments
 (0)