Skip to content

Commit cfc7e3b

Browse files
committed
Store threshold as Row bytes instead of Vec<ScalarValue>. Move filter expression updates inside the threshold lock to ensure atomic updates of both threshold and filter. This prevents race conditions where multiple threads could update the filter expression out of order, leading to less selective filters overwriting more selective ones.
1 parent 3974a4d commit cfc7e3b

File tree

1 file changed

+63
-75
lines changed
  • datafusion/physical-plan/src/topk

1 file changed

+63
-75
lines changed

datafusion/physical-plan/src/topk/mod.rs

Lines changed: 63 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ pub struct TopK {
133133
pub struct TopKDynamicFilters {
134134
/// The current *global* threshold for the dynamic filter.
135135
/// This is shared across all partitions and is updated by any of them.
136-
thresholds: Arc<RwLock<Option<Vec<ScalarValue>>>>,
136+
/// Stored as row bytes for efficient comparison.
137+
threshold_row: Arc<RwLock<Option<Vec<u8>>>>,
137138
/// The expression used to evaluate the dynamic filter
138139
expr: Arc<DynamicFilterPhysicalExpr>,
139140
}
@@ -142,7 +143,7 @@ impl TopKDynamicFilters {
142143
/// Create a new `TopKDynamicFilters` with the given expression
143144
pub fn new(expr: Arc<DynamicFilterPhysicalExpr>) -> Self {
144145
Self {
145-
thresholds: Arc::new(RwLock::new(None)),
146+
threshold_row: Arc::new(RwLock::new(None)),
146147
expr,
147148
}
148149
}
@@ -341,94 +342,72 @@ impl TopK {
341342
/// (a > 2 OR (a = 2 AND b < 3))
342343
/// ```
343344
fn update_filter(&mut self) -> Result<()> {
344-
let Some(thresholds) = self.heap.get_threshold_values(&self.expr)? else {
345+
let Some(new_threshold_row) = self.heap.get_threshold_row() else {
345346
return Ok(());
346347
};
347348

348-
// Are the new thresholds more selective than our existing ones?
349-
let should_update = {
350-
if let Some(current) = self.filter.thresholds.write().as_mut() {
351-
assert!(current.len() == thresholds.len());
352-
// Check if new thresholds are more selective than current ones
353-
let mut more_selective = false;
354-
for ((current_value, new_value), sort_expr) in
355-
current.iter().zip(thresholds.iter()).zip(self.expr.iter())
356-
{
357-
// Handle null cases
358-
let (current_is_null, new_is_null) =
359-
(current_value.is_null(), new_value.is_null());
360-
361-
match (current_is_null, new_is_null) {
362-
(true, true) => {
363-
// Both null, continue checking next values
364-
}
365-
(true, false) => {
366-
// Current is null, new is not null
367-
// For nulls_first: null < non-null, so new value is less selective
368-
// For nulls_last: null > non-null, so new value is more selective
369-
more_selective = !sort_expr.options.nulls_first;
370-
break;
371-
}
372-
(false, true) => {
373-
// Current is not null, new is null
374-
// For nulls_first: non-null > null, so new value is more selective
375-
// For nulls_last: non-null < null, so new value is less selective
376-
more_selective = sort_expr.options.nulls_first;
377-
break;
378-
}
379-
(false, false) => {
380-
// Neither is null, compare values
381-
match current_value.partial_cmp(new_value) {
382-
Some(ordering) => {
383-
match ordering {
384-
Ordering::Equal => {
385-
// Continue checking next values
386-
}
387-
Ordering::Less => {
388-
// For descending sort: new > current means more selective
389-
// For ascending sort: new > current means less selective
390-
more_selective = sort_expr.options.descending;
391-
break;
392-
}
393-
Ordering::Greater => {
394-
// For descending sort: new < current means less selective
395-
// For ascending sort: new < current means more selective
396-
more_selective =
397-
!sort_expr.options.descending;
398-
break;
399-
}
400-
}
401-
}
402-
None => {
403-
// If values can't be compared, don't update
404-
more_selective = false;
405-
break;
406-
}
407-
}
408-
}
349+
// Extract filter expression reference before entering critical section
350+
let filter_expr = Arc::clone(&self.filter.expr);
351+
352+
// Check if we need to update and do both threshold and filter update atomically
353+
{
354+
let mut threshold_guard = self.filter.threshold_row.write();
355+
if let Some(current_row) = threshold_guard.as_ref() {
356+
match current_row.as_slice().cmp(new_threshold_row) {
357+
Ordering::Less => {
358+
// new > current, so new threshold is more selective
359+
// Update threshold and filter atomically to prevent race conditions
360+
*threshold_guard = Some(new_threshold_row.to_vec());
361+
362+
// Extract scalar values for filter expression creation
363+
let thresholds =
364+
match self.heap.get_threshold_values(&self.expr)? {
365+
Some(t) => t,
366+
None => return Ok(()),
367+
};
368+
369+
// Update the filter expression while still holding the lock
370+
Self::update_filter_expression(
371+
&filter_expr,
372+
&self.expr,
373+
thresholds,
374+
)?;
375+
}
376+
Ordering::Equal | Ordering::Greater => {
377+
// Same threshold or current is more selective, no need to update
409378
}
410379
}
411-
// If the new thresholds are more selective, update the current ones
412-
if more_selective {
413-
*current = thresholds.clone();
414-
}
415-
more_selective
416380
} else {
417381
// No current thresholds, so update with the new ones
418-
true
382+
*threshold_guard = Some(new_threshold_row.to_vec());
383+
384+
// Extract scalar values for filter expression creation
385+
let thresholds = match self.heap.get_threshold_values(&self.expr)? {
386+
Some(t) => t,
387+
None => return Ok(()),
388+
};
389+
390+
// Update the filter expression while still holding the lock
391+
Self::update_filter_expression(&filter_expr, &self.expr, thresholds)?;
419392
}
420393
};
421394

422-
if !should_update {
423-
return Ok(());
424-
}
395+
Ok(())
396+
}
425397

398+
/// Update the filter expression with the given thresholds.
399+
/// This should only be called while holding the threshold lock.
400+
fn update_filter_expression(
401+
filter_expr: &DynamicFilterPhysicalExpr,
402+
sort_exprs: &[PhysicalSortExpr],
403+
thresholds: Vec<ScalarValue>,
404+
) -> Result<()> {
426405
// Create filter expressions for each threshold
427406
let mut filters: Vec<Arc<dyn PhysicalExpr>> =
428407
Vec::with_capacity(thresholds.len());
429408

430409
let mut prev_sort_expr: Option<Arc<dyn PhysicalExpr>> = None;
431-
for (sort_expr, value) in self.expr.iter().zip(thresholds.iter()) {
410+
for (sort_expr, value) in sort_exprs.iter().zip(thresholds.iter()) {
432411
// Create the appropriate operator based on sort order
433412
let op = if sort_expr.options.descending {
434413
// For descending sort, we want col > threshold (exclude smaller values)
@@ -502,7 +481,7 @@ impl TopK {
502481

503482
if let Some(predicate) = dynamic_predicate {
504483
if !predicate.eq(&lit(true)) {
505-
self.filter.expr.update(predicate)?;
484+
filter_expr.update(predicate)?;
506485
}
507486
}
508487

@@ -842,6 +821,15 @@ impl TopKHeap {
842821
+ self.owned_bytes
843822
}
844823

824+
fn get_threshold_row(&self) -> Option<&[u8]> {
825+
// If the heap doesn't have k elements yet, we can't create thresholds
826+
let max_row = self.max()?;
827+
828+
// Return the row bytes directly - this is much more efficient
829+
// than extracting ScalarValues and comparing them
830+
Some(&max_row.row)
831+
}
832+
845833
fn get_threshold_values(
846834
&self,
847835
sort_exprs: &[PhysicalSortExpr],

0 commit comments

Comments
 (0)