Skip to content

Commit da3d90a

Browse files
authored
synchronize partition bounds reporting in HashJoin (#17452)
1 parent 0c7d830 commit da3d90a

File tree

4 files changed

+89
-34
lines changed

4 files changed

+89
-34
lines changed

datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,7 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() {
10651065
];
10661066
let probe_repartition = Arc::new(
10671067
RepartitionExec::try_new(
1068-
probe_scan,
1068+
Arc::clone(&probe_scan),
10691069
Partitioning::Hash(probe_hash_exprs, partition_count),
10701070
)
10711071
.unwrap(),
@@ -1199,6 +1199,13 @@ async fn test_hashjoin_dynamic_filter_pushdown_partitioned() {
11991199

12001200
let result = format!("{}", pretty_format_batches(&batches).unwrap());
12011201

1202+
let probe_scan_metrics = probe_scan.metrics().unwrap();
1203+
1204+
// The probe side had 4 rows, but after applying the dynamic filter only 2 rows should remain.
1205+
// The number of output rows from the probe side scan should stay consistent across executions.
1206+
// Issue: https://github.com/apache/datafusion/issues/17451
1207+
assert_eq!(probe_scan_metrics.output_rows().unwrap(), 2);
1208+
12021209
insta::assert_snapshot!(
12031210
result,
12041211
@r"
@@ -1355,7 +1362,7 @@ async fn test_nested_hashjoin_dynamic_filter_pushdown() {
13551362
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, x], file_type=test, pushdown_supported=true
13561363
- HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, d@0)]
13571364
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[b, c, y], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ b@0 >= aa AND b@0 <= ab ]
1358-
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ d@0 >= ca AND d@0 <= ce ]
1365+
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, z], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ d@0 >= ca AND d@0 <= cb ]
13591366
"
13601367
);
13611368
}

datafusion/core/tests/physical_optimizer/filter_pushdown/util.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use datafusion_datasource::{
2727
};
2828
use datafusion_physical_expr_common::physical_expr::fmt_sql;
2929
use datafusion_physical_optimizer::PhysicalOptimizerRule;
30+
use datafusion_physical_plan::filter::batch_filter;
3031
use datafusion_physical_plan::filter_pushdown::{FilterPushdownPhase, PushedDown};
3132
use datafusion_physical_plan::{
3233
displayable,
@@ -53,6 +54,7 @@ pub struct TestOpener {
5354
batch_size: Option<usize>,
5455
schema: Option<SchemaRef>,
5556
projection: Option<Vec<usize>>,
57+
predicate: Option<Arc<dyn PhysicalExpr>>,
5658
}
5759

5860
impl FileOpener for TestOpener {
@@ -77,6 +79,12 @@ impl FileOpener for TestOpener {
7779
let (mapper, projection) = factory.map_schema(&batches[0].schema()).unwrap();
7880
let mut new_batches = Vec::new();
7981
for batch in batches {
82+
let batch = if let Some(predicate) = &self.predicate {
83+
batch_filter(&batch, predicate)?
84+
} else {
85+
batch
86+
};
87+
8088
let batch = batch.project(&projection).unwrap();
8189
let batch = mapper.map_batch(batch).unwrap();
8290
new_batches.push(batch);
@@ -133,6 +141,7 @@ impl FileSource for TestSource {
133141
batch_size: self.batch_size,
134142
schema: self.schema.clone(),
135143
projection: self.projection.clone(),
144+
predicate: self.predicate.clone(),
136145
})
137146
}
138147

datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};
3232

3333
use itertools::Itertools;
3434
use parking_lot::Mutex;
35+
use tokio::sync::Barrier;
3536

3637
/// Represents the minimum and maximum values for a specific column.
3738
/// Used in dynamic filter pushdown to establish value boundaries.
@@ -86,9 +87,9 @@ impl PartitionBounds {
8687
/// ## Synchronization Strategy
8788
///
8889
/// 1. Each partition computes bounds from its build-side data
89-
/// 2. Bounds are stored in the shared HashMap (indexed by partition_id)
90-
/// 3. A counter tracks how many partitions have reported their bounds
91-
/// 4. When the last partition reports (completed == total), bounds are merged and filter is updated
90+
/// 2. Bounds are stored in the shared vector
91+
/// 3. A barrier tracks how many partitions have reported their bounds
92+
/// 4. When the last partition reports, bounds are merged and the filter is updated exactly once
9293
///
9394
/// ## Partition Counting
9495
///
@@ -103,10 +104,7 @@ impl PartitionBounds {
103104
pub(crate) struct SharedBoundsAccumulator {
104105
/// Shared state protected by a single mutex to avoid ordering concerns
105106
inner: Mutex<SharedBoundsState>,
106-
/// Total number of partitions.
107-
/// Need to know this so that we can update the dynamic filter once we are done
108-
/// building *all* of the hash tables.
109-
total_partitions: usize,
107+
barrier: Barrier,
110108
/// Dynamic filter for pushdown to probe side
111109
dynamic_filter: Arc<DynamicFilterPhysicalExpr>,
112110
/// Right side join expressions needed for creating filter bounds
@@ -118,8 +116,6 @@ struct SharedBoundsState {
118116
/// Bounds from completed partitions.
119117
/// Each element represents the column bounds computed by one partition.
120118
bounds: Vec<PartitionBounds>,
121-
/// Number of partitions that have reported completion.
122-
completed_partitions: usize,
123119
}
124120

125121
impl SharedBoundsAccumulator {
@@ -170,9 +166,8 @@ impl SharedBoundsAccumulator {
170166
Self {
171167
inner: Mutex::new(SharedBoundsState {
172168
bounds: Vec::with_capacity(expected_calls),
173-
completed_partitions: 0,
174169
}),
175-
total_partitions: expected_calls,
170+
barrier: Barrier::new(expected_calls),
176171
dynamic_filter,
177172
on_right,
178173
}
@@ -253,36 +248,44 @@ impl SharedBoundsAccumulator {
253248
/// bounds from the current partition, increments the completion counter, and when all
254249
/// partitions have reported, creates an OR'd filter from individual partition bounds.
255250
///
251+
/// This method is async and uses a [`tokio::sync::Barrier`] to wait for all partitions
252+
/// to report their bounds. Once that occurs, the method will resolve for all callers and the
253+
/// dynamic filter will be updated exactly once.
254+
///
255+
/// # Note
256+
///
257+
/// As barriers are reusable, it is likely an error to call this method more times than the
258+
/// total number of partitions - as it can lead to pending futures that never resolve. We rely
259+
/// on correct usage from the caller rather than imposing additional checks here. If this is a concern,
260+
/// consider making the resulting future shared so the ready result can be reused.
261+
///
256262
/// # Arguments
263+
/// * `partition` - The partition identifier reporting its bounds
257264
/// * `partition_bounds` - The bounds computed by this partition (if any)
258265
///
259266
/// # Returns
260267
/// * `Result<()>` - Ok if successful, Err if filter update failed
261-
pub(crate) fn report_partition_bounds(
268+
pub(crate) async fn report_partition_bounds(
262269
&self,
263270
partition: usize,
264271
partition_bounds: Option<Vec<ColumnBounds>>,
265272
) -> Result<()> {
266-
let mut inner = self.inner.lock();
267-
268273
// Store bounds in the accumulator - this runs once per partition
269274
if let Some(bounds) = partition_bounds {
270-
// Only push actual bounds if they exist
271-
inner.bounds.push(PartitionBounds::new(partition, bounds));
275+
self.inner
276+
.lock()
277+
.bounds
278+
.push(PartitionBounds::new(partition, bounds));
272279
}
273280

274-
// Increment the completion counter
275-
// Even empty partitions must report to ensure proper termination
276-
inner.completed_partitions += 1;
277-
let completed = inner.completed_partitions;
278-
let total_partitions = self.total_partitions;
279-
280-
// Critical synchronization point: Only update the filter when ALL partitions are complete
281-
// Troubleshooting: If you see "completed > total_partitions", check partition
282-
// count calculation in new_from_partition_mode() - it may not match actual execution calls
283-
if completed == total_partitions && !inner.bounds.is_empty() {
284-
let filter_expr = self.create_filter_from_partition_bounds(&inner.bounds)?;
285-
self.dynamic_filter.update(filter_expr)?;
281+
if self.barrier.wait().await.is_leader() {
282+
// All partitions have reported, so we can update the filter
283+
let inner = self.inner.lock();
284+
if !inner.bounds.is_empty() {
285+
let filter_expr =
286+
self.create_filter_from_partition_bounds(&inner.bounds)?;
287+
self.dynamic_filter.update(filter_expr)?;
288+
}
286289
}
287290

288291
Ok(())

datafusion/physical-plan/src/joins/hash_join/stream.rs

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ impl BuildSide {
120120
pub(super) enum HashJoinStreamState {
121121
/// Initial state for HashJoinStream indicating that build-side data not collected yet
122122
WaitBuildSide,
123+
/// Waiting for bounds to be reported by all partitions
124+
WaitPartitionBoundsReport,
123125
/// Indicates that build-side has been collected, and stream is ready for fetching probe-side
124126
FetchProbeBatch,
125127
/// Indicates that non-empty batch has been fetched from probe-side, and is ready to be processed
@@ -205,6 +207,9 @@ pub(super) struct HashJoinStream {
205207
right_side_ordered: bool,
206208
/// Shared bounds accumulator for coordinating dynamic filter updates (optional)
207209
bounds_accumulator: Option<Arc<SharedBoundsAccumulator>>,
210+
/// Optional future to signal when bounds have been reported by all partitions
211+
/// and the dynamic filter has been updated
212+
bounds_waiter: Option<OnceFut<()>>,
208213
}
209214

210215
impl RecordBatchStream for HashJoinStream {
@@ -325,6 +330,7 @@ impl HashJoinStream {
325330
hashes_buffer,
326331
right_side_ordered,
327332
bounds_accumulator,
333+
bounds_waiter: None,
328334
}
329335
}
330336

@@ -339,6 +345,9 @@ impl HashJoinStream {
339345
HashJoinStreamState::WaitBuildSide => {
340346
handle_state!(ready!(self.collect_build_side(cx)))
341347
}
348+
HashJoinStreamState::WaitPartitionBoundsReport => {
349+
handle_state!(ready!(self.wait_for_partition_bounds_report(cx)))
350+
}
342351
HashJoinStreamState::FetchProbeBatch => {
343352
handle_state!(ready!(self.fetch_probe_batch(cx)))
344353
}
@@ -355,6 +364,26 @@ impl HashJoinStream {
355364
}
356365
}
357366

367+
/// Optional step to wait until bounds have been reported by all partitions.
368+
/// This state is only entered if a bounds accumulator is present.
369+
///
370+
/// ## Why wait?
371+
///
372+
/// The dynamic filter is only built once all partitions have reported their bounds.
373+
/// If we do not wait here, the probe-side scan may start before the filter is ready.
374+
/// This can lead to the probe-side scan missing the opportunity to apply the filter
375+
/// and skip reading unnecessary data.
376+
fn wait_for_partition_bounds_report(
377+
&mut self,
378+
cx: &mut std::task::Context<'_>,
379+
) -> Poll<Result<StatefulStreamResult<Option<RecordBatch>>>> {
380+
if let Some(ref mut fut) = self.bounds_waiter {
381+
ready!(fut.get_shared(cx))?;
382+
}
383+
self.state = HashJoinStreamState::FetchProbeBatch;
384+
Poll::Ready(Ok(StatefulStreamResult::Continue))
385+
}
386+
358387
/// Collects build-side data by polling `OnceFut` future from initialized build-side
359388
///
360389
/// Updates build-side to `Ready`, and state to `FetchProbeSide`
@@ -376,13 +405,20 @@ impl HashJoinStream {
376405
// Dynamic filter coordination between partitions:
377406
// Report bounds to the accumulator which will handle synchronization and filter updates
378407
if let Some(ref bounds_accumulator) = self.bounds_accumulator {
379-
bounds_accumulator
380-
.report_partition_bounds(self.partition, left_data.bounds.clone())?;
408+
let bounds_accumulator = Arc::clone(bounds_accumulator);
409+
let partition = self.partition;
410+
let left_data_bounds = left_data.bounds.clone();
411+
self.bounds_waiter = Some(OnceFut::new(async move {
412+
bounds_accumulator
413+
.report_partition_bounds(partition, left_data_bounds)
414+
.await
415+
}));
416+
self.state = HashJoinStreamState::WaitPartitionBoundsReport;
417+
} else {
418+
self.state = HashJoinStreamState::FetchProbeBatch;
381419
}
382420

383-
self.state = HashJoinStreamState::FetchProbeBatch;
384421
self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });
385-
386422
Poll::Ready(Ok(StatefulStreamResult::Continue))
387423
}
388424

0 commit comments

Comments
 (0)