@@ -29,7 +29,10 @@ use arrow::{
29
29
} ,
30
30
record_batch:: RecordBatch ,
31
31
} ;
32
- use arrow_array:: { Array , Float32Array , Float64Array , UnionArray } ;
32
+ use arrow_array:: {
33
+ Array , BooleanArray , DictionaryArray , Float32Array , Float64Array , Int8Array ,
34
+ UnionArray ,
35
+ } ;
33
36
use arrow_buffer:: ScalarBuffer ;
34
37
use arrow_schema:: { ArrowError , UnionFields , UnionMode } ;
35
38
use datafusion_functions_aggregate:: count:: count_udaf;
@@ -2363,3 +2366,105 @@ async fn dense_union_is_null() {
2363
2366
] ;
2364
2367
assert_batches_sorted_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2365
2368
}
2369
+
2370
+ #[ tokio:: test]
2371
+ async fn boolean_dictionary_as_filter ( ) {
2372
+ let values = vec ! [ Some ( true ) , Some ( false ) , None , Some ( true ) ] ;
2373
+ let keys = vec ! [ 0 , 0 , 1 , 2 , 1 , 3 , 1 ] ;
2374
+ let values_array = BooleanArray :: from ( values) ;
2375
+ let keys_array = Int8Array :: from ( keys) ;
2376
+ let array =
2377
+ DictionaryArray :: new ( keys_array, Arc :: new ( values_array) as Arc < dyn Array > ) ;
2378
+ let array = Arc :: new ( array) ;
2379
+
2380
+ let field = Field :: new (
2381
+ "my_dict" ,
2382
+ DataType :: Dictionary ( Box :: new ( DataType :: Int8 ) , Box :: new ( DataType :: Boolean ) ) ,
2383
+ true ,
2384
+ ) ;
2385
+ let schema = Arc :: new ( Schema :: new ( vec ! [ field] ) ) ;
2386
+
2387
+ let batch = RecordBatch :: try_new ( schema, vec ! [ array. clone( ) ] ) . unwrap ( ) ;
2388
+
2389
+ let ctx = SessionContext :: new ( ) ;
2390
+
2391
+ ctx. register_batch ( "dict_batch" , batch) . unwrap ( ) ;
2392
+
2393
+ let df = ctx. table ( "dict_batch" ) . await . unwrap ( ) ;
2394
+
2395
+ // view_all
2396
+ let expected = [
2397
+ "+---------+" ,
2398
+ "| my_dict |" ,
2399
+ "+---------+" ,
2400
+ "| true |" ,
2401
+ "| true |" ,
2402
+ "| false |" ,
2403
+ "| |" ,
2404
+ "| false |" ,
2405
+ "| true |" ,
2406
+ "| false |" ,
2407
+ "+---------+" ,
2408
+ ] ;
2409
+ assert_batches_eq ! ( expected, & df. clone( ) . collect( ) . await . unwrap( ) ) ;
2410
+
2411
+ let result_df = df. clone ( ) . filter ( col ( "my_dict" ) ) . unwrap ( ) ;
2412
+ let expected = [
2413
+ "+---------+" ,
2414
+ "| my_dict |" ,
2415
+ "+---------+" ,
2416
+ "| true |" ,
2417
+ "| true |" ,
2418
+ "| true |" ,
2419
+ "+---------+" ,
2420
+ ] ;
2421
+ assert_batches_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2422
+
2423
+ // test nested dictionary
2424
+ let keys = vec ! [ 0 , 2 ] ; // 0 -> true, 2 -> false
2425
+ let keys_array = Int8Array :: from ( keys) ;
2426
+ let nested_array = DictionaryArray :: new ( keys_array, array) ;
2427
+
2428
+ let field = Field :: new (
2429
+ "my_nested_dict" ,
2430
+ DataType :: Dictionary (
2431
+ Box :: new ( DataType :: Int8 ) ,
2432
+ Box :: new ( DataType :: Dictionary (
2433
+ Box :: new ( DataType :: Int8 ) ,
2434
+ Box :: new ( DataType :: Boolean ) ,
2435
+ ) ) ,
2436
+ ) ,
2437
+ true ,
2438
+ ) ;
2439
+
2440
+ let schema = Arc :: new ( Schema :: new ( vec ! [ field] ) ) ;
2441
+
2442
+ let batch = RecordBatch :: try_new ( schema, vec ! [ Arc :: new( nested_array) ] ) . unwrap ( ) ;
2443
+
2444
+ ctx. register_batch ( "nested_dict_batch" , batch) . unwrap ( ) ;
2445
+
2446
+ let df = ctx. table ( "nested_dict_batch" ) . await . unwrap ( ) ;
2447
+
2448
+ // view_all
2449
+ let expected = [
2450
+ "+----------------+" ,
2451
+ "| my_nested_dict |" ,
2452
+ "+----------------+" ,
2453
+ "| true |" ,
2454
+ "| false |" ,
2455
+ "+----------------+" ,
2456
+ ] ;
2457
+
2458
+ assert_batches_eq ! ( expected, & df. clone( ) . collect( ) . await . unwrap( ) ) ;
2459
+
2460
+ let result_df = df. clone ( ) . filter ( col ( "my_nested_dict" ) ) . unwrap ( ) ;
2461
+ let expected = [
2462
+ "+----------------+" ,
2463
+ "| my_nested_dict |" ,
2464
+ "+----------------+" ,
2465
+ "| true |" ,
2466
+ "+----------------+" ,
2467
+ ] ;
2468
+
2469
+ assert_batches_eq ! ( expected, & result_df. collect( ) . await . unwrap( ) ) ;
2470
+ }
0 commit comments