diff --git a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs index 1f47412caf2a..4cf6609fd441 100644 --- a/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_query_fuzz.rs @@ -31,7 +31,6 @@ use datafusion_execution::memory_pool::{ }; use datafusion_expr::display_schema; use datafusion_physical_plan::spill::get_record_batch_memory_size; -use itertools::Itertools; use std::time::Duration; use datafusion_execution::{memory_pool::FairSpillPool, runtime_env::RuntimeEnvBuilder}; @@ -73,43 +72,6 @@ async fn sort_query_fuzzer_runner() { fuzzer.run().await.unwrap(); } -/// Reproduce the bug with specific seeds from the -/// [failing test case](https://github.com/apache/datafusion/issues/16452). -#[tokio::test(flavor = "multi_thread")] -async fn test_reproduce_sort_query_issue_16452() { - // Seeds from the failing test case - let init_seed = 10313160656544581998u64; - let query_seed = 15004039071976572201u64; - let config_seed_1 = 11807432710583113300u64; - let config_seed_2 = 759937414670321802u64; - - let random_seed = 1u64; // Use a fixed seed to ensure consistent behavior - - let mut test_generator = SortFuzzerTestGenerator::new( - 2000, - 3, - "sort_fuzz_table".to_string(), - get_supported_types_columns(random_seed), - false, - random_seed, - ); - - let mut results = vec![]; - - for config_seed in [config_seed_1, config_seed_2] { - let r = test_generator - .fuzzer_run(init_seed, query_seed, config_seed) - .await - .unwrap(); - - results.push(r); - } - - for (lhs, rhs) in results.iter().tuple_windows() { - check_equality_of_batches(lhs, rhs).unwrap(); - } -} - /// SortQueryFuzzer holds the runner configuration for executing sort query fuzz tests. The fuzzing details are managed inside `SortFuzzerTestGenerator`. /// /// It defines: @@ -466,7 +428,7 @@ impl SortFuzzerTestGenerator { .collect(); let mut order_by_clauses = Vec::new(); - for col in selected_columns { + for col in &selected_columns { let mut clause = col.name.clone(); if rng.random_bool(0.5) { let order = if rng.random_bool(0.5) { "ASC" } else { "DESC" }; @@ -501,7 +463,12 @@ impl SortFuzzerTestGenerator { let limit_clause = limit.map_or(String::new(), |l| format!(" LIMIT {l}")); let query = format!( - "SELECT * FROM {} ORDER BY {}{}", + "SELECT {} FROM {} ORDER BY {}{}", + selected_columns + .iter() + .map(|col| col.name.clone()) + .collect::>() + .join(", "), self.table_name, order_by_clauses.join(", "), limit_clause diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 71029662f5f5..8d06fa73ce8e 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -18,8 +18,8 @@ //! TopK: Combination of Sort / LIMIT use arrow::{ - array::Array, - compute::interleave_record_batch, + array::{Array, AsArray}, + compute::{interleave_record_batch, prep_null_mask_filter, FilterBuilder}, row::{RowConverter, Rows, SortField}, }; use datafusion_expr::{ColumnarValue, Operator}; @@ -203,7 +203,7 @@ impl TopK { let baseline = self.metrics.baseline.clone(); let _timer = baseline.elapsed_compute().timer(); - let sort_keys: Vec = self + let mut sort_keys: Vec = self .expr .iter() .map(|expr| { @@ -212,6 +212,43 @@ impl TopK { }) .collect::>>()?; + let mut selected_rows = None; + + if let Some(filter) = self.filter.as_ref() { + // If a filter is provided, update it with the new rows + let filter = filter.current()?; + let filtered = filter.evaluate(&batch)?; + let num_rows = batch.num_rows(); + let array = filtered.into_array(num_rows)?; + let mut filter = array.as_boolean().clone(); + let true_count = filter.true_count(); + if true_count == 0 { + // nothing to filter, so no need to update + return Ok(()); + } + // only update the keys / rows if the filter does not match all rows + if true_count < num_rows { + // Indices in `set_indices` should be correct if filter contains nulls + // So we prepare the filter here. Note this is also done in the `FilterBuilder` + // so there is no overhead to do this here. + if filter.nulls().is_some() { + filter = prep_null_mask_filter(&filter); + } + + let filter_predicate = FilterBuilder::new(&filter); + let filter_predicate = if sort_keys.len() > 1 { + // Optimize filter when it has multiple sort keys + filter_predicate.optimize().build() + } else { + filter_predicate.build() + }; + selected_rows = Some(filter); + sort_keys = sort_keys + .iter() + .map(|key| filter_predicate.filter(key).map_err(|x| x.into())) + .collect::>>()?; + } + }; // reuse existing `Rows` to avoid reallocations let rows = &mut self.scratch_rows; rows.clear(); @@ -219,8 +256,12 @@ impl TopK { let mut batch_entry = self.heap.register_batch(batch.clone()); - let replacements = - self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry); + let replacements = match selected_rows { + Some(filter) => { + self.find_new_topk_items(filter.values().set_indices(), &mut batch_entry) + } + None => self.find_new_topk_items(0..sort_keys[0].len(), &mut batch_entry), + }; if replacements > 0 { self.metrics.row_replacements.add(replacements);