Skip to content

Commit 8d2b240

Browse files
authored
Allow using dictionary arrays as filters (#12382)
* Allow using dictionaries as filters * revert, nested * fmt
1 parent c575bbf commit 8d2b240

File tree

2 files changed

+118
-3
lines changed

2 files changed

+118
-3
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ use arrow::{
2929
},
3030
record_batch::RecordBatch,
3131
};
32-
use arrow_array::{Array, Float32Array, Float64Array, UnionArray};
32+
use arrow_array::{
33+
Array, BooleanArray, DictionaryArray, Float32Array, Float64Array, Int8Array,
34+
UnionArray,
35+
};
3336
use arrow_buffer::ScalarBuffer;
3437
use arrow_schema::{ArrowError, UnionFields, UnionMode};
3538
use datafusion_functions_aggregate::count::count_udaf;
@@ -2363,3 +2366,105 @@ async fn dense_union_is_null() {
23632366
];
23642367
assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap());
23652368
}
2369+
2370+
#[tokio::test]
2371+
async fn boolean_dictionary_as_filter() {
2372+
let values = vec![Some(true), Some(false), None, Some(true)];
2373+
let keys = vec![0, 0, 1, 2, 1, 3, 1];
2374+
let values_array = BooleanArray::from(values);
2375+
let keys_array = Int8Array::from(keys);
2376+
let array =
2377+
DictionaryArray::new(keys_array, Arc::new(values_array) as Arc<dyn Array>);
2378+
let array = Arc::new(array);
2379+
2380+
let field = Field::new(
2381+
"my_dict",
2382+
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Boolean)),
2383+
true,
2384+
);
2385+
let schema = Arc::new(Schema::new(vec![field]));
2386+
2387+
let batch = RecordBatch::try_new(schema, vec![array.clone()]).unwrap();
2388+
2389+
let ctx = SessionContext::new();
2390+
2391+
ctx.register_batch("dict_batch", batch).unwrap();
2392+
2393+
let df = ctx.table("dict_batch").await.unwrap();
2394+
2395+
// view_all
2396+
let expected = [
2397+
"+---------+",
2398+
"| my_dict |",
2399+
"+---------+",
2400+
"| true |",
2401+
"| true |",
2402+
"| false |",
2403+
"| |",
2404+
"| false |",
2405+
"| true |",
2406+
"| false |",
2407+
"+---------+",
2408+
];
2409+
assert_batches_eq!(expected, &df.clone().collect().await.unwrap());
2410+
2411+
let result_df = df.clone().filter(col("my_dict")).unwrap();
2412+
let expected = [
2413+
"+---------+",
2414+
"| my_dict |",
2415+
"+---------+",
2416+
"| true |",
2417+
"| true |",
2418+
"| true |",
2419+
"+---------+",
2420+
];
2421+
assert_batches_eq!(expected, &result_df.collect().await.unwrap());
2422+
2423+
// test nested dictionary
2424+
let keys = vec![0, 2]; // 0 -> true, 2 -> false
2425+
let keys_array = Int8Array::from(keys);
2426+
let nested_array = DictionaryArray::new(keys_array, array);
2427+
2428+
let field = Field::new(
2429+
"my_nested_dict",
2430+
DataType::Dictionary(
2431+
Box::new(DataType::Int8),
2432+
Box::new(DataType::Dictionary(
2433+
Box::new(DataType::Int8),
2434+
Box::new(DataType::Boolean),
2435+
)),
2436+
),
2437+
true,
2438+
);
2439+
2440+
let schema = Arc::new(Schema::new(vec![field]));
2441+
2442+
let batch = RecordBatch::try_new(schema, vec![Arc::new(nested_array)]).unwrap();
2443+
2444+
ctx.register_batch("nested_dict_batch", batch).unwrap();
2445+
2446+
let df = ctx.table("nested_dict_batch").await.unwrap();
2447+
2448+
// view_all
2449+
let expected = [
2450+
"+----------------+",
2451+
"| my_nested_dict |",
2452+
"+----------------+",
2453+
"| true |",
2454+
"| false |",
2455+
"+----------------+",
2456+
];
2457+
2458+
assert_batches_eq!(expected, &df.clone().collect().await.unwrap());
2459+
2460+
let result_df = df.clone().filter(col("my_nested_dict")).unwrap();
2461+
let expected = [
2462+
"+----------------+",
2463+
"| my_nested_dict |",
2464+
"+----------------+",
2465+
"| true |",
2466+
"+----------------+",
2467+
];
2468+
2469+
assert_batches_eq!(expected, &result_df.collect().await.unwrap());
2470+
}

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,6 +2207,17 @@ impl Filter {
22072207
Self::try_new_internal(predicate, input, true)
22082208
}
22092209

2210+
fn is_allowed_filter_type(data_type: &DataType) -> bool {
2211+
match data_type {
2212+
// Interpret NULL as a missing boolean value.
2213+
DataType::Boolean | DataType::Null => true,
2214+
DataType::Dictionary(_, value_type) => {
2215+
Filter::is_allowed_filter_type(value_type.as_ref())
2216+
}
2217+
_ => false,
2218+
}
2219+
}
2220+
22102221
fn try_new_internal(
22112222
predicate: Expr,
22122223
input: Arc<LogicalPlan>,
@@ -2217,8 +2228,7 @@ impl Filter {
22172228
// construction (such as with correlated subqueries) so we make a best effort here and
22182229
// ignore errors resolving the expression against the schema.
22192230
if let Ok(predicate_type) = predicate.get_type(input.schema()) {
2220-
// Interpret NULL as a missing boolean value.
2221-
if predicate_type != DataType::Boolean && predicate_type != DataType::Null {
2231+
if !Filter::is_allowed_filter_type(&predicate_type) {
22222232
return plan_err!(
22232233
"Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}"
22242234
);

0 commit comments

Comments
 (0)