From ac2b88c697e6d5ddf4c0783b330d999c5db2f11f Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 2 Jul 2025 15:37:24 -0400 Subject: [PATCH 01/24] feat: Support `PiecewiseMergeJoin` for single range join filters --- .../physical-plan/src/joins/hash_join.rs | 1 + datafusion/physical-plan/src/joins/mod.rs | 4 +- .../src/joins/nested_loop_join.rs | 7 +- .../src/joins/piecewise_merge_join.rs | 2070 +++++++++++++++++ .../src/joins/sort_merge_join.rs | 2 +- datafusion/physical-plan/src/joins/utils.rs | 13 +- 6 files changed, 2090 insertions(+), 7 deletions(-) create mode 100644 datafusion/physical-plan/src/joins/piecewise_merge_join.rs diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 770399290dca..7b63522c1006 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -1621,6 +1621,7 @@ impl HashJoinStream { let (left_side, right_side) = get_final_indices_from_shared_bitmap( build_side.left_data.visited_indices_bitmap(), self.join_type, + false, ); let empty_right_batch = RecordBatch::new_empty(self.right.schema()); // use the left and right indices to produce the batch result diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434..7eb64ed5b313 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -23,12 +23,14 @@ use datafusion_physical_expr::PhysicalExprRef; pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; +pub use piecewise_merge_join::PiecewiseMergeJoinExec; // Note: SortMergeJoin is not used in plans yet pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; mod hash_join; mod nested_loop_join; +mod piecewise_merge_join; mod sort_merge_join; mod stream_join_utils; mod symmetric_hash_join; @@ -67,5 +69,5 @@ pub enum StreamJoinPartitionMode { SinglePartition, } -/// Shared bitmap for visited left-side indices +/// Shared bitmap for visited indices type SharedBitmapBuilder = Mutex; diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index fcc1107a0e26..75d17d44f593 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -943,8 +943,11 @@ impl NestedLoopJoinStream { // Only setting up timer, input is exhausted let timer = self.join_metrics.join_time.timer(); // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); + let (left_side, right_side) = get_final_indices_from_shared_bitmap( + visited_left_side, + self.join_type, + false, + ); let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema()); // use the left and right indices to produce the batch result let result = build_batch_from_indices( diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs new file mode 100644 index 000000000000..1b37c414e978 --- /dev/null +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -0,0 +1,2070 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + new_null_array, Array, BooleanArray, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, RecordBatchOptions, UInt16Array, UInt8Array, +}; +use arrow::compute::take; +use arrow::{ + array::{ + ArrayRef, BooleanBufferBuilder, RecordBatch, UInt32Array, UInt32Builder, + UInt64Array, UInt64Builder, + }, + compute::{concat_batches, sort_to_indices, take_record_batch}, + util::bit_util, +}; +use arrow_schema::{ArrowError, DataType, Schema, SchemaRef, SortOptions}; +use datafusion_common::{ + exec_err, internal_err, plan_err, utils::compare_rows, JoinSide, Result, ScalarValue, +}; +use datafusion_common::{not_impl_err, NullEquality}; +use datafusion_execution::{ + memory_pool::{MemoryConsumer, MemoryReservation}, + RecordBatchStream, SendableRecordBatchStream, +}; +use datafusion_expr::{JoinType, Operator}; +use datafusion_physical_expr::equivalence::join_equivalence_properties; +use datafusion_physical_expr::{ + LexOrdering, OrderingRequirements, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, +}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; +use futures::{Stream, StreamExt, TryStreamExt}; +use parking_lot::Mutex; +use std::fmt::Formatter; +use std::{cmp::Ordering, task::ready}; +use std::{sync::Arc, task::Poll}; + +use crate::execution_plan::{boundedness_from_children, EmissionType}; + +use crate::joins::sort_merge_join::compare_join_arrays; +use crate::joins::utils::{ + get_final_indices_from_shared_bitmap, symmetric_join_output_partitioning, +}; +use crate::{handle_state, DisplayAs, DisplayFormatType, ExecutionPlanProperties}; +use crate::{ + joins::{ + utils::{ + build_join_schema, BuildProbeJoinMetrics, OnceAsync, OnceFut, + StatefulStreamResult, + }, + SharedBitmapBuilder, + }, + metrics::ExecutionPlanMetricsSet, + spill::get_record_batch_memory_size, + ExecutionPlan, PlanProperties, +}; + +/// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter. +/// +/// The physical planner will choose to evalute this join when there is only one range predicate. This +/// is a binary expression which contains [`Operator::Lt`], [`Operator::LtEq`], [`Operator::Gt`], and +/// [`Operator::GtEq`].: +/// Examples: +/// - `col0` < `colb`, `col0` <= `colb`, `col0` > `colb`, `col0` >= `colb` +/// +/// Since the join only support range predicates, equijoins are not supported in `PiecewiseMergeJoinExec`, +/// however you can first evaluate another join and run `PiecewiseMergeJoinExec` if left with one range +/// predicate. +/// +/// # Execution Plan Inputs +/// For `PiecewiseMergeJoin` we label all left inputs as the `streamed' side and the right outputs as the +/// 'buffered' side. +/// +/// `PiecewiseMergeJoin` takes a sorted input for the side to be buffered and is able to sort streamed record +/// batches during processing. Sorted input must specifically be ascending/descending based on the operator. +/// +/// # Algorithms +/// Classic joins are processed differently compared to existence joins. +/// +/// ## Classic Joins (Inner, Full, Left, Right) +/// For classic joins we buffer the right side (buffered), and incrementally process the left side (streamed). +/// Every streamed batch is sorted so we can perform a sort merge algorithm. For the buffered side we want to +/// have it already sorted either ascending or descending based on the operator as this allows us to emit all +/// the rows from a given point to the end as matches. Sorting the streamed side allows us to start the pointer +/// from the previous row's match on the buffered side. +/// +/// Here is an example: +/// +/// We perform a `JoinType::Left` with these two batches and the operator being `Operator::Lt`(<). For each +/// row on the streamed side we move a pointer on the buffered until it matches the condition. Once we reach +/// the row which matches (in this case with row 1 on streamed will have its first match on row 2 on +/// buffered; 100 < 200 is true), we can emit all rows after that match. We can emit the rows like this because +/// if the batch is sorted in ascending order, every subsequent row will also satisfy the condition as they will +/// all be larger values. +/// +/// ```text +/// Processing Row 1: +/// +/// Sorted Streamed Side Sorted Buffered Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ 2 │ 200 │ ─┐ +/// ├──────────────────┤ ├──────────────────┤ │ For row 1 on streamed side with +/// 3 │ 500 │ 3 │ 200 │ │ value 100, we emit rows 2 - 5 +/// └──────────────────┘ ├──────────────────┤ │ as matches when the operator is +/// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all +/// ├──────────────────┤ │ rows after the first match (row 2 +/// 5 │ 400 │ ─┘ buffered side; 100 < 200) +/// └──────────────────┘ +/// +/// Processing Row 2: +/// By sorting the streamed side we know +/// +/// Sorted Streamed Side Sorted Buffered Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ 2 │ 200 │ <- Start here when probing for the streamed +/// ├──────────────────┤ ├──────────────────┤ side row 2. +/// 3 │ 500 │ 3 │ 200 │ +/// └──────────────────┘ ├──────────────────┤ +/// 4 │ 300 │ +/// ├──────────────────┤ +/// 5 │ 400 | +/// └──────────────────┘ +/// ``` +/// +/// ## Existence Joins (Semi, Anti, Mark) +/// Existence joins are made magnitudes of times faster with a `PiecewiseMergeJoin` as we only need to find +/// the min/max value of the streamed side to be able to emit all matches on the buffered side. By putting +/// the side we need to mark onto the sorted buffer side, we can emit all these matches at once. +/// +/// For Left Semi, Anti, and Mark joins we swap the inputs so that the marked side is on the buffered side. +/// +/// Here is an example: +/// We perform a `JoinType::Left` with these two batches and the operator being `Operator::Lt`(<). Because +/// the operator is `Operator::Lt` we can find the minimum value in the streamed side; in this case it is 200. +/// We can then advance a pointer from the start of the buffer side until we find the first value that satisfies +/// the predicate. All rows after that first matched value satisfy the condition 200 < x so we can mark all of +/// those rows as matched. +/// +/// ```text +/// Unsorted Streamed Side Sorted Buffered Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 500 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ 2 │ 200 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 3 │ 300 │ 3 │ 200 │ +/// └──────────────────┘ ├──────────────────┤ +/// min value: 200 4 │ 300 │ ─┐ +/// ├──────────────────┤ | We emit matches for row 4 - 5 on the +/// 5 │ 400 | ─┘ buffered side. +/// └──────────────────┘ +/// ``` +/// +/// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt`(<) or +/// `Operator::LtEq`(<=) and descending for `Operator::Gt`(>) or `Operator::GtEq`(>=). +/// +/// # Further Reference Material +/// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) +#[derive(Debug)] +pub struct PiecewiseMergeJoinExec { + /// Left sorted joining execution plan + pub streamed: Arc, + /// Right sorting joining execution plan + pub buffered: Arc, + /// The two expressions being compared + pub on: (Arc, Arc), + /// Comparison operator in the range predicate + pub operator: Operator, + /// How the join is performed + pub join_type: JoinType, + /// The schema once the join is applied + schema: SchemaRef, + buffered_fut: OnceAsync, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// The left SortExpr + left_sort_exprs: LexOrdering, + /// The right SortExpr + right_sort_exprs: LexOrdering, + /// Sort options of join columns used in sorting the stream and buffered execution plans + sort_options: SortOptions, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, +} + +impl PiecewiseMergeJoinExec { + pub fn try_new( + streamed: Arc, + buffered: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + // TODO: Implement mark joins for PiecewiseMergeJoin + if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) { + return plan_err!( + "Mark Joins are currently not supported for PiecewiseMergeJoin" + ); + } + + // We take the operator and enforce a sort order on the streamed + buffered side based on + // the operator type. + let sort_options = match operator { + Operator::Lt | Operator::LtEq => SortOptions::new(false, false), + Operator::Gt | Operator::GtEq => SortOptions::new(true, false), + _ => { + return plan_err!( + "Cannot contain non-range operator in PiecewiseMergeJoinExec" + ) + } + }; + + let left_sort_exprs = vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; + let right_sort_exprs = vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; + let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { + return plan_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its left side" + ); + }; + let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { + return plan_err!( + "PiecewiseMergeJoinExec requires valid sort expressions for its right side" + ); + }; + + let streamed_schema = streamed.schema(); + let buffered_schema = buffered.schema(); + + // Create output schema for the join + let schema = + Arc::new(build_join_schema(&streamed_schema, &buffered_schema, &join_type).0); + let cache = Self::compute_properties( + &streamed, + &buffered, + Arc::clone(&schema), + join_type, + &on, + )?; + + Ok(Self { + streamed, + buffered, + on, + operator, + join_type, + schema, + buffered_fut: Default::default(), + metrics: ExecutionPlanMetricsSet::new(), + left_sort_exprs, + right_sort_exprs, + sort_options, + cache, + }) + } + + /// Refeference to streamed side execution plan + pub fn streamed(&self) -> &Arc { + &self.streamed + } + + /// Refeerence to buffered side execution plan + pub fn buffered(&self) -> &Arc { + &self.buffered + } + + /// Join type + pub fn join_type(&self) -> JoinType { + self.join_type + } + + /// Reference to sort options + pub fn sort_options(&self) -> &SortOptions { + &self.sort_options + } + + /// Get probe side (buffered side) for the PiecewiseMergeJoin + /// In current implementation, probe side is determined according to join type. + pub fn probe_side(join_type: &JoinType) -> JoinSide { + match join_type { + JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => JoinSide::Right, + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, + } + } + + // TODO: fix compute properties to work specifically for PiecewiseMergeJoin + // This is currently just a filler implementation so that it actually returns + // a PlanProperties + pub fn compute_properties( + streamed: &Arc, + buffered: &Arc, + schema: SchemaRef, + join_type: JoinType, + join_on: &(PhysicalExprRef, PhysicalExprRef), + ) -> Result { + let eq_properties = join_equivalence_properties( + streamed.equivalence_properties().clone(), + buffered.equivalence_properties().clone(), + &join_type, + schema, + &[false], + Some(Self::probe_side(&join_type)), + &[join_on.clone()], + )?; + + let output_partitioning = + symmetric_join_output_partitioning(streamed, buffered, &join_type)?; + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + EmissionType::Incremental, + boundedness_from_children([streamed, buffered]), + )) + } + + pub fn swap_inputs(&self) -> Result> { + todo!() + } +} + +impl ExecutionPlan for PiecewiseMergeJoinExec { + fn name(&self) -> &str { + "PiecewiseMergeJoinExec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + todo!() + } + + fn required_input_ordering(&self) -> Vec> { + vec![ + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), + ] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + match &children[..] { + [left, right] => Ok(Arc::new(PiecewiseMergeJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + self.on.clone(), + self.operator, + self.join_type, + )?)), + _ => internal_err!( + "PiecewiseMergeJoin should have 2 children, found {}", + children.len() + ), + } + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + let on_streamed = Arc::clone(&self.on.0); + let on_buffered = Arc::clone(&self.on.1); + + // If the join type is either LeftSemi, LeftAnti, or LeftMark we will swap the inputs + // and sort ordering because we want the mark side to be the buffered side. + let (streamed, buffered, on_streamed, on_buffered, operator, sort_options) = if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark + ) { + ( + Arc::clone(&self.buffered), + Arc::clone(&self.streamed), + on_buffered, + on_streamed, + self.operator.swap().unwrap(), + SortOptions::new(!self.sort_options.descending, false), + ) + } else { + ( + Arc::clone(&self.streamed), + Arc::clone(&self.buffered), + on_streamed, + on_buffered, + self.operator, + self.sort_options, + ) + }; + + let metrics = BuildProbeJoinMetrics::new(0, &self.metrics); + let buffered_fut = self.buffered_fut.try_once(|| { + let reservation = MemoryConsumer::new("PiecewiseMergeJoinInput") + .register(context.memory_pool()); + let buffered_stream = buffered.execute(partition, Arc::clone(&context))?; + Ok(build_buffered_data( + buffered_stream, + Arc::clone(&on_buffered), + metrics, + reservation, + build_visited_indices_map(self.join_type), + )) + })?; + + let streamed = streamed.execute(partition, Arc::clone(&context))?; + let existence_join = is_existence_join(self.join_type()); + + Ok(Box::pin(PiecewiseMergeJoinStream::try_new( + Arc::clone(&self.schema), + on_streamed, + self.join_type, + operator, + streamed, + BufferedSide::Initial(BufferedSideInitialState { buffered_fut }), + if existence_join { + PiecewiseMergeJoinStreamState::FetchStreamBatch + } else { + PiecewiseMergeJoinStreamState::WaitBufferedSide + }, + existence_join, + sort_options, + ))) + } +} + +impl DisplayAs for PiecewiseMergeJoinExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + let on_str = format!( + "({} {} {})", + fmt_sql(self.on.0.as_ref()), + self.operator, + fmt_sql(self.on.1.as_ref()) + ); + + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "PiecewiseMergeJoin: operator={:?}, join_type={:?}, on={}", + self.operator, self.join_type, on_str + ) + } + + DisplayFormatType::TreeRender => { + writeln!(f, "operator={:?}", self.operator)?; + if self.join_type != JoinType::Inner { + writeln!(f, "join_type={:?}", self.join_type)?; + } + writeln!(f, "on={on_str}") + } + } + } +} + +// Returns boolean for whether the join is an existence join +fn is_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} + +// Returns boolean to check if the join type needs to record +// buffered side matches for classic joins +fn need_produce_result_in_final(join_type: JoinType) -> bool { + matches!(join_type, JoinType::Full | JoinType::Right) +} + +// Returns boolean for whether or not we need to build the buffered side +// bitmap for marking matched rows on the buffered side. +fn build_visited_indices_map(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::Full + | JoinType::Right + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::LeftSemi + | JoinType::RightSemi + | JoinType::LeftMark + | JoinType::RightMark + ) +} + +async fn build_buffered_data( + buffered: SendableRecordBatchStream, + on_buffered: PhysicalExprRef, + metrics: BuildProbeJoinMetrics, + reservation: MemoryReservation, + build_map: bool, +) -> Result { + let schema = buffered.schema(); + + // Combine batches and record number of rows + let initial = (Vec::new(), 0, metrics, reservation); + let (batches, num_rows, metrics, mut reservation) = buffered + .try_fold(initial, |mut acc, batch| async { + let batch_size = get_record_batch_memory_size(&batch); + acc.3.try_grow(batch_size)?; + acc.2.build_mem_used.add(batch_size); + acc.2.build_input_batches.add(1); + acc.2.build_input_rows.add(batch.num_rows()); + // Update row count + acc.1 += batch.num_rows(); + // Push batch to output + acc.0.push(batch); + Ok(acc) + }) + .await?; + + let batches_iter = batches.iter().rev(); + let single_batch = concat_batches(&schema, batches_iter)?; + + // Evaluate physical expression on the buffered side. + let buffered_values = on_buffered + .evaluate(&single_batch)? + .into_array(single_batch.num_rows())?; + + // Created visited indices bitmap only if the join type requires it + let visited_indices_bitmap = if build_map { + let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); + reservation.try_grow(bitmap_size)?; + metrics.build_mem_used.add(bitmap_size); + + let mut bitmap_buffer = BooleanBufferBuilder::new(single_batch.num_rows()); + bitmap_buffer.append_n(num_rows, false); + bitmap_buffer + } else { + BooleanBufferBuilder::new(0) + }; + + let buffered_data = BufferedSideData::new( + single_batch, + buffered_values, + Mutex::new(visited_indices_bitmap), + reservation, + ); + + Ok(buffered_data) +} + +struct BufferedSideData { + batch: RecordBatch, + values: ArrayRef, + visited_indices_bitmap: SharedBitmapBuilder, + _reservation: MemoryReservation, +} + +impl BufferedSideData { + fn new( + batch: RecordBatch, + values: ArrayRef, + visited_indices_bitmap: SharedBitmapBuilder, + reservation: MemoryReservation, + ) -> Self { + Self { + batch, + values, + visited_indices_bitmap, + _reservation: reservation, + } + } + + fn batch(&self) -> &RecordBatch { + &self.batch + } + + fn values(&self) -> &ArrayRef { + &self.values + } +} + +enum BufferedSide { + /// Indicates that build-side not collected yet + Initial(BufferedSideInitialState), + /// Indicates that build-side data has been collected + Ready(BufferedSideReadyState), +} + +impl BufferedSide { + // Takes a mutable state of the buffered row batches + fn try_as_initial_mut(&mut self) -> Result<&mut BufferedSideInitialState> { + match self { + BufferedSide::Initial(state) => Ok(state), + _ => internal_err!("Expected build side in initial state"), + } + } + + fn try_as_ready(&self) -> Result<&BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => { + internal_err!("Expected build side in ready state") + } + } + } + + /// Tries to extract BuildSideReadyState from BuildSide enum. + /// Returns an error if state is not Ready. + fn try_as_ready_mut(&mut self) -> Result<&mut BufferedSideReadyState> { + match self { + BufferedSide::Ready(state) => Ok(state), + _ => internal_err!("Expected build side in ready state"), + } + } +} + +struct BufferedSideInitialState { + buffered_fut: OnceFut, +} + +struct BufferedSideReadyState { + /// Collected build-side data + buffered_data: Arc, +} + +enum PiecewiseMergeJoinStreamState { + WaitBufferedSide, + FetchStreamBatch, + ProcessStreamBatch(StreamedBatch), + ExhaustedStreamSide, + Completed, +} + +impl PiecewiseMergeJoinStreamState { + // Grab mutable reference to the current stream batch + fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut StreamedBatch> { + match self { + PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state), + _ => internal_err!("Expected streamed batch in StreamBatch"), + } + } +} + +struct StreamedBatch { + pub batch: RecordBatch, + values: Vec, +} + +impl StreamedBatch { + fn new(batch: RecordBatch, values: Vec) -> Self { + Self { batch, values } + } + + fn values(&self) -> &Vec { + &self.values + } +} + +struct PiecewiseMergeJoinStream { + // Output schema of the `PiecewiseMergeJoin` + pub schema: Arc, + + // Physical expression that is evaluated on the streamed side + // We do not need on_buffered as this is already evaluated when + // creating the buffered side which happens before initializing + // `PiecewiseMergeJoinStream` + pub on_streamed: PhysicalExprRef, + // Type of join + pub join_type: JoinType, + // Comparison operator + pub operator: Operator, + // Streamed batch + pub streamed: SendableRecordBatchStream, + // Streamed schema + streamed_schema: SchemaRef, + // Buffered side data + buffered_side: BufferedSide, + // Stores the min max value for the streamed side, only needed + // for existence joins. + streamed_global_min_max: Mutex>, + // Tracks the state of the `PiecewiseMergeJoin` + state: PiecewiseMergeJoinStreamState, + // Flag for whehter or not the join_type is an existence join. + existence_join: bool, + // Sort option for buffered and streamed side (specifies whether + // the sort is ascending or descending) + sort_option: SortOptions, +} + +impl RecordBatchStream for PiecewiseMergeJoinStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +// `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, +// `ProcessStreamBatch`, `ExhaustedStreamSide` and `Completed`. Classic joins and existence +// joins have a different processing order and behaviour for these states. +// +// Classic Joins +// 1. `WaitBufferedSide` - Load in the buffered side data into memory. +// 2. `FetchStreamBatch` - Fetch + sort incoming stream batches. We switch the state to +// `ExhaustedStreamBatch` once stream batches are exhausted. +// 3. `ProcessStreamBatch` - Compare stream batch row values against the buffered side data. +// 4. `ExhaustedStreamBatch` - If the join type is Left or Inner we will return state as +// `Completed` however for Full and Right we will need to process the matched/unmatched rows. +// +// Existence Joins +// 1. `FetchStreamBatch` - Fetch incoming stream batch until exhausted. We find the min/max variable +// within this state. We switch the state to `WaitBufferedSide` once we have exhausted all stream +// batches. +// 3. `WaitBufferedSide` - Load buffered side data into memory. +// 4. `ExhaustedStreamBatch` - Use the global minimum or maximum value to find the matches on +// the buffered side. +impl PiecewiseMergeJoinStream { + // Creates a new `PiecewiseMergeJoinStream` instance + pub fn try_new( + schema: Arc, + on_streamed: PhysicalExprRef, + join_type: JoinType, + operator: Operator, + streamed: SendableRecordBatchStream, + buffered_side: BufferedSide, + state: PiecewiseMergeJoinStreamState, + existence_join: bool, + sort_option: SortOptions, + ) -> Self { + let streamed_schema = streamed.schema(); + Self { + schema, + on_streamed, + join_type, + operator, + streamed_schema, + streamed, + buffered_side, + streamed_global_min_max: Mutex::new(None), + state, + existence_join, + sort_option, + } + } + + fn poll_next_impl( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + return match self.state { + PiecewiseMergeJoinStreamState::WaitBufferedSide => { + handle_state!(ready!(self.collect_buffered_side(cx))) + } + PiecewiseMergeJoinStreamState::FetchStreamBatch => { + handle_state!(ready!(self.fetch_stream_batch(cx))) + } + PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => { + handle_state!(self.process_stream_batch()) + } + PiecewiseMergeJoinStreamState::ExhaustedStreamSide => { + handle_state!(self.process_unmatched_buffered_batch()) + } + PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + // Collects buffered side data + fn collect_buffered_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + let buffered_data = ready!(self + .buffered_side + .try_as_initial_mut()? + .buffered_fut + .get_shared(cx))?; + + self.state = if self.existence_join { + // For existence joins we will start to compare the buffered + // side to the global max min to get matches. + PiecewiseMergeJoinStreamState::ExhaustedStreamSide + } else { + // We will start fetching stream batches for classic joins + PiecewiseMergeJoinStreamState::FetchStreamBatch + }; + + self.buffered_side = + BufferedSide::Ready(BufferedSideReadyState { buffered_data }); + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Fetches incoming stream batches + fn fetch_stream_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.streamed.poll_next_unpin(cx)) { + None => { + if self.existence_join { + self.state = PiecewiseMergeJoinStreamState::WaitBufferedSide; + } else { + self.state = PiecewiseMergeJoinStreamState::ExhaustedStreamSide; + } + } + Some(Ok(batch)) => { + // Evaluate the streamed physical expression on the stream batch + let stream_values: ArrayRef = self + .on_streamed + .evaluate(&batch)? + .into_array(batch.num_rows())?; + + // For existence joins we do not need to sort the output, and we only need to + // find the min or max value (depending on the operator) of all the stream batches + if self.existence_join { + let mut global_min_max = self.streamed_global_min_max.lock(); + let streamed_batch = StreamedBatch::new(batch, vec![stream_values]); + + // Finds the min/max value of the streamed batch and compares it against the global + // min/max + resolve_existence_join( + &streamed_batch, + &mut global_min_max, + self.operator, + ) + .unwrap(); + + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + + // Sort stream values and change the streamed record batch accordingly + let indices = sort_to_indices( + stream_values.as_ref(), + Some(self.sort_option), + None, + )?; + let stream_batch = take_record_batch(&batch, &indices)?; + let stream_values = take(stream_values.as_ref(), &indices, None)?; + + self.state = + PiecewiseMergeJoinStreamState::ProcessStreamBatch(StreamedBatch { + batch: stream_batch, + values: vec![stream_values], + }); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + // Only classic join will call this, we process stream batches and evaluate against + // the buffered side data. + fn process_stream_batch( + &mut self, + ) -> Result>> { + let stream_batch = self.state.try_as_process_stream_batch_mut()?; + let buffered_side = self.buffered_side.try_as_ready_mut()?; + + let result = resolve_classic_join( + stream_batch, + buffered_side, + Arc::clone(&self.schema), + self.operator, + self.sort_option, + self.join_type, + )?; + + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; + Ok(StatefulStreamResult::Ready(Some(result))) + } + + // Process remaining unmatched rows + fn process_unmatched_buffered_batch( + &mut self, + ) -> Result>> { + // Return early for `JoinType::Left` and `JoinType::Inner` + if matches!(self.join_type, JoinType::Left | JoinType::Inner) { + self.state = PiecewiseMergeJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Ready(None)); + } + + let buffered_data = Arc::clone(&self + .buffered_side + .try_as_ready() + .unwrap() + .buffered_data + ); + + if matches!( + self.join_type, + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark + | JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark + ) { + let global_min_max = self.streamed_global_min_max.lock(); + let threshold = match &*global_min_max { + Some(v) => v.clone(), + None => return exec_err!("Stream batch was empty."), + }; + + let buffered_values = buffered_data.values(); + let mut threshold_idx: Option = None; + + // We iterate the buffered size values while comparing the threshold value (min/max) + // and record our first match + for buffered_idx in 0..buffered_data.values.len() { + let buffered_value = + ScalarValue::try_from_array(&buffered_values, buffered_idx)?; + let ord = compare_rows( + &[threshold.clone()], + &[buffered_value.clone()], + &[self.sort_option], + )?; + + // Decide “past the threshold” by operator + let keep = match self.operator { + Operator::Gt | Operator::Lt => ord == Ordering::Less, + Operator::GtEq | Operator::LtEq => { + ord == Ordering::Less || ord == Ordering::Equal + } + _ => false, + }; + + // Record match + if keep { + threshold_idx = Some(buffered_idx); + break; + } + } + + let mut buffered_indices = UInt64Builder::default(); + + // If we found a match then we will append all indices from the threshold index + // to the end of the buffered size rows + if let Some(threshold_idx) = threshold_idx { + let buffered_range: Vec = + (threshold_idx as u64..buffered_data.values.len() as u64).collect(); + buffered_indices.append_slice(&buffered_range); + } + let buffered_indices_array = buffered_indices.finish(); + + // The visited bitmap hasn't been marked yet for existence joins + let mut bitmap = buffered_data.visited_indices_bitmap.lock(); + buffered_indices_array.iter().flatten().for_each(|x| { + bitmap.set_bit(x as usize, true); + }); + } + + let (buffered_indices, streamed_indices) = get_final_indices_from_shared_bitmap( + &buffered_data.visited_indices_bitmap, + self.join_type, + true, + ); + + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let batch = build_matched_indices( + self.join_type, + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + buffered_indices, + )?; + + self.state = PiecewiseMergeJoinStreamState::Completed; + + Ok(StatefulStreamResult::Ready(Some(batch))) + } +} + +impl Stream for PiecewiseMergeJoinStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.poll_next_impl(cx) + } +} + +// For Semi, Anti, and Mark joins, we are only required to have to return the marked side, so +// we just need to find the min/max value of either side. +fn resolve_existence_join( + stream_batch: &StreamedBatch, + global_min_max: &mut Option, + operator: Operator, +) -> Result<()> { + let min_max_value = global_min_max; + + // Based on the operator we will find the minimum or maximum value. + match operator { + Operator::Gt | Operator::GtEq => { + let max_value = min_max(&stream_batch.values[0], true)?; + let new_max = if let Some(prev) = (*min_max_value).clone() { + if max_value.partial_cmp(&prev).unwrap() == Ordering::Greater { + max_value + } else { + prev + } + } else { + max_value + }; + *min_max_value = Some(new_max); + } + + Operator::Lt | Operator::LtEq => { + let min_value = min_max(&stream_batch.values[0], false)?; + let new_min = if let Some(prev) = (*min_max_value).clone() { + if min_value.partial_cmp(&prev).unwrap() == Ordering::Less { + min_value + } else { + prev + } + } else { + min_value + }; + *min_max_value = Some(new_min); + } + _ => { + return exec_err!( + "PiecewiseMergeJoin should not contain operator, {}", + operator + ) + } + }; + + Ok(()) +} + +// For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. +fn resolve_classic_join( + stream_batch: &StreamedBatch, + buffered_side: &BufferedSideReadyState, + join_schema: Arc, + operator: Operator, + sort_options: SortOptions, + join_type: JoinType, +) -> Result { + let stream_values = stream_batch.values(); + let buffered_values = buffered_side.buffered_data.values(); + let buffered_len = buffered_values.len(); + + let mut stream_indices = UInt32Builder::default(); + let mut buffered_indices = UInt64Builder::default(); + + // Our pivot variable allows us to start probing on the buffered side where we last matched + // in the previous stream row. + let mut pivot = 0; + for row_idx in 0..stream_values[0].len() { + let mut found = false; + while pivot < buffered_values.len() { + let compare = compare_join_arrays( + &[Arc::clone(&stream_values[0])], + row_idx, + &[Arc::clone(buffered_values)], + pivot, + &[sort_options], + NullEquality::NullEqualsNothing, + )?; + + // If we find a match we append all indices and move to the next stream row index + match operator { + Operator::Gt | Operator::Lt => { + if matches!(compare, Ordering::Less) { + let count = buffered_values.len() - pivot; + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (pivot as u64..buffered_len as u64).collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + found = true; + + break; + } + pivot += 1; + } + Operator::GtEq | Operator::LtEq => { + if matches!(compare, Ordering::Equal | Ordering::Less) { + let count = buffered_values.len() - pivot; + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (pivot as u64..buffered_len as u64).collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + found = true; + + break; + } + pivot += 1; + } + _ => { + return exec_err!( + "PiecewiseMergeJoin should not contain operator, {}", + operator + ) + } + }; + } + + // If not found we append a null value for `JoinType::Left` and `JoinType::Full` + if !found && matches!(join_type, JoinType::Left | JoinType::Full) { + stream_indices.append_value(row_idx as u32); + buffered_indices.append_null(); + } + } + + let stream_indices_array = stream_indices.finish(); + let buffered_indices_array = buffered_indices.finish(); + + // We need to mark the buffered side matched indices for `JoinType::Full` and `JoinType::Right` + if need_produce_result_in_final(join_type) { + let mut bitmap = buffered_side.buffered_data.visited_indices_bitmap.lock(); + + buffered_indices_array.iter().flatten().for_each(|x| { + bitmap.set_bit(x as usize, true); + }); + } + + let batch = build_matched_indices( + join_type, + join_schema, + &stream_batch.batch, + &buffered_side.buffered_data.batch, + stream_indices_array, + buffered_indices_array, + )?; + + Ok(batch) +} + +fn build_matched_indices( + join_type: JoinType, + schema: Arc, + streamed_batch: &RecordBatch, + buffered_batch: &RecordBatch, + streamed_indices: UInt32Array, + buffered_indices: UInt64Array, +) -> Result { + if schema.fields().is_empty() { + // Build an “empty” RecordBatch with just row‐count metadata + let options = RecordBatchOptions::new() + .with_match_field_names(true) + .with_row_count(Some(streamed_indices.len())); + return Ok(RecordBatch::try_new_with_options( + Arc::new((*schema).clone()), + vec![], + &options, + )?); + } + + // Gather stream columns after applying filter specified with stream indices + let mut streamed_columns = if !is_existence_join(join_type) { + streamed_batch + .columns() + .iter() + .map(|column_array| { + if column_array.is_empty() + || streamed_indices.null_count() == streamed_indices.len() + { + assert_eq!(streamed_indices.null_count(), streamed_indices.len()); + Ok(new_null_array( + column_array.data_type(), + streamed_indices.len(), + )) + } else { + take(column_array, &streamed_indices, None) + } + }) + .collect::, ArrowError>>()? + } else { + vec![] + }; + + let buffered_columns = buffered_batch + .columns() + .iter() + .map(|column_array| take(column_array, &buffered_indices, None)) + .collect::, ArrowError>>()?; + + streamed_columns.extend(buffered_columns); + + Ok(RecordBatch::try_new( + Arc::new((*schema).clone()), + streamed_columns, + )?) +} + +pub fn min_max(array: &ArrayRef, find_max: bool) -> Result { + macro_rules! find_min_max { + ($ARR:ty, $SCALAR:ident) => {{ + let arr = array.as_any().downcast_ref::<$ARR>().unwrap(); + let mut extreme: Option<_> = None; + for i in 0..arr.len() { + if arr.is_valid(i) { + let v = arr.value(i); + extreme = Some(match extreme { + Some(cur) => { + if find_max { + if v > cur { + v + } else { + cur + } + } else { + if v < cur { + v + } else { + cur + } + } + } + None => v, + }); + } + } + ScalarValue::$SCALAR(extreme) + }}; + } + + let result = match array.data_type() { + DataType::Int8 => find_min_max!(Int8Array, Int8), + DataType::Int16 => find_min_max!(Int16Array, Int16), + DataType::Int32 => find_min_max!(Int32Array, Int32), + DataType::Int64 => find_min_max!(Int64Array, Int64), + DataType::UInt8 => find_min_max!(UInt8Array, UInt8), + DataType::UInt16 => find_min_max!(UInt16Array, UInt16), + DataType::UInt32 => find_min_max!(UInt32Array, UInt32), + DataType::UInt64 => find_min_max!(UInt64Array, UInt64), + DataType::Float32 => find_min_max!(Float32Array, Float32), + DataType::Float64 => find_min_max!(Float64Array, Float64), + + DataType::Boolean => { + let arr = array.as_any().downcast_ref::().unwrap(); + let mut extreme: Option = None; + for i in 0..arr.len() { + if arr.is_valid(i) { + let v = arr.value(i); + extreme = Some(match extreme { + Some(cur) => { + if find_max { + cur || v // max: true if either is true + } else { + cur && v // min: false if either is false + } + } + None => v, + }); + } + } + ScalarValue::Boolean(extreme) + } + + dt => { + return not_impl_err!( + "Unsupported data type in PiecewiseMergeJoin min/max function: {}", + dt + ); + } + }; + + Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + common, + test::{build_table_i32, TestMemoryExec}, + ExecutionPlan, + }; + use arrow::array::{Date32Array, Date64Array}; + use arrow_schema::Field; + use datafusion_common::test_util::batches_to_string; + use datafusion_execution::TaskContext; + use datafusion_expr::JoinType; + use datafusion_physical_expr::expressions::Column; + use insta::assert_snapshot; + use std::sync::Arc; + + fn columns(schema: &Schema) -> Vec { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn build_date64_table( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap() + } + + fn join( + left: Arc, + right: Arc, + on: (Arc, Arc), + operator: Operator, + join_type: JoinType, + ) -> Result { + Ok(PiecewiseMergeJoinExec::try_new( + left, right, on, operator, join_type, + )?) + } + + async fn join_collect( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + join_collect_with_options(left, right, on, operator, join_type).await + } + + async fn join_collect_with_options( + left: Arc, + right: Arc, + on: (PhysicalExprRef, PhysicalExprRef), + operator: Operator, + join_type: JoinType, + ) -> Result<(Vec, Vec)> { + let task_ctx = Arc::new(TaskContext::default()); + let join = join(left, right, on, operator, join_type)?; + let columns = columns(&join.schema()); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_less_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 2, 3]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | 7 | 10 | 2 | 70 | + | 1 | 1 | 7 | 20 | 3 | 80 | + | 1 | 1 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 20 | 3 | 80 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 3 | 9 | 30 | 4 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_less_than_unsorted() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 1, 2]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 1 | 8 | 10 | 2 | 70 | + | 2 | 1 | 8 | 20 | 3 | 80 | + | 2 | 1 | 8 | 30 | 4 | 90 | + | 3 | 2 | 9 | 20 | 3 | 80 | + | 3 | 2 | 9 | 30 | 4 | 90 | + | 1 | 3 | 7 | 30 | 4 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_greater_than_equal_to() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![2, 3, 4]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 3 | 4 | 9 | 10 | 3 | 70 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 2 | 3 | 8 | 10 | 3 | 70 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 2 | 3 | 8 | 30 | 1 | 90 | + | 1 | 2 | 7 | 20 | 2 | 80 | + | 1 | 2 | 7 | 30 | 1 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_inner_empty_left() -> Result<()> { + let left = build_table( + ("a1", &Vec::::new()), + ("b1", &Vec::::new()), + ("c1", &Vec::::new()), + ); + let right = build_table( + ("a2", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c2", &vec![1, 2]), + ); + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_full_greater_than_equal_to() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1, 2]), + ("c1", &vec![100, 200]), + ); + let right = build_table( + ("a2", &vec![10, 20]), + ("b1", &vec![3, 2]), + ("c2", &vec![300, 400]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+-----+----+----+-----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+-----+ + | 2 | 2 | 200 | 20 | 2 | 400 | + | 1 | 1 | 100 | | | | + | | | | 10 | 3 | 300 | + +----+----+-----+----+----+-----+ + "#); + + Ok(()) + } + + #[tokio::test] + async fn join_left_greater_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![3, 2, 1]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Left).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 3 | 4 | 9 | 10 | 3 | 70 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 2 | 3 | 8 | 30 | 1 | 90 | + | 1 | 1 | 7 | | | | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_greater_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![5, 3, 2]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 3 | 4 | 9 | 20 | 3 | 80 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 2 | 3 | 8 | 30 | 2 | 90 | + | | | | 10 | 5 | 70 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_less_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 1 | 7 | 10 | 2 | 70 | + | 1 | 1 | 7 | 20 | 3 | 80 | + | 1 | 1 | 7 | 30 | 5 | 90 | + | 2 | 3 | 8 | 30 | 5 | 90 | + | 3 | 4 | 9 | 30 | 5 | 90 | + +----+----+----+----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_less_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::RightSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + | 10 | 2 | 70 | + | 20 | 3 | 80 | + | 30 | 5 | 90 | + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_semi_less_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![5, 4, 1]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 2 | 4 | 8 | + | 3 | 1 | 9 | + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_semi_greater_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 4, 5]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 2 | 4 | 8 | + | 3 | 5 | 9 | + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_anti_greater_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 4, 5]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Gt, JoinType::LeftAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 1 | 7 | + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_semi_greater_than_equal_to() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1, 3, 4]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1, 2, 3]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::GtEq, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 1 | 7 | + | 2 | 3 | 8 | + | 3 | 4 | 9 | + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_anti_less_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![5, 4, 1]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::LeftAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 5 | 7 | + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_left_anti_less_than_equal_to() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![5, 4, 1]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::LeftAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_semi_unsorted() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![3, 1, 2]), // unsorted + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 4]), + ("c2", &vec![70, 80, 90]), + ); + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::RightSemi).await?; + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + | 10 | 2 | 70 | + | 20 | 3 | 80 | + | 30 | 4 | 90 | + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_less_than_equal_to() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![5, 4, 1]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![2, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::LtEq, JoinType::RightAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_right_anti_less_than() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![5, 4, 1]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1, 3, 5]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::RightAnti).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + | 10 | 1 | 70 | + +----+----+----+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date32_inner_less_than() -> Result<()> { + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19105, 19107, 19107]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19105, 19103, 19107]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-02 | 2022-04-23 | 1970-01-08 | 1970-01-31 | 2022-04-25 | 1970-04-01 | + +------------+------------+------------+------------+------------+------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_inner_less_than() -> Result<()> { + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "#); + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index a8c209a492ba..68d53aa7a47a 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -2433,7 +2433,7 @@ fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) -> Vec bool { pub(crate) fn get_final_indices_from_shared_bitmap( shared_bitmap: &SharedBitmapBuilder, join_type: JoinType, + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let bitmap = shared_bitmap.lock(); - get_final_indices_from_bit_map(&bitmap, join_type) + get_final_indices_from_bit_map(&bitmap, join_type, piecewise) } /// In the end of join execution, need to use bit map of the matched @@ -808,16 +809,22 @@ pub(crate) fn get_final_indices_from_shared_bitmap( pub(crate) fn get_final_indices_from_bit_map( left_bit_map: &BooleanBufferBuilder, join_type: JoinType, + // We add a flag for whether this is being passed from the `PiecewiseMergeJoin` + // because the bitmap can be for left + right `JoinType`s + piecewise: bool, ) -> (UInt64Array, UInt32Array) { let left_size = left_bit_map.len(); - if join_type == JoinType::LeftMark { + if join_type == JoinType::LeftMark || (join_type == JoinType::RightMark && piecewise) + { let left_indices = (0..left_size as u64).collect::(); let right_indices = (0..left_size) .map(|idx| left_bit_map.get_bit(idx).then_some(0)) .collect::(); return (left_indices, right_indices); } - let left_indices = if join_type == JoinType::LeftSemi { + let left_indices = if join_type == JoinType::LeftSemi + || (join_type == JoinType::RightSemi && piecewise) + { (0..left_size) .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) .collect::() From 6af253b350a80cbfde398dd12073f902c3119ab7 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 2 Jul 2025 15:57:52 -0400 Subject: [PATCH 02/24] fix --- datafusion/physical-plan/src/joins/piecewise_merge_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 1b37c414e978..072446a417ac 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -323,7 +323,7 @@ impl PiecewiseMergeJoinExec { buffered.equivalence_properties().clone(), &join_type, schema, - &[false], + &[false, false], Some(Self::probe_side(&join_type)), &[join_on.clone()], )?; From 60047ba397ca45411ddeb71141471a46167215ce Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 2 Jul 2025 16:00:19 -0400 Subject: [PATCH 03/24] add children() --- datafusion/physical-plan/src/joins/piecewise_merge_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 072446a417ac..c461fa3a9972 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -358,7 +358,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { } fn children(&self) -> Vec<&Arc> { - todo!() + vec![&self.streamed, &self.buffered] } fn required_input_ordering(&self) -> Vec> { From 6199b6106c800c7c4d0f9897413f20af7489c6dc Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 2 Jul 2025 16:18:19 -0400 Subject: [PATCH 04/24] fmt --- .../src/joins/piecewise_merge_join.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index c461fa3a9972..8ef3cdda2f86 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -228,8 +228,10 @@ impl PiecewiseMergeJoinExec { } }; - let left_sort_exprs = vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; - let right_sort_exprs = vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; + let left_sort_exprs = + vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; + let right_sort_exprs = + vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { return plan_err!( "PiecewiseMergeJoinExec requires valid sort expressions for its left side" @@ -309,7 +311,7 @@ impl PiecewiseMergeJoinExec { } // TODO: fix compute properties to work specifically for PiecewiseMergeJoin - // This is currently just a filler implementation so that it actually returns + // This is currently just a filler implementation so that it actually returns // a PlanProperties pub fn compute_properties( streamed: &Arc, @@ -910,12 +912,8 @@ impl PiecewiseMergeJoinStream { return Ok(StatefulStreamResult::Ready(None)); } - let buffered_data = Arc::clone(&self - .buffered_side - .try_as_ready() - .unwrap() - .buffered_data - ); + let buffered_data = + Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); if matches!( self.join_type, @@ -987,7 +985,8 @@ impl PiecewiseMergeJoinStream { ); let buffered_batch = buffered_data.batch(); - let empty_stream_batch = RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); let batch = build_matched_indices( self.join_type, From e7cf488541137311ee38dd6d5d20e5be828a6593 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 2 Jul 2025 16:36:07 -0400 Subject: [PATCH 05/24] tma --- datafusion/physical-plan/src/joins/piecewise_merge_join.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 8ef3cdda2f86..336520a39908 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -744,6 +744,7 @@ impl RecordBatchStream for PiecewiseMergeJoinStream { // the buffered side. impl PiecewiseMergeJoinStream { // Creates a new `PiecewiseMergeJoinStream` instance + #[allow(clippy::too_many_arguments)] pub fn try_new( schema: Arc, on_streamed: PhysicalExprRef, From e77bd5ad4825b97a2d4fa0b9d79b70d17e544582 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 2 Jul 2025 17:11:30 -0400 Subject: [PATCH 06/24] clippy --- datafusion/physical-plan/src/joins/piecewise_merge_join.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 336520a39908..7ce9443b8803 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -1386,9 +1386,7 @@ mod tests { operator: Operator, join_type: JoinType, ) -> Result { - Ok(PiecewiseMergeJoinExec::try_new( - left, right, on, operator, join_type, - )?) + PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type) } async fn join_collect( From 7212242e65239809c24b0be7f9c5b808cfbeaf9f Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 2 Jul 2025 21:29:39 -0400 Subject: [PATCH 07/24] fix: Required input ordering --- .../src/joins/piecewise_merge_join.rs | 91 +++++++++++++------ 1 file changed, 63 insertions(+), 28 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 7ce9443b8803..4a35a88388d8 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -219,7 +219,15 @@ impl PiecewiseMergeJoinExec { // We take the operator and enforce a sort order on the streamed + buffered side based on // the operator type. let sort_options = match operator { - Operator::Lt | Operator::LtEq => SortOptions::new(false, false), + Operator::Lt | Operator::LtEq => { + // For the Left existence joins the inputs will be swapped so we need to switch the sort + // options. + if is_left_existence_join(join_type) { + SortOptions::new(true, false) + } else { + SortOptions::new(false, false) + } + } Operator::Gt | Operator::GtEq => SortOptions::new(true, false), _ => { return plan_err!( @@ -364,10 +372,25 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { } fn required_input_ordering(&self) -> Vec> { - vec![ - Some(OrderingRequirements::from(self.left_sort_exprs.clone())), - Some(OrderingRequirements::from(self.right_sort_exprs.clone())), - ] + // Existence joins don't need to be sorted on one side. + if is_left_existence_join(self.join_type) { + // Left side needs to be sorted because this will be swapped to the + // buffered side + vec![ + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + None, + ] + } else if is_right_existence_join(self.join_type) { + vec![ + None, + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), + ] + } else { + vec![ + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + Some(OrderingRequirements::from(self.right_sort_exprs.clone())), + ] + } } fn with_new_children( @@ -399,28 +422,24 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { // If the join type is either LeftSemi, LeftAnti, or LeftMark we will swap the inputs // and sort ordering because we want the mark side to be the buffered side. - let (streamed, buffered, on_streamed, on_buffered, operator, sort_options) = if matches!( - self.join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark - ) { - ( - Arc::clone(&self.buffered), - Arc::clone(&self.streamed), - on_buffered, - on_streamed, - self.operator.swap().unwrap(), - SortOptions::new(!self.sort_options.descending, false), - ) - } else { - ( - Arc::clone(&self.streamed), - Arc::clone(&self.buffered), - on_streamed, - on_buffered, - self.operator, - self.sort_options, - ) - }; + let (streamed, buffered, on_streamed, on_buffered, operator) = + if is_left_existence_join(self.join_type) { + ( + Arc::clone(&self.buffered), + Arc::clone(&self.streamed), + on_buffered, + on_streamed, + self.operator.swap().unwrap(), + ) + } else { + ( + Arc::clone(&self.streamed), + Arc::clone(&self.buffered), + on_streamed, + on_buffered, + self.operator, + ) + }; let metrics = BuildProbeJoinMetrics::new(0, &self.metrics); let buffered_fut = self.buffered_fut.try_once(|| { @@ -452,7 +471,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { PiecewiseMergeJoinStreamState::WaitBufferedSide }, existence_join, - sort_options, + self.sort_options, ))) } } @@ -499,6 +518,22 @@ fn is_existence_join(join_type: JoinType) -> bool { ) } +// Returns boolean for whether the join is a left existence join +fn is_left_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::LeftAnti | JoinType::LeftSemi | JoinType::LeftMark + ) +} + +// Returns boolean for whether the join is a right existence join +fn is_right_existence_join(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark + ) +} + // Returns boolean to check if the join type needs to record // buffered side matches for classic joins fn need_produce_result_in_final(join_type: JoinType) -> bool { From 30e73e45c1359e90c9ead77de04d1e3b291473e0 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 2 Jul 2025 22:22:33 -0400 Subject: [PATCH 08/24] fix --- .../physical-plan/src/joins/piecewise_merge_join.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 4a35a88388d8..762ed4b30748 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -228,7 +228,13 @@ impl PiecewiseMergeJoinExec { SortOptions::new(false, false) } } - Operator::Gt | Operator::GtEq => SortOptions::new(true, false), + Operator::Gt | Operator::GtEq => { + if is_left_existence_join(join_type) { + SortOptions::new(false, false) + } else { + SortOptions::new(true, false) + } + } _ => { return plan_err!( "Cannot contain non-range operator in PiecewiseMergeJoinExec" From feaee9a7216b3015ee24c6d284c6a821ab8afa9d Mon Sep 17 00:00:00 2001 From: Jonathan Date: Sat, 5 Jul 2025 20:12:53 -0400 Subject: [PATCH 09/24] fix sorting --- .../physical-plan/src/joins/piecewise_merge_join.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 762ed4b30748..e328654e3b58 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -386,14 +386,10 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { Some(OrderingRequirements::from(self.left_sort_exprs.clone())), None, ] - } else if is_right_existence_join(self.join_type) { - vec![ - None, - Some(OrderingRequirements::from(self.right_sort_exprs.clone())), - ] } else { + // We sort the left side in memory, so we do not need to enforce any sorting vec![ - Some(OrderingRequirements::from(self.left_sort_exprs.clone())), + None, Some(OrderingRequirements::from(self.right_sort_exprs.clone())), ] } From 75a60b297dc2556e7abbed13e7ebe5bee7ee245e Mon Sep 17 00:00:00 2001 From: Jonathan Date: Sat, 5 Jul 2025 21:00:32 -0400 Subject: [PATCH 10/24] update --- .../physical-plan/src/joins/piecewise_merge_join.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index e328654e3b58..9ecfe4edd7af 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -528,14 +528,6 @@ fn is_left_existence_join(join_type: JoinType) -> bool { ) } -// Returns boolean for whether the join is a right existence join -fn is_right_existence_join(join_type: JoinType) -> bool { - matches!( - join_type, - JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark - ) -} - // Returns boolean to check if the join type needs to record // buffered side matches for classic joins fn need_produce_result_in_final(join_type: JoinType) -> bool { From 08456003087ab4aa62b03d3c182c2de0411f3508 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Mon, 7 Jul 2025 12:42:46 -0400 Subject: [PATCH 11/24] Update datafusion/physical-plan/src/joins/mod.rs Co-authored-by: Oleks V --- datafusion/physical-plan/src/joins/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 7eb64ed5b313..92b851e9ed15 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -24,7 +24,6 @@ pub use hash_join::HashJoinExec; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; pub use piecewise_merge_join::PiecewiseMergeJoinExec; -// Note: SortMergeJoin is not used in plans yet pub use sort_merge_join::SortMergeJoinExec; pub use symmetric_hash_join::SymmetricHashJoinExec; mod cross_join; From eb9040f24eb80ffba269b0b0588bda5191b61cf2 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 7 Jul 2025 15:58:40 -0400 Subject: [PATCH 12/24] feat: Add `compute_properties` + add comments --- .../src/joins/piecewise_merge_join.rs | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 9ecfe4edd7af..24bf784d12b6 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -307,26 +307,23 @@ impl PiecewiseMergeJoinExec { &self.sort_options } - /// Get probe side (buffered side) for the PiecewiseMergeJoin + /// Get probe side (streameded side) for the PiecewiseMergeJoin /// In current implementation, probe side is determined according to join type. pub fn probe_side(join_type: &JoinType) -> JoinSide { match join_type { JoinType::Right | JoinType::RightSemi | JoinType::RightAnti - | JoinType::RightMark => JoinSide::Right, + | JoinType::RightMark => JoinSide::Left, JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::LeftAnti | JoinType::LeftSemi - | JoinType::LeftMark => JoinSide::Left, + | JoinType::LeftMark => JoinSide::Right, } } - // TODO: fix compute properties to work specifically for PiecewiseMergeJoin - // This is currently just a filler implementation so that it actually returns - // a PlanProperties pub fn compute_properties( streamed: &Arc, buffered: &Arc, @@ -339,7 +336,7 @@ impl PiecewiseMergeJoinExec { buffered.equivalence_properties().clone(), &join_type, schema, - &[false, false], + &Self::maintains_input_order(join_type), Some(Self::probe_side(&join_type)), &[join_on.clone()], )?; @@ -355,6 +352,25 @@ impl PiecewiseMergeJoinExec { )) } + fn maintains_input_order(join_type: JoinType) -> Vec { + match join_type { + // The existence side is expected to come in sorted + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => vec![true, false], + JoinType::RightSemi + | JoinType::RightAnti + | JoinType::RightMark => { + vec![false, true] + } + // Left, Right, Full, Inner Join is not guaranteed to maintain + // input order as the streamed side will be sorted during + // execution for `PiecewiseMergeJoin` + _ => vec![false, false], + } + } + + // TODO: We implement this with the physical planner. pub fn swap_inputs(&self) -> Result> { todo!() } @@ -945,6 +961,8 @@ impl PiecewiseMergeJoinStream { let buffered_data = Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); + // For Semi/Anti/Mark joins we mark indices on the buffered side, and retrieve final indices from + // `get_final_indices_bitmap` if matches!( self.join_type, JoinType::LeftSemi @@ -1001,13 +1019,14 @@ impl PiecewiseMergeJoinStream { } let buffered_indices_array = buffered_indices.finish(); - // The visited bitmap hasn't been marked yet for existence joins + // Mark bitmap here because the visited bitmap hasn't been marked yet for existence joins let mut bitmap = buffered_data.visited_indices_bitmap.lock(); buffered_indices_array.iter().flatten().for_each(|x| { bitmap.set_bit(x as usize, true); }); } + // Pass in piecewise flag to allow Right Semi/Anti/Mark joins to also be processed let (buffered_indices, streamed_indices) = get_final_indices_from_shared_bitmap( &buffered_data.visited_indices_bitmap, self.join_type, From f33b8d84d9eff3e07717f004991dd322ca3d48f2 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 7 Jul 2025 15:59:17 -0400 Subject: [PATCH 13/24] fmt! --- .../src/joins/piecewise_merge_join.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 24bf784d12b6..90444356637e 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -355,16 +355,14 @@ impl PiecewiseMergeJoinExec { fn maintains_input_order(join_type: JoinType) -> Vec { match join_type { // The existence side is expected to come in sorted - JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::LeftMark => vec![true, false], - JoinType::RightSemi - | JoinType::RightAnti - | JoinType::RightMark => { + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + vec![true, false] + } + JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { vec![false, true] } - // Left, Right, Full, Inner Join is not guaranteed to maintain - // input order as the streamed side will be sorted during + // Left, Right, Full, Inner Join is not guaranteed to maintain + // input order as the streamed side will be sorted during // execution for `PiecewiseMergeJoin` _ => vec![false, false], } @@ -961,7 +959,7 @@ impl PiecewiseMergeJoinStream { let buffered_data = Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); - // For Semi/Anti/Mark joins we mark indices on the buffered side, and retrieve final indices from + // For Semi/Anti/Mark joins we mark indices on the buffered side, and retrieve final indices from // `get_final_indices_bitmap` if matches!( self.join_type, From 250a8a506441f5864d2e9ee184a9e61fbf6cdd92 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 7 Jul 2025 17:31:33 -0400 Subject: [PATCH 14/24] feat: Add metrics + memory reservation --- .../src/joins/piecewise_merge_join.rs | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 90444356637e..42c5f77604e4 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::{ - new_null_array, Array, BooleanArray, Float32Array, Float64Array, Int16Array, + new_null_array, Array, AsArray, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, RecordBatchOptions, UInt16Array, UInt8Array, }; use arrow::compute::take; @@ -465,7 +465,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { Ok(build_buffered_data( buffered_stream, Arc::clone(&on_buffered), - metrics, + metrics.clone(), reservation, build_visited_indices_map(self.join_type), )) @@ -488,6 +488,7 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { }, existence_join, self.sort_options, + metrics, ))) } } @@ -598,6 +599,13 @@ async fn build_buffered_data( .evaluate(&single_batch)? .into_array(single_batch.num_rows())?; + // We add the single batch size + the memory of the join keys + // size of the size estimation + let size_estimation = get_record_batch_memory_size(&single_batch) + + buffered_values.get_array_memory_size(); + reservation.try_grow(size_estimation)?; + metrics.build_mem_used.add(size_estimation); + // Created visited indices bitmap only if the join type requires it let visited_indices_bitmap = if build_map { let bitmap_size = bit_util::ceil(single_batch.num_rows(), 8); @@ -758,6 +766,8 @@ struct PiecewiseMergeJoinStream { // Sort option for buffered and streamed side (specifies whether // the sort is ascending or descending) sort_option: SortOptions, + // Metrics for build + probe joins + join_metrics: BuildProbeJoinMetrics, } impl RecordBatchStream for PiecewiseMergeJoinStream { @@ -798,6 +808,7 @@ impl PiecewiseMergeJoinStream { state: PiecewiseMergeJoinStreamState, existence_join: bool, sort_option: SortOptions, + join_metrics: BuildProbeJoinMetrics, ) -> Self { let streamed_schema = streamed.schema(); Self { @@ -812,6 +823,7 @@ impl PiecewiseMergeJoinStream { state, existence_join, sort_option, + join_metrics, } } @@ -843,11 +855,13 @@ impl PiecewiseMergeJoinStream { &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>>> { + let build_timer = self.join_metrics.build_time.timer(); let buffered_data = ready!(self .buffered_side .try_as_initial_mut()? .buffered_fut .get_shared(cx))?; + build_timer.done(); self.state = if self.existence_join { // For existence joins we will start to compare the buffered @@ -884,9 +898,14 @@ impl PiecewiseMergeJoinStream { .evaluate(&batch)? .into_array(batch.num_rows())?; + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + // For existence joins we do not need to sort the output, and we only need to // find the min or max value (depending on the operator) of all the stream batches if self.existence_join { + // Run timer during this phase as finding the min/max on streamed side is considered join time. + let timer = self.join_metrics.join_time.timer(); let mut global_min_max = self.streamed_global_min_max.lock(); let streamed_batch = StreamedBatch::new(batch, vec![stream_values]); @@ -898,6 +917,7 @@ impl PiecewiseMergeJoinStream { self.operator, ) .unwrap(); + timer.done(); self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; return Poll::Ready(Ok(StatefulStreamResult::Continue)); @@ -924,7 +944,7 @@ impl PiecewiseMergeJoinStream { Poll::Ready(Ok(StatefulStreamResult::Continue)) } - // Only classic join will call this, we process stream batches and evaluate against + // Only classic join will call. This function will process stream batches and evaluate against // the buffered side data. fn process_stream_batch( &mut self, @@ -932,7 +952,7 @@ impl PiecewiseMergeJoinStream { let stream_batch = self.state.try_as_process_stream_batch_mut()?; let buffered_side = self.buffered_side.try_as_ready_mut()?; - let result = resolve_classic_join( + let batch = resolve_classic_join( stream_batch, buffered_side, Arc::clone(&self.schema), @@ -942,7 +962,7 @@ impl PiecewiseMergeJoinStream { )?; self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; - Ok(StatefulStreamResult::Ready(Some(result))) + Ok(StatefulStreamResult::Ready(Some(batch))) } // Process remaining unmatched rows @@ -952,14 +972,15 @@ impl PiecewiseMergeJoinStream { // Return early for `JoinType::Left` and `JoinType::Inner` if matches!(self.join_type, JoinType::Left | JoinType::Inner) { self.state = PiecewiseMergeJoinStreamState::Completed; - return Ok(StatefulStreamResult::Ready(None)); } + let timer = self.join_metrics.join_time.timer(); + let buffered_data = Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); - // For Semi/Anti/Mark joins we mark indices on the buffered side, and retrieve final indices from + // For Semi/Anti/Mark joins that mark indices on the buffered side, and retrieve final indices from // `get_final_indices_bitmap` if matches!( self.join_type, @@ -979,7 +1000,7 @@ impl PiecewiseMergeJoinStream { let buffered_values = buffered_data.values(); let mut threshold_idx: Option = None; - // We iterate the buffered size values while comparing the threshold value (min/max) + // Iterate the buffered size values while comparing the threshold value (min/max) // and record our first match for buffered_idx in 0..buffered_data.values.len() { let buffered_value = @@ -1008,7 +1029,7 @@ impl PiecewiseMergeJoinStream { let mut buffered_indices = UInt64Builder::default(); - // If we found a match then we will append all indices from the threshold index + // If a match is found then append all indices from the threshold index // to the end of the buffered size rows if let Some(threshold_idx) = threshold_idx { let buffered_range: Vec = @@ -1044,6 +1065,9 @@ impl PiecewiseMergeJoinStream { buffered_indices, )?; + timer.done(); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); self.state = PiecewiseMergeJoinStreamState::Completed; Ok(StatefulStreamResult::Ready(Some(batch))) From a2ec52bb345b89ecb6fcbb2f1471275d606a9fd7 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 7 Jul 2025 18:05:39 -0400 Subject: [PATCH 15/24] min/max refactor --- Cargo.lock | 1 + datafusion/physical-plan/Cargo.toml | 1 + .../src/joins/piecewise_merge_join.rs | 92 ++----------------- 3 files changed, 9 insertions(+), 85 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bf0d19db8413..969a2ca50d47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2511,6 +2511,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", "datafusion-functions-window", "datafusion-functions-window-common", "datafusion-physical-expr", diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 095ee78cd0d6..9889b45cc5a5 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -53,6 +53,7 @@ datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-physical-expr-common = { workspace = true } diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 42c5f77604e4..737a884ccb06 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -15,10 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ - new_null_array, Array, AsArray, BooleanArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, RecordBatchOptions, UInt16Array, UInt8Array, -}; +use arrow::array::{new_null_array, Array, RecordBatchOptions}; use arrow::compute::take; use arrow::{ array::{ @@ -28,16 +25,17 @@ use arrow::{ compute::{concat_batches, sort_to_indices, take_record_batch}, util::bit_util, }; -use arrow_schema::{ArrowError, DataType, Schema, SchemaRef, SortOptions}; +use arrow_schema::{ArrowError, Schema, SchemaRef, SortOptions}; +use datafusion_common::NullEquality; use datafusion_common::{ exec_err, internal_err, plan_err, utils::compare_rows, JoinSide, Result, ScalarValue, }; -use datafusion_common::{not_impl_err, NullEquality}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, RecordBatchStream, SendableRecordBatchStream, }; use datafusion_expr::{JoinType, Operator}; +use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch}; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{ LexOrdering, OrderingRequirements, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, @@ -1097,7 +1095,7 @@ fn resolve_existence_join( // Based on the operator we will find the minimum or maximum value. match operator { Operator::Gt | Operator::GtEq => { - let max_value = min_max(&stream_batch.values[0], true)?; + let max_value = max_batch(&stream_batch.values[0])?; let new_max = if let Some(prev) = (*min_max_value).clone() { if max_value.partial_cmp(&prev).unwrap() == Ordering::Greater { max_value @@ -1111,7 +1109,7 @@ fn resolve_existence_join( } Operator::Lt | Operator::LtEq => { - let min_value = min_max(&stream_batch.values[0], false)?; + let min_value = min_batch(&stream_batch.values[0])?; let new_min = if let Some(prev) = (*min_max_value).clone() { if min_value.partial_cmp(&prev).unwrap() == Ordering::Less { min_value @@ -1292,82 +1290,6 @@ fn build_matched_indices( )?) } -pub fn min_max(array: &ArrayRef, find_max: bool) -> Result { - macro_rules! find_min_max { - ($ARR:ty, $SCALAR:ident) => {{ - let arr = array.as_any().downcast_ref::<$ARR>().unwrap(); - let mut extreme: Option<_> = None; - for i in 0..arr.len() { - if arr.is_valid(i) { - let v = arr.value(i); - extreme = Some(match extreme { - Some(cur) => { - if find_max { - if v > cur { - v - } else { - cur - } - } else { - if v < cur { - v - } else { - cur - } - } - } - None => v, - }); - } - } - ScalarValue::$SCALAR(extreme) - }}; - } - - let result = match array.data_type() { - DataType::Int8 => find_min_max!(Int8Array, Int8), - DataType::Int16 => find_min_max!(Int16Array, Int16), - DataType::Int32 => find_min_max!(Int32Array, Int32), - DataType::Int64 => find_min_max!(Int64Array, Int64), - DataType::UInt8 => find_min_max!(UInt8Array, UInt8), - DataType::UInt16 => find_min_max!(UInt16Array, UInt16), - DataType::UInt32 => find_min_max!(UInt32Array, UInt32), - DataType::UInt64 => find_min_max!(UInt64Array, UInt64), - DataType::Float32 => find_min_max!(Float32Array, Float32), - DataType::Float64 => find_min_max!(Float64Array, Float64), - - DataType::Boolean => { - let arr = array.as_any().downcast_ref::().unwrap(); - let mut extreme: Option = None; - for i in 0..arr.len() { - if arr.is_valid(i) { - let v = arr.value(i); - extreme = Some(match extreme { - Some(cur) => { - if find_max { - cur || v // max: true if either is true - } else { - cur && v // min: false if either is false - } - } - None => v, - }); - } - } - ScalarValue::Boolean(extreme) - } - - dt => { - return not_impl_err!( - "Unsupported data type in PiecewiseMergeJoin min/max function: {}", - dt - ); - } - }; - - Ok(result) -} - #[cfg(test)] mod tests { use super::*; @@ -1377,7 +1299,7 @@ mod tests { ExecutionPlan, }; use arrow::array::{Date32Array, Date64Array}; - use arrow_schema::Field; + use arrow_schema::{DataType, Field}; use datafusion_common::test_util::batches_to_string; use datafusion_execution::TaskContext; use datafusion_expr::JoinType; From 23298f1f9b2dfac743c8e09d8a76a494797f7d7c Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 7 Jul 2025 19:08:09 -0400 Subject: [PATCH 16/24] rm output_rows --- datafusion/physical-plan/src/joins/piecewise_merge_join.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 737a884ccb06..27ce9b6f58e9 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -1065,7 +1065,6 @@ impl PiecewiseMergeJoinStream { timer.done(); self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); self.state = PiecewiseMergeJoinStreamState::Completed; Ok(StatefulStreamResult::Ready(Some(batch))) From 3e3a8b6ec297539349810d63c19a0fa078ed3352 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 4 Aug 2025 23:00:24 -0400 Subject: [PATCH 17/24] new join --- .../physical-plan/src/joins/nested_loop_join.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 5da8291cbc2f..b84197a031c0 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -1096,7 +1096,7 @@ impl NestedLoopJoinStream { }; // Only setting up timer, input is exhausted - let _timer = self.join_metrics.join_time.timer(); + let timer = self.join_metrics.join_time.timer(); // use the global left bitmap to produce the left indices and right indices let (left_side, right_side) = get_final_indices_from_shared_bitmap( @@ -1121,9 +1121,12 @@ impl NestedLoopJoinStream { if result.is_ok() { timer.done(); } - - let (left_side, right_side) = - get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); + + let (left_side, right_side) = get_final_indices_from_shared_bitmap( + visited_left_side, + self.join_type, + true, + ); self.join_result_status = Some(JoinResultProgress { build_indices: left_side, From 9016ce7225646c456f25da94928d0182f8215cb6 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 11 Aug 2025 17:20:47 -0400 Subject: [PATCH 18/24] add sql examples --- .../src/joins/piecewise_merge_join.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 27ce9b6f58e9..dd264b5c992e 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -106,6 +106,12 @@ use crate::{ /// all be larger values. /// /// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (100), (200), (500)) AS streamed(a) +/// LEFT JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// /// Processing Row 1: /// /// Sorted Streamed Side Sorted Buffered Side @@ -146,13 +152,19 @@ use crate::{ /// For Left Semi, Anti, and Mark joins we swap the inputs so that the marked side is on the buffered side. /// /// Here is an example: -/// We perform a `JoinType::Left` with these two batches and the operator being `Operator::Lt`(<). Because +/// We perform a `JoinType::LeftSemi` with these two batches and the operator being `Operator::Lt`(<). Because /// the operator is `Operator::Lt` we can find the minimum value in the streamed side; in this case it is 200. /// We can then advance a pointer from the start of the buffer side until we find the first value that satisfies /// the predicate. All rows after that first matched value satisfy the condition 200 < x so we can mark all of /// those rows as matched. /// /// ```text +/// SQL statement: +/// SELECT * +/// FROM (VALUES (100), (200), (500)) AS streamed(a) +/// LEFT SEMI JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) +/// ON streamed.a < buffered.b; +/// /// Unsorted Streamed Side Sorted Buffered Side /// ┌──────────────────┐ ┌──────────────────┐ /// 1 │ 500 │ 1 │ 100 │ From e865e27c3f523f8a5dac8602c808c66bdeedc207 Mon Sep 17 00:00:00 2001 From: Jonathan Chen Date: Wed, 13 Aug 2025 16:17:18 -0400 Subject: [PATCH 19/24] Update datafusion/physical-plan/src/joins/piecewise_merge_join.rs Co-authored-by: Oleks V --- datafusion/physical-plan/src/joins/piecewise_merge_join.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index dd264b5c992e..4ee95ac7faa4 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -221,7 +221,7 @@ impl PiecewiseMergeJoinExec { ) -> Result { // TODO: Implement mark joins for PiecewiseMergeJoin if matches!(join_type, JoinType::LeftMark | JoinType::RightMark) { - return plan_err!( + return not_impl_err!( "Mark Joins are currently not supported for PiecewiseMergeJoin" ); } From 22e423dc874a1f1bba9aaa91f6cf44e470367510 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Wed, 13 Aug 2025 19:07:25 -0400 Subject: [PATCH 20/24] fix --- datafusion/physical-plan/src/joins/piecewise_merge_join.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 4ee95ac7faa4..a6fb29bc5445 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -26,10 +26,10 @@ use arrow::{ util::bit_util, }; use arrow_schema::{ArrowError, Schema, SchemaRef, SortOptions}; -use datafusion_common::NullEquality; use datafusion_common::{ exec_err, internal_err, plan_err, utils::compare_rows, JoinSide, Result, ScalarValue, }; +use datafusion_common::{not_impl_err, NullEquality}; use datafusion_execution::{ memory_pool::{MemoryConsumer, MemoryReservation}, RecordBatchStream, SendableRecordBatchStream, @@ -111,7 +111,7 @@ use crate::{ /// FROM (VALUES (100), (200), (500)) AS streamed(a) /// LEFT JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) /// ON streamed.a < buffered.b; -/// +/// /// Processing Row 1: /// /// Sorted Streamed Side Sorted Buffered Side @@ -164,7 +164,7 @@ use crate::{ /// FROM (VALUES (100), (200), (500)) AS streamed(a) /// LEFT SEMI JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) /// ON streamed.a < buffered.b; -/// +/// /// Unsorted Streamed Side Sorted Buffered Side /// ┌──────────────────┐ ┌──────────────────┐ /// 1 │ 500 │ 1 │ 100 │ @@ -198,6 +198,7 @@ pub struct PiecewiseMergeJoinExec { pub join_type: JoinType, /// The schema once the join is applied schema: SchemaRef, + /// Buffered data buffered_fut: OnceAsync, /// Execution metrics metrics: ExecutionPlanMetricsSet, From bd245f7af479d49a6010ff5f693737a146216940 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 19 Aug 2025 21:55:20 -0400 Subject: [PATCH 21/24] update --- .../src/joins/piecewise_merge_join.rs | 1458 +++++++++++++---- 1 file changed, 1148 insertions(+), 310 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index a6fb29bc5445..8e10911f900a 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{new_null_array, Array, RecordBatchOptions}; +use arrow::array::{ + new_null_array, Array, PrimitiveArray, PrimitiveBuilder, RecordBatchOptions, +}; use arrow::compute::take; +use arrow::datatypes::{UInt32Type, UInt64Type}; use arrow::{ array::{ ArrayRef, BooleanBufferBuilder, RecordBatch, UInt32Array, UInt32Builder, @@ -67,6 +70,9 @@ use crate::{ ExecutionPlan, PlanProperties, }; +/// Batch emits this number of rows when processing +pub const DEFAULT_INCREMENTAL_BATCH_VALUE: usize = 1; + /// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter. /// /// The physical planner will choose to evalute this join when there is only one range predicate. This @@ -80,7 +86,7 @@ use crate::{ /// predicate. /// /// # Execution Plan Inputs -/// For `PiecewiseMergeJoin` we label all left inputs as the `streamed' side and the right outputs as the +/// For `PiecewiseMergeJoin` we label all right inputs as the `streamed' side and the left outputs as the /// 'buffered' side. /// /// `PiecewiseMergeJoin` takes a sorted input for the side to be buffered and is able to sort streamed record @@ -114,34 +120,35 @@ use crate::{ /// /// Processing Row 1: /// -/// Sorted Streamed Side Sorted Buffered Side -/// ┌──────────────────┐ ┌──────────────────┐ -/// 1 │ 100 │ 1 │ 100 │ -/// ├──────────────────┤ ├──────────────────┤ -/// 2 │ 200 │ 2 │ 200 │ ─┐ -/// ├──────────────────┤ ├──────────────────┤ │ For row 1 on streamed side with -/// 3 │ 500 │ 3 │ 200 │ │ value 100, we emit rows 2 - 5 -/// └──────────────────┘ ├──────────────────┤ │ as matches when the operator is -/// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all -/// ├──────────────────┤ │ rows after the first match (row 2 -/// 5 │ 400 │ ─┘ buffered side; 100 < 200) -/// └──────────────────┘ +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ ─┐ 2 │ 200 │ +/// ├──────────────────┤ │ For row 1 on streamed side with ├──────────────────┤ +/// 3 │ 200 │ │ value 100, we emit rows 2 - 5. 3 │ 500 │ +/// ├──────────────────┤ │ as matches when the operator is └──────────────────┘ +/// 4 │ 300 │ │ `Operator::Lt` (<) Emitting all +/// ├──────────────────┤ │ rows after the first match (row +/// 5 │ 400 │ ─┘ 2 buffered side; 100 < 200) +/// └──────────────────┘ /// /// Processing Row 2: /// By sorting the streamed side we know /// -/// Sorted Streamed Side Sorted Buffered Side -/// ┌──────────────────┐ ┌──────────────────┐ -/// 1 │ 100 │ 1 │ 100 │ -/// ├──────────────────┤ ├──────────────────┤ -/// 2 │ 200 │ 2 │ 200 │ <- Start here when probing for the streamed -/// ├──────────────────┤ ├──────────────────┤ side row 2. -/// 3 │ 500 │ 3 │ 200 │ -/// └──────────────────┘ ├──────────────────┤ -/// 4 │ 300 │ -/// ├──────────────────┤ -/// 5 │ 400 | -/// └──────────────────┘ +/// Sorted Buffered Side Sorted Streamed Side +/// ┌──────────────────┐ ┌──────────────────┐ +/// 1 │ 100 │ 1 │ 100 │ +/// ├──────────────────┤ ├──────────────────┤ +/// 2 │ 200 │ <- Start here when probing for the 2 │ 200 │ +/// ├──────────────────┤ streamed side row 2. ├──────────────────┤ +/// 3 │ 200 │ 3 │ 500 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ +/// ├──────────────────┤ +/// 5 │ 400 │ +/// └──────────────────┘ +/// /// ``` /// /// ## Existence Joins (Semi, Anti, Mark) @@ -165,31 +172,60 @@ use crate::{ /// LEFT SEMI JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) /// ON streamed.a < buffered.b; /// -/// Unsorted Streamed Side Sorted Buffered Side +/// Sorted Buffered Side Unsorted Streamed Side /// ┌──────────────────┐ ┌──────────────────┐ -/// 1 │ 500 │ 1 │ 100 │ +/// 1 │ 100 │ 1 │ 500 │ /// ├──────────────────┤ ├──────────────────┤ /// 2 │ 200 │ 2 │ 200 │ /// ├──────────────────┤ ├──────────────────┤ -/// 3 │ 300 │ 3 │ 200 │ -/// └──────────────────┘ ├──────────────────┤ -/// min value: 200 4 │ 300 │ ─┐ -/// ├──────────────────┤ | We emit matches for row 4 - 5 on the -/// 5 │ 400 | ─┘ buffered side. -/// └──────────────────┘ +/// 3 │ 200 │ 3 │ 300 │ +/// ├──────────────────┤ └──────────────────┘ +/// 4 │ 300 │ ─┐ +/// ├──────────────────┤ | We emit matches for row 4 - 5 +/// 5 │ 400 │ ─┘ on the buffered side. +/// └──────────────────┘ +/// min value: 200 /// ``` /// /// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt`(<) or /// `Operator::LtEq`(<=) and descending for `Operator::Gt`(>) or `Operator::GtEq`(>=). /// +/// ## Assumptions / Notation +/// - [R], [S]: number of pages (blocks) of R and S +/// - |R|, |S|: number of tuples in R and S +/// - B: number of buffer pages +/// +/// # Performance (cost) +/// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is a breakdown +/// of the calculations: +/// +/// ## Piecewise Merge Join (PWMJ) +/// Intuition: Keep the buffered side (R) sorted and in memory (or scan it in sorted order), +/// sort the streamed side (S), then merge in order while advancing a pivot on R. +/// +/// Average I/O cost: +/// cost(PWMJ) = cost_to_sort(R) + cost_to_sort(S) + ([R] + [S]) +/// = sort(R) + sort(S) + [R] + [S] +/// +/// - If R (buffered) already sorted on the join key: cost(PWMJ) = sort(S) + [R] + [S] +/// - If S already sorted and R not: cost(PWMJ) = sort(R) + [R] + [S] +/// - If both already sorted: cost(PWMJ) = [R] + [S] +/// +/// ## Nested Loop Join +/// cost(NLJ) ≈ [R] + |R|·[S] +/// +/// Takeaway: +/// - When at least one side needs sorting, PWMJ ≈ sort(R) + sort(S) + [R] + [S] on average, +/// typically beating NLJ’s |R|·[S] (or its buffered variant) for nontrivial |R|, [S]. +/// /// # Further Reference Material /// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) #[derive(Debug)] pub struct PiecewiseMergeJoinExec { - /// Left sorted joining execution plan - pub streamed: Arc, - /// Right sorting joining execution plan + /// Left buffered execution plan pub buffered: Arc, + /// Right streamed execution plan + pub streamed: Arc, /// The two expressions being compared pub on: (Arc, Arc), /// Comparison operator in the range predicate @@ -214,8 +250,8 @@ pub struct PiecewiseMergeJoinExec { impl PiecewiseMergeJoinExec { pub fn try_new( - streamed: Arc, buffered: Arc, + streamed: Arc, on: (Arc, Arc), operator: Operator, join_type: JoinType, @@ -227,23 +263,23 @@ impl PiecewiseMergeJoinExec { ); } - // We take the operator and enforce a sort order on the streamed + buffered side based on + // Take the operator and enforce a sort order on the streamed + buffered side based on // the operator type. let sort_options = match operator { Operator::Lt | Operator::LtEq => { - // For the Left existence joins the inputs will be swapped so we need to switch the sort - // options. - if is_left_existence_join(join_type) { - SortOptions::new(true, false) - } else { + // For left existence joins the inputs will be swapped so the sort + // options are switched + if is_right_existence_join(join_type) { SortOptions::new(false, false) + } else { + SortOptions::new(true, false) } } Operator::Gt | Operator::GtEq => { - if is_left_existence_join(join_type) { - SortOptions::new(false, false) - } else { + if is_right_existence_join(join_type) { SortOptions::new(true, false) + } else { + SortOptions::new(false, false) } } _ => { @@ -253,10 +289,12 @@ impl PiecewiseMergeJoinExec { } }; + // Give the same `sort_option for comparison later` let left_sort_exprs = vec![PhysicalSortExpr::new(Arc::clone(&on.0), sort_options)]; let right_sort_exprs = vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; + let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { return plan_err!( "PiecewiseMergeJoinExec requires valid sort expressions for its left side" @@ -268,15 +306,15 @@ impl PiecewiseMergeJoinExec { ); }; - let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); + let streamed_schema = streamed.schema(); // Create output schema for the join let schema = - Arc::new(build_join_schema(&streamed_schema, &buffered_schema, &join_type).0); + Arc::new(build_join_schema(&buffered_schema, &streamed_schema, &join_type).0); let cache = Self::compute_properties( - &streamed, &buffered, + &streamed, Arc::clone(&schema), join_type, &on, @@ -298,16 +336,16 @@ impl PiecewiseMergeJoinExec { }) } - /// Refeference to streamed side execution plan - pub fn streamed(&self) -> &Arc { - &self.streamed - } - /// Refeerence to buffered side execution plan pub fn buffered(&self) -> &Arc { &self.buffered } + /// Refeference to streamed side execution plan + pub fn streamed(&self) -> &Arc { + &self.streamed + } + /// Join type pub fn join_type(&self) -> JoinType { self.join_type @@ -318,33 +356,33 @@ impl PiecewiseMergeJoinExec { &self.sort_options } - /// Get probe side (streameded side) for the PiecewiseMergeJoin + /// Get probe side (streamed side) for the PiecewiseMergeJoin /// In current implementation, probe side is determined according to join type. pub fn probe_side(join_type: &JoinType) -> JoinSide { match join_type { JoinType::Right + | JoinType::Inner + | JoinType::Full | JoinType::RightSemi | JoinType::RightAnti - | JoinType::RightMark => JoinSide::Left, - JoinType::Inner - | JoinType::Left - | JoinType::Full + | JoinType::RightMark => JoinSide::Right, + JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi - | JoinType::LeftMark => JoinSide::Right, + | JoinType::LeftMark => JoinSide::Left, } } pub fn compute_properties( - streamed: &Arc, buffered: &Arc, + streamed: &Arc, schema: SchemaRef, join_type: JoinType, join_on: &(PhysicalExprRef, PhysicalExprRef), ) -> Result { let eq_properties = join_equivalence_properties( - streamed.equivalence_properties().clone(), buffered.equivalence_properties().clone(), + streamed.equivalence_properties().clone(), &join_type, schema, &Self::maintains_input_order(join_type), @@ -353,13 +391,13 @@ impl PiecewiseMergeJoinExec { )?; let output_partitioning = - symmetric_join_output_partitioning(streamed, buffered, &join_type)?; + symmetric_join_output_partitioning(buffered, streamed, &join_type)?; Ok(PlanProperties::new( eq_properties, output_partitioning, EmissionType::Incremental, - boundedness_from_children([streamed, buffered]), + boundedness_from_children([buffered, streamed]), )) } @@ -367,10 +405,10 @@ impl PiecewiseMergeJoinExec { match join_type { // The existence side is expected to come in sorted JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { - vec![true, false] + vec![false, true] } JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { - vec![false, true] + vec![true, false] } // Left, Right, Full, Inner Join is not guaranteed to maintain // input order as the streamed side will be sorted during @@ -399,23 +437,23 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { } fn children(&self) -> Vec<&Arc> { - vec![&self.streamed, &self.buffered] + vec![&self.buffered, &self.streamed] } fn required_input_ordering(&self) -> Vec> { // Existence joins don't need to be sorted on one side. - if is_left_existence_join(self.join_type) { - // Left side needs to be sorted because this will be swapped to the + if is_right_existence_join(self.join_type) { + // Right side needs to be sorted because this will be swapped to the // buffered side vec![ - Some(OrderingRequirements::from(self.left_sort_exprs.clone())), None, + Some(OrderingRequirements::from(self.left_sort_exprs.clone())), ] } else { - // We sort the left side in memory, so we do not need to enforce any sorting + // Sort the right side in memory, so we do not need to enforce any sorting vec![ - None, Some(OrderingRequirements::from(self.right_sort_exprs.clone())), + None, ] } } @@ -444,26 +482,26 @@ impl ExecutionPlan for PiecewiseMergeJoinExec { partition: usize, context: Arc, ) -> Result { - let on_streamed = Arc::clone(&self.on.0); - let on_buffered = Arc::clone(&self.on.1); + let on_buffered = Arc::clone(&self.on.0); + let on_streamed = Arc::clone(&self.on.1); - // If the join type is either LeftSemi, LeftAnti, or LeftMark we will swap the inputs + // If the join type is either RightSemi, RightAnti, or RightMark we will swap the inputs // and sort ordering because we want the mark side to be the buffered side. - let (streamed, buffered, on_streamed, on_buffered, operator) = - if is_left_existence_join(self.join_type) { + let (buffered, streamed, on_buffered, on_streamed, operator) = + if is_right_existence_join(self.join_type) { ( - Arc::clone(&self.buffered), Arc::clone(&self.streamed), - on_buffered, + Arc::clone(&self.buffered), on_streamed, + on_buffered, self.operator.swap().unwrap(), ) } else { ( - Arc::clone(&self.streamed), Arc::clone(&self.buffered), - on_streamed, + Arc::clone(&self.streamed), on_buffered, + on_streamed, self.operator, ) }; @@ -546,18 +584,18 @@ fn is_existence_join(join_type: JoinType) -> bool { ) } -// Returns boolean for whether the join is a left existence join -fn is_left_existence_join(join_type: JoinType) -> bool { +// Returns boolean for whether the join is a right existence join +fn is_right_existence_join(join_type: JoinType) -> bool { matches!( join_type, - JoinType::LeftAnti | JoinType::LeftSemi | JoinType::LeftMark + JoinType::RightAnti | JoinType::RightSemi | JoinType::RightMark ) } // Returns boolean to check if the join type needs to record // buffered side matches for classic joins fn need_produce_result_in_final(join_type: JoinType) -> bool { - matches!(join_type, JoinType::Full | JoinType::Right) + matches!(join_type, JoinType::Full | JoinType::Left) } // Returns boolean for whether or not we need to build the buffered side @@ -566,7 +604,7 @@ fn build_visited_indices_map(join_type: JoinType) -> bool { matches!( join_type, JoinType::Full - | JoinType::Right + | JoinType::Left | JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi @@ -779,6 +817,8 @@ struct PiecewiseMergeJoinStream { sort_option: SortOptions, // Metrics for build + probe joins join_metrics: BuildProbeJoinMetrics, + // Tracking incremental state for emitting record batches + batch_process_state: BatchProcessState, } impl RecordBatchStream for PiecewiseMergeJoinStream { @@ -835,6 +875,7 @@ impl PiecewiseMergeJoinStream { existence_join, sort_option, join_metrics, + batch_process_state: BatchProcessState::new(), } } @@ -960,18 +1001,23 @@ impl PiecewiseMergeJoinStream { fn process_stream_batch( &mut self, ) -> Result>> { - let stream_batch = self.state.try_as_process_stream_batch_mut()?; let buffered_side = self.buffered_side.try_as_ready_mut()?; + let stream_batch = self.state.try_as_process_stream_batch_mut()?; let batch = resolve_classic_join( - stream_batch, buffered_side, + stream_batch, Arc::clone(&self.schema), self.operator, self.sort_option, self.join_type, + &mut self.batch_process_state, )?; + if self.batch_process_state.continue_process { + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; Ok(StatefulStreamResult::Ready(Some(batch))) } @@ -980,8 +1026,8 @@ impl PiecewiseMergeJoinStream { fn process_unmatched_buffered_batch( &mut self, ) -> Result>> { - // Return early for `JoinType::Left` and `JoinType::Inner` - if matches!(self.join_type, JoinType::Left | JoinType::Inner) { + // Return early for `JoinType::Right` and `JoinType::Inner` + if matches!(self.join_type, JoinType::Right | JoinType::Inner) { self.state = PiecewiseMergeJoinStreamState::Completed; return Ok(StatefulStreamResult::Ready(None)); } @@ -991,6 +1037,78 @@ impl PiecewiseMergeJoinStream { let buffered_data = Arc::clone(&self.buffered_side.try_as_ready().unwrap().buffered_data); + // Check if the same batch needs to be checked for values again + if let Some(start_idx) = self.batch_process_state.process_rest { + if let Some(buffered_indices) = &self.batch_process_state.buffered_indices { + let remaining = buffered_indices.len() - start_idx; + + // Branch into this and return value if there are more rows to deal with + if remaining > DEFAULT_INCREMENTAL_BATCH_VALUE { + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let buffered_chunk_ref = buffered_indices + .slice(start_idx, DEFAULT_INCREMENTAL_BATCH_VALUE); + let new_buffered_indices = buffered_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let streamed_indices: UInt32Array = + (0..new_buffered_indices.len() as u32).collect(); + + let batch = build_matched_indices( + self.join_type, + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + new_buffered_indices.clone(), + )?; + + self.batch_process_state.set_process_rest(Some( + start_idx + DEFAULT_INCREMENTAL_BATCH_VALUE, + )); + self.batch_process_state.continue_process = true; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let buffered_chunk_ref = buffered_indices.slice(start_idx, remaining); + let new_buffered_indices = buffered_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let streamed_indices: UInt32Array = + (0..new_buffered_indices.len() as u32).collect(); + + let batch = build_matched_indices( + self.join_type, + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + new_buffered_indices.clone(), + )?; + + self.batch_process_state.reset(); + + timer.done(); + self.join_metrics.output_batches.add(1); + self.state = PiecewiseMergeJoinStreamState::Completed; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + + return exec_err!("Batch process state should hold buffered indices"); + } + // For Semi/Anti/Mark joins that mark indices on the buffered side, and retrieve final indices from // `get_final_indices_bitmap` if matches!( @@ -1005,6 +1123,7 @@ impl PiecewiseMergeJoinStream { let global_min_max = self.streamed_global_min_max.lock(); let threshold = match &*global_min_max { Some(v) => v.clone(), + // This shouldn't be possible None => return exec_err!("Stream batch was empty."), }; @@ -1047,6 +1166,7 @@ impl PiecewiseMergeJoinStream { (threshold_idx as u64..buffered_data.values.len() as u64).collect(); buffered_indices.append_slice(&buffered_range); } + let buffered_indices_array = buffered_indices.finish(); // Mark bitmap here because the visited bitmap hasn't been marked yet for existence joins @@ -1063,6 +1183,41 @@ impl PiecewiseMergeJoinStream { true, ); + // If the output indices is larger than the limit for the incremental batching then + // proceed to outputting all matches up to that index, return batch, and the matching + // will start next on the updated index (`process_rest`) + if buffered_indices.len() > DEFAULT_INCREMENTAL_BATCH_VALUE { + let buffered_batch = buffered_data.batch(); + let empty_stream_batch = + RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); + + let indices_chunk_ref = buffered_indices.slice( + self.batch_process_state.start_idx, + DEFAULT_INCREMENTAL_BATCH_VALUE, + ); + + let indices_chunk = indices_chunk_ref + .as_any() + .downcast_ref::() + .expect("downcast to UInt64Array after slice"); + + let batch = build_matched_indices( + self.join_type, + Arc::clone(&self.schema), + &empty_stream_batch, + buffered_batch, + streamed_indices, + indices_chunk.clone(), + )?; + + self.batch_process_state.buffered_indices = Some(buffered_indices); + self.batch_process_state + .set_process_rest(Some(DEFAULT_INCREMENTAL_BATCH_VALUE)); + self.batch_process_state.continue_process = true; + + return Ok(StatefulStreamResult::Ready(Some(batch))); + } + let buffered_batch = buffered_data.batch(); let empty_stream_batch = RecordBatch::new_empty(Arc::clone(&self.streamed_schema)); @@ -1084,6 +1239,68 @@ impl PiecewiseMergeJoinStream { } } +// Holds all information for processing incremental output +struct BatchProcessState { + // Used to pick up from the last index on the stream side + start_idx: usize, + // Used to pick up from the last index on the buffered side + pivot: usize, + // Tracks the number of rows processed; default starts at 0 + num_rows: usize, + // Processes the rest of the batch + process_rest: Option, + // Used to skip fully processing the row + not_found: bool, + // Signals whether to call `ProcessStreamBatch` again + continue_process: bool, + // Holding the buffered indices when processing the remaining marked rows. + buffered_indices: Option>, +} + +impl BatchProcessState { + pub fn new() -> Self { + Self { + start_idx: 0, + num_rows: 0, + pivot: 0, + process_rest: None, + not_found: false, + continue_process: false, + buffered_indices: None, + } + } + + pub fn reset(&mut self) { + self.start_idx = 0; + self.num_rows = 0; + self.pivot = 0; + self.process_rest = None; + self.not_found = false; + self.continue_process = false; + self.buffered_indices = None; + } + + pub fn pivot(&self) -> usize { + self.pivot + } + + pub fn set_pivot(&mut self, pivot: usize) { + self.pivot += pivot; + } + + pub fn set_start_idx(&mut self, start_idx: usize) { + self.start_idx = start_idx; + } + + pub fn set_rows(&mut self, num_rows: usize) { + self.num_rows = num_rows; + } + + pub fn set_process_rest(&mut self, process_rest: Option) { + self.process_rest = process_rest; + } +} + impl Stream for PiecewiseMergeJoinStream { type Item = Result; @@ -1106,7 +1323,7 @@ fn resolve_existence_join( // Based on the operator we will find the minimum or maximum value. match operator { - Operator::Gt | Operator::GtEq => { + Operator::Lt | Operator::LtEq => { let max_value = max_batch(&stream_batch.values[0])?; let new_max = if let Some(prev) = (*min_max_value).clone() { if max_value.partial_cmp(&prev).unwrap() == Ordering::Greater { @@ -1120,7 +1337,7 @@ fn resolve_existence_join( *min_max_value = Some(new_max); } - Operator::Lt | Operator::LtEq => { + Operator::Gt | Operator::GtEq => { let min_value = min_batch(&stream_batch.values[0])?; let new_min = if let Some(prev) = (*min_max_value).clone() { if min_value.partial_cmp(&prev).unwrap() == Ordering::Less { @@ -1146,90 +1363,282 @@ fn resolve_existence_join( // For Left, Right, Full, and Inner joins, incoming stream batches will already be sorted. fn resolve_classic_join( + buffered_side: &mut BufferedSideReadyState, stream_batch: &StreamedBatch, - buffered_side: &BufferedSideReadyState, join_schema: Arc, operator: Operator, sort_options: SortOptions, join_type: JoinType, + batch_process_state: &mut BatchProcessState, ) -> Result { - let stream_values = stream_batch.values(); let buffered_values = buffered_side.buffered_data.values(); let buffered_len = buffered_values.len(); + let stream_values = stream_batch.values(); - let mut stream_indices = UInt32Builder::default(); let mut buffered_indices = UInt64Builder::default(); + let mut stream_indices = UInt32Builder::default(); // Our pivot variable allows us to start probing on the buffered side where we last matched // in the previous stream row. - let mut pivot = 0; - for row_idx in 0..stream_values[0].len() { + let mut pivot = batch_process_state.pivot(); + for row_idx in batch_process_state.start_idx..stream_values[0].len() { let mut found = false; - while pivot < buffered_values.len() { - let compare = compare_join_arrays( - &[Arc::clone(&stream_values[0])], - row_idx, - &[Arc::clone(buffered_values)], - pivot, - &[sort_options], - NullEquality::NullEqualsNothing, - )?; - // If we find a match we append all indices and move to the next stream row index - match operator { - Operator::Gt | Operator::Lt => { - if matches!(compare, Ordering::Less) { - let count = buffered_values.len() - pivot; - let stream_repeated = vec![row_idx as u32; count]; - let buffered_range: Vec = - (pivot as u64..buffered_len as u64).collect(); + // Check once to see if it is a redo of a null value if not we do not try to process the batch + if !batch_process_state.not_found { + while pivot < buffered_values.len() + || batch_process_state.process_rest.is_some() + { + // If there is still data left in the batch to process, use the index and output + if let Some(start_idx) = batch_process_state.process_rest { + let count = buffered_values.len() - start_idx; + if count >= DEFAULT_INCREMENTAL_BATCH_VALUE { + let stream_repeated = + vec![row_idx as u32; DEFAULT_INCREMENTAL_BATCH_VALUE]; + batch_process_state.set_process_rest(Some( + start_idx + DEFAULT_INCREMENTAL_BATCH_VALUE, + )); + batch_process_state.set_rows( + batch_process_state.num_rows + + DEFAULT_INCREMENTAL_BATCH_VALUE, + ); + let buffered_range: Vec = (start_idx as u64 + ..((start_idx as u64) + + (DEFAULT_INCREMENTAL_BATCH_VALUE as u64))) + .collect(); stream_indices.append_slice(&stream_repeated); buffered_indices.append_slice(&buffered_range); - found = true; - break; + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); } - pivot += 1; + + batch_process_state.set_rows(batch_process_state.num_rows + count); + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (start_idx as u64..buffered_len as u64).collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + batch_process_state.process_rest = None; + + found = true; + + break; } - Operator::GtEq | Operator::LtEq => { - if matches!(compare, Ordering::Equal | Ordering::Less) { - let count = buffered_values.len() - pivot; - let stream_repeated = vec![row_idx as u32; count]; - let buffered_range: Vec = - (pivot as u64..buffered_len as u64).collect(); - stream_indices.append_slice(&stream_repeated); - buffered_indices.append_slice(&buffered_range); - found = true; - break; + let compare = compare_join_arrays( + &[Arc::clone(&stream_values[0])], + row_idx, + &[Arc::clone(buffered_values)], + pivot, + &[sort_options], + NullEquality::NullEqualsNothing, + )?; + + // If we find a match we append all indices and move to the next stream row index + match operator { + Operator::Gt | Operator::Lt => { + if matches!(compare, Ordering::Less) { + let count = buffered_values.len() - pivot; + + // If the current output + new output is over our process value then we want to be + // able to change that + if batch_process_state.num_rows + count + >= DEFAULT_INCREMENTAL_BATCH_VALUE + { + let process_batch_size = DEFAULT_INCREMENTAL_BATCH_VALUE + - batch_process_state.num_rows; + let stream_repeated = + vec![row_idx as u32; process_batch_size]; + batch_process_state.set_rows( + batch_process_state.num_rows + process_batch_size, + ); + let buffered_range: Vec = (pivot as u64 + ..(pivot + process_batch_size) as u64) + .collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + batch_process_state + .set_process_rest(Some(pivot + process_batch_size)); + batch_process_state.continue_process = true; + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Update the number of rows processed + batch_process_state + .set_rows(batch_process_state.num_rows + count); + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (pivot as u64..buffered_len as u64).collect(); + + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + found = true; + + break; + } } - pivot += 1; - } - _ => { - return exec_err!( - "PiecewiseMergeJoin should not contain operator, {}", - operator - ) - } - }; + Operator::GtEq | Operator::LtEq => { + if matches!(compare, Ordering::Equal | Ordering::Less) { + let count = buffered_values.len() - pivot; + + // If the current output + new output is over our process value then we want to be + // able to change that + if batch_process_state.num_rows + count + >= DEFAULT_INCREMENTAL_BATCH_VALUE + { + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + + let process_batch_size = DEFAULT_INCREMENTAL_BATCH_VALUE + - batch_process_state.num_rows; + let stream_repeated = + vec![row_idx as u32; process_batch_size]; + batch_process_state + .set_process_rest(Some(pivot + process_batch_size)); + batch_process_state.set_rows( + batch_process_state.num_rows + process_batch_size, + ); + let buffered_range: Vec = (pivot as u64 + ..(pivot + process_batch_size) as u64) + .collect(); + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Update the number of rows processed + batch_process_state + .set_rows(batch_process_state.num_rows + count); + let stream_repeated = vec![row_idx as u32; count]; + let buffered_range: Vec = + (pivot as u64..buffered_len as u64).collect(); + + stream_indices.append_slice(&stream_repeated); + buffered_indices.append_slice(&buffered_range); + found = true; + + break; + } + } + _ => { + return exec_err!( + "PiecewiseMergeJoin should not contain operator, {}", + operator + ) + } + }; + + // Increment pivot after every row + pivot += 1; + } } - // If not found we append a null value for `JoinType::Left` and `JoinType::Full` - if !found && matches!(join_type, JoinType::Left | JoinType::Full) { + // If not found we append a null value for `JoinType::Right` and `JoinType::Full` + if (!found || batch_process_state.not_found) + && matches!(join_type, JoinType::Right | JoinType::Full) + { + let remaining = DEFAULT_INCREMENTAL_BATCH_VALUE + .saturating_sub(batch_process_state.num_rows); + if remaining == 0 { + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + // Update the start index so it repeats the process + batch_process_state.set_start_idx(row_idx); + batch_process_state.set_pivot(pivot); + batch_process_state.not_found = true; + batch_process_state.continue_process = true; + batch_process_state.set_rows(0); + + return Ok(batch); + } + + // Append right side value + null value for left stream_indices.append_value(row_idx as u32); buffered_indices.append_null(); + batch_process_state.set_rows(batch_process_state.num_rows + 1); + batch_process_state.not_found = false; } } + let batch = process_batch( + &mut buffered_indices, + &mut stream_indices, + stream_batch, + buffered_side, + join_type, + join_schema, + )?; + + // Resets batch process state for processing `Left` + `Full` join + batch_process_state.reset(); + + Ok(batch) +} + +fn process_batch( + buffered_indices: &mut PrimitiveBuilder, + stream_indices: &mut PrimitiveBuilder, + stream_batch: &StreamedBatch, + buffered_side: &mut BufferedSideReadyState, + join_type: JoinType, + join_schema: Arc, +) -> Result { let stream_indices_array = stream_indices.finish(); let buffered_indices_array = buffered_indices.finish(); - // We need to mark the buffered side matched indices for `JoinType::Full` and `JoinType::Right` + // We need to mark the buffered side matched indices for `JoinType::Full` and `JoinType::Left` if need_produce_result_in_final(join_type) { let mut bitmap = buffered_side.buffered_data.visited_indices_bitmap.lock(); - buffered_indices_array.iter().flatten().for_each(|x| { - bitmap.set_bit(x as usize, true); + buffered_indices_array.iter().flatten().for_each(|i| { + bitmap.set_bit(i as usize, true); }); } @@ -1266,7 +1675,7 @@ fn build_matched_indices( } // Gather stream columns after applying filter specified with stream indices - let mut streamed_columns = if !is_existence_join(join_type) { + let streamed_columns = if !is_existence_join(join_type) { streamed_batch .columns() .iter() @@ -1288,17 +1697,17 @@ fn build_matched_indices( vec![] }; - let buffered_columns = buffered_batch + let mut buffered_columns = buffered_batch .columns() .iter() .map(|column_array| take(column_array, &buffered_indices, None)) .collect::, ArrowError>>()?; - streamed_columns.extend(buffered_columns); + buffered_columns.extend(streamed_columns); Ok(RecordBatch::try_new( Arc::new((*schema).clone()), - streamed_columns, + buffered_columns, )?) } @@ -1421,11 +1830,26 @@ mod tests { #[tokio::test] async fn join_inner_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), - ("b1", &vec![1, 2, 3]), // this has a repetition + ("b1", &vec![3, 2, 1]), // this has a repetition ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 4]), @@ -1441,30 +1865,45 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | 7 | 10 | 2 | 70 | - | 1 | 1 | 7 | 20 | 3 | 80 | - | 1 | 1 | 7 | 30 | 4 | 90 | - | 2 | 2 | 8 | 20 | 3 | 80 | - | 2 | 2 | 8 | 30 | 4 | 90 | - | 3 | 3 | 9 | 30 | 4 | 90 | - +----+----+----+----+----+----+ - "#); + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 20 | 3 | 80 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_inner_less_than_unsorted() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 2 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), - ("b1", &vec![3, 1, 2]), // this has a repetition + ("b1", &vec![3, 2, 1]), // this has a repetition ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), - ("b1", &vec![2, 3, 4]), + ("b1", &vec![3, 2, 4]), ("c2", &vec![70, 80, 90]), ); @@ -1477,27 +1916,42 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 2 | 1 | 8 | 10 | 2 | 70 | - | 2 | 1 | 8 | 20 | 3 | 80 | - | 2 | 1 | 8 | 30 | 4 | 90 | - | 3 | 2 | 9 | 20 | 3 | 80 | - | 3 | 2 | 9 | 30 | 4 | 90 | - | 1 | 3 | 7 | 30 | 4 | 90 | - +----+----+----+----+----+----+ - "#); + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 3 | 7 | 30 | 4 | 90 | + | 2 | 2 | 8 | 30 | 4 | 90 | + | 3 | 1 | 9 | 30 | 4 | 90 | + | 2 | 2 | 8 | 10 | 3 | 70 | + | 3 | 1 | 9 | 10 | 3 | 70 | + | 3 | 1 | 9 | 20 | 2 | 80 | + +----+----+----+----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_inner_greater_than_equal_to() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 2 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), - ("b1", &vec![2, 3, 4]), // this has a repetition + ("b1", &vec![2, 3, 4]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![3, 2, 1]), @@ -1513,34 +1967,47 @@ mod tests { join_collect(left, right, on, Operator::GtEq, JoinType::Inner).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 3 | 4 | 9 | 10 | 3 | 70 | - | 3 | 4 | 9 | 20 | 2 | 80 | - | 3 | 4 | 9 | 30 | 1 | 90 | - | 2 | 3 | 8 | 10 | 3 | 70 | - | 2 | 3 | 8 | 20 | 2 | 80 | - | 2 | 3 | 8 | 30 | 1 | 90 | - | 1 | 2 | 7 | 20 | 2 | 80 | - | 1 | 2 | 7 | 30 | 1 | 90 | - +----+----+----+----+----+----+ - "#); + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 2 | 7 | 30 | 1 | 90 | + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 1 | 2 | 7 | 20 | 2 | 80 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 2 | 3 | 8 | 10 | 3 | 70 | + | 3 | 4 | 9 | 10 | 3 | 70 | + +----+----+----+----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_inner_empty_left() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // (empty) + // +----+----+----+ let left = build_table( ("a1", &Vec::::new()), ("b1", &Vec::::new()), ("c1", &Vec::::new()), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 1 | 1 | 1 | + // | 2 | 2 | 2 | + // +----+----+----+ let right = build_table( ("a2", &vec![1, 2]), ("b1", &vec![1, 2]), ("c2", &vec![1, 2]), ); + let on = ( Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, @@ -1548,21 +2015,34 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::LtEq, JoinType::Inner).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - +----+----+----+----+----+----+ - "#); + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + +----+----+----+----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_full_greater_than_equal_to() -> Result<()> { + // +----+----+-----+ + // | a1 | b1 | c1 | + // +----+----+-----+ + // | 1 | 1 | 100 | + // | 2 | 2 | 200 | + // +----+----+-----+ let left = build_table( ("a1", &vec![1, 2]), ("b1", &vec![1, 2]), ("c1", &vec![100, 200]), ); + + // +----+----+-----+ + // | a2 | b1 | c2 | + // +----+----+-----+ + // | 10 | 3 | 300 | + // | 20 | 2 | 400 | + // +----+----+-----+ let right = build_table( ("a2", &vec![10, 20]), ("b1", &vec![3, 2]), @@ -1578,25 +2058,40 @@ mod tests { join_collect(left, right, on, Operator::GtEq, JoinType::Full).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+-----+----+----+-----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+-----+----+----+-----+ - | 2 | 2 | 200 | 20 | 2 | 400 | - | 1 | 1 | 100 | | | | - | | | | 10 | 3 | 300 | - +----+----+-----+----+----+-----+ - "#); + +----+----+-----+----+----+-----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+-----+----+----+-----+ + | 2 | 2 | 200 | 20 | 2 | 400 | + | | | | 10 | 3 | 300 | + | 1 | 1 | 100 | | | | + +----+----+-----+----+----+-----+ + "#); Ok(()) } #[tokio::test] async fn join_left_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![1, 3, 4]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 3 | 70 | + // | 20 | 2 | 80 | + // | 30 | 1 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![3, 2, 1]), @@ -1612,27 +2107,42 @@ mod tests { join_collect(left, right, on, Operator::Gt, JoinType::Left).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 3 | 4 | 9 | 10 | 3 | 70 | - | 3 | 4 | 9 | 20 | 2 | 80 | - | 3 | 4 | 9 | 30 | 1 | 90 | - | 2 | 3 | 8 | 20 | 2 | 80 | - | 2 | 3 | 8 | 30 | 1 | 90 | - | 1 | 1 | 7 | | | | - +----+----+----+----+----+----+ - "#); + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 1 | 90 | + | 3 | 4 | 9 | 30 | 1 | 90 | + | 2 | 3 | 8 | 20 | 2 | 80 | + | 3 | 4 | 9 | 20 | 2 | 80 | + | 3 | 4 | 9 | 10 | 3 | 70 | + | 1 | 1 | 7 | | | | + +----+----+----+----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_right_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![1, 3, 4]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 5 | 70 | + // | 20 | 3 | 80 | + // | 30 | 2 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![5, 3, 2]), @@ -1648,25 +2158,40 @@ mod tests { join_collect(left, right, on, Operator::Gt, JoinType::Right).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 3 | 4 | 9 | 20 | 3 | 80 | - | 3 | 4 | 9 | 30 | 2 | 90 | - | 2 | 3 | 8 | 30 | 2 | 90 | - | | | | 10 | 5 | 70 | - +----+----+----+----+----+----+ - "#); + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 2 | 3 | 8 | 30 | 2 | 90 | + | 3 | 4 | 9 | 30 | 2 | 90 | + | 3 | 4 | 9 | 20 | 3 | 80 | + | | | | 10 | 5 | 70 | + +----+----+----+----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_right_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 3 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), - ("b1", &vec![1, 3, 4]), + ("b1", &vec![4, 3, 1]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 5]), @@ -1682,26 +2207,41 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+----+----+----+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +----+----+----+----+----+----+ - | 1 | 1 | 7 | 10 | 2 | 70 | - | 1 | 1 | 7 | 20 | 3 | 80 | - | 1 | 1 | 7 | 30 | 5 | 90 | - | 2 | 3 | 8 | 30 | 5 | 90 | - | 3 | 4 | 9 | 30 | 5 | 90 | - +----+----+----+----+----+----+ - "#); + +----+----+----+----+----+----+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +----+----+----+----+----+----+ + | 1 | 4 | 7 | 30 | 5 | 90 | + | 2 | 3 | 8 | 30 | 5 | 90 | + | 3 | 1 | 9 | 30 | 5 | 90 | + | 3 | 1 | 9 | 20 | 3 | 80 | + | 3 | 1 | 9 | 10 | 2 | 70 | + +----+----+----+----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_right_semi_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 4 | 7 | + // | 2 | 3 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), - ("b1", &vec![1, 3, 4]), // this has a repetition + ("b1", &vec![4, 3, 1]), // this has a repetition ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 5]), @@ -1717,24 +2257,39 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::RightSemi).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a2 | b1 | c2 | - +----+----+----+ - | 10 | 2 | 70 | - | 20 | 3 | 80 | - | 30 | 5 | 90 | - +----+----+----+ - "#); + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + | 10 | 2 | 70 | + | 20 | 3 | 80 | + | 30 | 5 | 90 | + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_left_semi_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 5 | 7 | + // | 2 | 4 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![5, 4, 1]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 5]), @@ -1750,23 +2305,38 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::LeftSemi).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 2 | 4 | 8 | - | 3 | 1 | 9 | - +----+----+----+ - "#); + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 2 | 4 | 8 | + | 3 | 1 | 9 | + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_left_semi_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 4 | 8 | + // | 3 | 5 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![1, 4, 5]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 5]), @@ -1782,23 +2352,38 @@ mod tests { join_collect(left, right, on, Operator::Gt, JoinType::LeftSemi).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 2 | 4 | 8 | - | 3 | 5 | 9 | - +----+----+----+ - "#); + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 2 | 4 | 8 | + | 3 | 5 | 9 | + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_left_anti_greater_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 4 | 8 | + // | 3 | 5 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![1, 4, 5]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 5]), @@ -1814,22 +2399,37 @@ mod tests { join_collect(left, right, on, Operator::Gt, JoinType::LeftAnti).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 1 | 1 | 7 | - +----+----+----+ - "#); + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 1 | 7 | + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_left_semi_greater_than_equal_to() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 1 | 7 | + // | 2 | 3 | 8 | + // | 3 | 4 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![1, 3, 4]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 1 | 70 | + // | 20 | 2 | 80 | + // | 30 | 3 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![1, 2, 3]), @@ -1845,24 +2445,39 @@ mod tests { join_collect(left, right, on, Operator::GtEq, JoinType::LeftSemi).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 1 | 1 | 7 | - | 2 | 3 | 8 | - | 3 | 4 | 9 | - +----+----+----+ - "#); + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 1 | 7 | + | 2 | 3 | 8 | + | 3 | 4 | 9 | + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_left_anti_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 5 | 7 | + // | 2 | 4 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![5, 4, 1]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 5]), @@ -1878,22 +2493,37 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::LeftAnti).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - | 1 | 5 | 7 | - +----+----+----+ - "#); + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + | 1 | 5 | 7 | + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_left_anti_less_than_equal_to() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 5 | 7 | + // | 2 | 4 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![5, 4, 1]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 5]), @@ -1909,21 +2539,36 @@ mod tests { join_collect(left, right, on, Operator::LtEq, JoinType::LeftAnti).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a1 | b1 | c1 | - +----+----+----+ - +----+----+----+ - "#); + +----+----+----+ + | a1 | b1 | c1 | + +----+----+----+ + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_right_semi_unsorted() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 3 | 7 | + // | 2 | 1 | 8 | + // | 3 | 2 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![3, 1, 2]), // unsorted ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 4 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 4]), @@ -1936,24 +2581,39 @@ mod tests { let (_, batches) = join_collect(left, right, on, Operator::Lt, JoinType::RightSemi).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a2 | b1 | c2 | - +----+----+----+ - | 10 | 2 | 70 | - | 20 | 3 | 80 | - | 30 | 4 | 90 | - +----+----+----+ - "#); + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + | 10 | 2 | 70 | + | 20 | 3 | 80 | + | 30 | 4 | 90 | + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_right_anti_less_than_equal_to() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 5 | 7 | + // | 2 | 4 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![5, 4, 1]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 2 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![2, 3, 5]), @@ -1969,21 +2629,36 @@ mod tests { join_collect(left, right, on, Operator::LtEq, JoinType::RightAnti).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a2 | b1 | c2 | - +----+----+----+ - +----+----+----+ - "#); + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_right_anti_less_than() -> Result<()> { + // +----+----+----+ + // | a1 | b1 | c1 | + // +----+----+----+ + // | 1 | 5 | 7 | + // | 2 | 4 | 8 | + // | 3 | 1 | 9 | + // +----+----+----+ let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![5, 4, 1]), ("c1", &vec![7, 8, 9]), ); + + // +----+----+----+ + // | a2 | b1 | c2 | + // +----+----+----+ + // | 10 | 1 | 70 | + // | 20 | 3 | 80 | + // | 30 | 5 | 90 | + // +----+----+----+ let right = build_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![1, 3, 5]), @@ -1999,22 +2674,37 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::RightAnti).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +----+----+----+ - | a2 | b1 | c2 | - +----+----+----+ - | 10 | 1 | 70 | - +----+----+----+ - "#); + +----+----+----+ + | a2 | b1 | c2 | + +----+----+----+ + | 10 | 1 | 70 | + +----+----+----+ + "#); Ok(()) } #[tokio::test] async fn join_date32_inner_less_than() -> Result<()> { + // +----+-------+----+ + // | a1 | b1 | c1 | + // +----+-------+----+ + // | 1 | 19107 | 7 | + // | 2 | 19107 | 8 | + // | 3 | 19105 | 9 | + // +----+-------+----+ let left = build_date_table( ("a1", &vec![1, 2, 3]), - ("b1", &vec![19105, 19107, 19107]), + ("b1", &vec![19107, 19107, 19105]), ("c1", &vec![7, 8, 9]), ); + + // +----+-------+----+ + // | a2 | b1 | c2 | + // +----+-------+----+ + // | 10 | 19105 | 70 | + // | 20 | 19103 | 80 | + // | 30 | 19107 | 90 | + // +----+-------+----+ let right = build_date_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![19105, 19103, 19107]), @@ -2030,22 +2720,37 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +------------+------------+------------+------------+------------+------------+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +------------+------------+------------+------------+------------+------------+ - | 1970-01-02 | 2022-04-23 | 1970-01-08 | 1970-01-31 | 2022-04-25 | 1970-04-01 | - +------------+------------+------------+------------+------------+------------+ - "#); + +------------+------------+------------+------------+------------+------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +------------+------------+------------+------------+------------+------------+ + | 1970-01-04 | 2022-04-23 | 1970-01-10 | 1970-01-31 | 2022-04-25 | 1970-04-01 | + +------------+------------+------------+------------+------------+------------+ + "#); Ok(()) } #[tokio::test] async fn join_date64_inner_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650903441000 | 8 | + // | 3 | 1650703441000 | 9 | + // +----+---------------+----+ let left = build_date64_table( ("a1", &vec![1, 2, 3]), - ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), + ("b1", &vec![1650903441000, 1650903441000, 1650703441000]), ("c1", &vec![7, 8, 9]), ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 70 | + // | 20 | 1650503441000 | 80 | + // | 30 | 1650903441000 | 90 | + // +----+---------------+----+ let right = build_date64_table( ("a2", &vec![10, 20, 30]), ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), @@ -2061,12 +2766,145 @@ mod tests { join_collect(left, right, on, Operator::Lt, JoinType::Inner).await?; assert_snapshot!(batches_to_string(&batches), @r#" - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - | a1 | b1 | c1 | a2 | b1 | c2 | - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - | 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | - +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ - "#); + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.003 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + "#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_right_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650703441000 | 8 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1650903441000, 1650703441000]), + ("c1", &vec![7, 8]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 80 | + // | 20 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20]), + ("b1", &vec![1650703441000, 1650903441000]), + ("c2", &vec![80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::Right).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.002 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + | | | | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.080 | + +-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+ +"#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_left_semi_less_than() -> Result<()> { + // +----+---------------+----+ + // | a1 | b1 | c1 | + // +----+---------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650903441000 | 8 | + // | 3 | 1650703441000 | 9 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650903441000, 1650903441000, 1650703441000]), + ("c1", &vec![7, 8, 9]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10]), + ("b1", &vec![1650903441000]), + ("c2", &vec![90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::LeftSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+ + | a1 | b1 | c1 | + +-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.003 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.009 | + +-------------------------+---------------------+-------------------------+ +"#); + Ok(()) + } + + #[tokio::test] + async fn join_date64_right_semi_less_than() -> Result<()> { + // +----+--------------+----+ + // | a1 | b1 | c1 | + // +----+--------------+----+ + // | 1 | 1650903441000 | 7 | + // | 2 | 1650703441000 | 8 | + // +----+---------------+----+ + let left = build_date64_table( + ("a1", &vec![1, 2]), + ("b1", &vec![1650903441000, 1650703441000]), + ("c1", &vec![7, 8]), + ); + + // +----+---------------+----+ + // | a2 | b1 | c2 | + // +----+---------------+----+ + // | 10 | 1650703441000 | 80 | + // | 20 | 1650903441000 | 90 | + // +----+---------------+----+ + let right = build_date64_table( + ("a2", &vec![10, 20]), + ("b1", &vec![1650703441000, 1650903441000]), + ("c2", &vec![80, 90]), + ); + + let on = ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + ); + + let (_, batches) = + join_collect(left, right, on, Operator::Lt, JoinType::RightSemi).await?; + + assert_snapshot!(batches_to_string(&batches), @r#" + +-------------------------+---------------------+-------------------------+ + | a2 | b1 | c2 | + +-------------------------+---------------------+-------------------------+ + | 1970-01-01T00:00:00.020 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 | + +-------------------------+---------------------+-------------------------+ +"#); Ok(()) } } From f490de66e660227566994776d202a7098a67a28c Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 19 Aug 2025 22:08:39 -0400 Subject: [PATCH 22/24] update --- .../src/joins/piecewise_merge_join.rs | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 8e10911f900a..8215f7bdd7cb 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -71,7 +71,7 @@ use crate::{ }; /// Batch emits this number of rows when processing -pub const DEFAULT_INCREMENTAL_BATCH_VALUE: usize = 1; +pub const DEFAULT_INCREMENTAL_BATCH_VALUE: usize = 8192; /// `PiecewiseMergeJoinExec` is a join execution plan that only evaluates single range filter. /// @@ -187,36 +187,35 @@ pub const DEFAULT_INCREMENTAL_BATCH_VALUE: usize = 1; /// min value: 200 /// ``` /// -/// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt`(<) or -/// `Operator::LtEq`(<=) and descending for `Operator::Gt`(>) or `Operator::GtEq`(>=). +/// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt` (<) or +/// `Operator::LtEq` (<=) and descending for `Operator::Gt` (>) or `Operator::GtEq` (>=). /// /// ## Assumptions / Notation -/// - [R], [S]: number of pages (blocks) of R and S -/// - |R|, |S|: number of tuples in R and S -/// - B: number of buffer pages +/// - \[R\], \[S\]: number of pages (blocks) of `R` and `S` +/// - |R|, |S|: number of tuples in `R` and `S` +/// - `B`: number of buffer pages /// /// # Performance (cost) /// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is a breakdown /// of the calculations: /// /// ## Piecewise Merge Join (PWMJ) -/// Intuition: Keep the buffered side (R) sorted and in memory (or scan it in sorted order), -/// sort the streamed side (S), then merge in order while advancing a pivot on R. +/// Intuition: Keep the buffered side (`R`) sorted and in memory (or scan it in sorted order), +/// sort the streamed side (`S`), then merge in order while advancing a pivot on `R`. /// /// Average I/O cost: -/// cost(PWMJ) = cost_to_sort(R) + cost_to_sort(S) + ([R] + [S]) -/// = sort(R) + sort(S) + [R] + [S] +/// `cost(PWMJ) = sort(R) + sort(S) + (\[R\] + \[S\])` /// -/// - If R (buffered) already sorted on the join key: cost(PWMJ) = sort(S) + [R] + [S] -/// - If S already sorted and R not: cost(PWMJ) = sort(R) + [R] + [S] -/// - If both already sorted: cost(PWMJ) = [R] + [S] +/// - If `R` (buffered) already sorted on the join key: `cost(PWMJ) = sort(S) + \[R\] + \[S\]` +/// - If `S` already sorted and `R` not: `cost(PWMJ) = sort(R) + \[R\] + \[S\]` +/// - If both already sorted: `cost(PWMJ) = \[R\] + \[S\]` /// /// ## Nested Loop Join -/// cost(NLJ) ≈ [R] + |R|·[S] +/// `cost(NLJ) ≈ \[R\] + |R|·\[S\]` /// /// Takeaway: -/// - When at least one side needs sorting, PWMJ ≈ sort(R) + sort(S) + [R] + [S] on average, -/// typically beating NLJ’s |R|·[S] (or its buffered variant) for nontrivial |R|, [S]. +/// - When at least one side needs sorting, PWMJ ≈ `sort(R) + sort(S) + \[R\] + \[S\]` on average, +/// typically beating NLJ’s `|R|·\[S\]` (or its buffered variant) for nontrivial `|R|`, `\[S\]`. /// /// # Further Reference Material /// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) @@ -387,7 +386,7 @@ impl PiecewiseMergeJoinExec { schema, &Self::maintains_input_order(join_type), Some(Self::probe_side(&join_type)), - &[join_on.clone()], + std::slice::from_ref(join_on), )?; let output_partitioning = @@ -1136,8 +1135,8 @@ impl PiecewiseMergeJoinStream { let buffered_value = ScalarValue::try_from_array(&buffered_values, buffered_idx)?; let ord = compare_rows( - &[threshold.clone()], - &[buffered_value.clone()], + std::slice::from_ref(&threshold), + std::slice::from_ref(&buffered_value), &[self.sort_option], )?; From 79a5aab2a58e530bcbfcbf121c19ba3944012908 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Tue, 19 Aug 2025 22:38:33 -0400 Subject: [PATCH 23/24] remove pub from `BatchProcessState` --- .../physical-plan/src/joins/piecewise_merge_join.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index 8215f7bdd7cb..c36c875d55a2 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -1269,7 +1269,7 @@ impl BatchProcessState { } } - pub fn reset(&mut self) { + fn reset(&mut self) { self.start_idx = 0; self.num_rows = 0; self.pivot = 0; @@ -1279,23 +1279,23 @@ impl BatchProcessState { self.buffered_indices = None; } - pub fn pivot(&self) -> usize { + fn pivot(&self) -> usize { self.pivot } - pub fn set_pivot(&mut self, pivot: usize) { + fn set_pivot(&mut self, pivot: usize) { self.pivot += pivot; } - pub fn set_start_idx(&mut self, start_idx: usize) { + fn set_start_idx(&mut self, start_idx: usize) { self.start_idx = start_idx; } - pub fn set_rows(&mut self, num_rows: usize) { + fn set_rows(&mut self, num_rows: usize) { self.num_rows = num_rows; } - pub fn set_process_rest(&mut self, process_rest: Option) { + fn set_process_rest(&mut self, process_rest: Option) { self.process_rest = process_rest; } } From 52979367a9672273a189f553f2c9993148ac8be0 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Thu, 4 Sep 2025 22:54:27 -0400 Subject: [PATCH 24/24] doc changes --- .../src/joins/piecewise_merge_join.rs | 89 +++++++++++++------ 1 file changed, 60 insertions(+), 29 deletions(-) diff --git a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs index c36c875d55a2..a7a95e237df0 100644 --- a/datafusion/physical-plan/src/joins/piecewise_merge_join.rs +++ b/datafusion/physical-plan/src/joins/piecewise_merge_join.rs @@ -30,7 +30,7 @@ use arrow::{ }; use arrow_schema::{ArrowError, Schema, SchemaRef, SortOptions}; use datafusion_common::{ - exec_err, internal_err, plan_err, utils::compare_rows, JoinSide, Result, ScalarValue, + exec_err, internal_err, utils::compare_rows, JoinSide, Result, ScalarValue, }; use datafusion_common::{not_impl_err, NullEquality}; use datafusion_execution::{ @@ -101,6 +101,24 @@ pub const DEFAULT_INCREMENTAL_BATCH_VALUE: usize = 8192; /// have it already sorted either ascending or descending based on the operator as this allows us to emit all /// the rows from a given point to the end as matches. Sorting the streamed side allows us to start the pointer /// from the previous row's match on the buffered side. +/// +/// For `Lt` (`<`) + `LtEq` (`<=`) operations both inputs are to be sorted in descending order and sorted in +/// ascending order for `Gt` (`>`) + `GtEq` (`>=`) than (`>`) operations. `SortExec` is used to enforce sorting +/// on the buffered side and streamed side is sorted in memory. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// for stream_row in stream_batch: +/// for buffer_row in buffer_batch: +/// if compare(stream_row, probe_row): +/// output stream_row X buffer_batch[buffer_row:] +/// else: +/// continue +/// ``` +/// +/// The algorithm uses the streamed side to drive the loop. This is due to every row on the stream side iterating +/// the buffered side to find every first match. /// /// Here is an example: /// @@ -155,8 +173,25 @@ pub const DEFAULT_INCREMENTAL_BATCH_VALUE: usize = 8192; /// Existence joins are made magnitudes of times faster with a `PiecewiseMergeJoin` as we only need to find /// the min/max value of the streamed side to be able to emit all matches on the buffered side. By putting /// the side we need to mark onto the sorted buffer side, we can emit all these matches at once. +/// +/// For less than operations (`<`) both inputs are to be sorted in descending order and vice versa for greater +/// than (`>`) operations. `SortExec` is used to enforce sorting on the buffered side and streamed side does not +/// need to be sorted due to only needing to find the min/max. /// /// For Left Semi, Anti, and Mark joins we swap the inputs so that the marked side is on the buffered side. +/// +/// The pseudocode for the algorithm looks like this: +/// +/// ```text +/// // Using the example of a less than `<` operation +/// let max = max_batch(streamed_batch) +/// +/// for buffer_row in buffer_batch: +/// if buffer_row < max: +/// output buffer_batch[buffer_row:] +/// ``` +/// +/// Only need to find the min/max value and iterate through the buffered side once. /// /// Here is an example: /// We perform a `JoinType::LeftSemi` with these two batches and the operator being `Operator::Lt`(<). Because @@ -168,7 +203,7 @@ pub const DEFAULT_INCREMENTAL_BATCH_VALUE: usize = 8192; /// ```text /// SQL statement: /// SELECT * -/// FROM (VALUES (100), (200), (500)) AS streamed(a) +/// FROM (VALUES (500), (200), (300)) AS streamed(a) /// LEFT SEMI JOIN (VALUES (100), (200), (200), (300), (400)) AS buffered(b) /// ON streamed.a < buffered.b; /// @@ -190,32 +225,26 @@ pub const DEFAULT_INCREMENTAL_BATCH_VALUE: usize = 8192; /// For both types of joins, the buffered side must be sorted ascending for `Operator::Lt` (<) or /// `Operator::LtEq` (<=) and descending for `Operator::Gt` (>) or `Operator::GtEq` (>=). /// -/// ## Assumptions / Notation -/// - \[R\], \[S\]: number of pages (blocks) of `R` and `S` -/// - |R|, |S|: number of tuples in `R` and `S` -/// - `B`: number of buffer pages -/// -/// # Performance (cost) -/// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is a breakdown -/// of the calculations: +/// # Performance Explanation (cost) +/// Piecewise Merge Join is used over Nested Loop Join due to its superior performance. Here is the breakdown: /// /// ## Piecewise Merge Join (PWMJ) -/// Intuition: Keep the buffered side (`R`) sorted and in memory (or scan it in sorted order), -/// sort the streamed side (`S`), then merge in order while advancing a pivot on `R`. +/// # Classic Join: +/// Requires sorting the probe side and, for each probe row, scanning the buffered side until the first match +/// is found. +/// Complexity: `O(sort(S) + |S| * scan(R))`. /// -/// Average I/O cost: -/// `cost(PWMJ) = sort(R) + sort(S) + (\[R\] + \[S\])` +/// # Mark Join: +/// Sorts the probe side, then computes the min/max range of the probe keys and scans the buffered side only +/// within that range. +/// Complexity: `O(|S| + scan(R[range]))`. /// -/// - If `R` (buffered) already sorted on the join key: `cost(PWMJ) = sort(S) + \[R\] + \[S\]` -/// - If `S` already sorted and `R` not: `cost(PWMJ) = sort(R) + \[R\] + \[S\]` -/// - If both already sorted: `cost(PWMJ) = \[R\] + \[S\]` +/// ## Nested Loop Join +/// Compares every row from `S` with every row from `R`. +/// Complexity: `O(|S| * |R|)`. /// /// ## Nested Loop Join -/// `cost(NLJ) ≈ \[R\] + |R|·\[S\]` -/// -/// Takeaway: -/// - When at least one side needs sorting, PWMJ ≈ `sort(R) + sort(S) + \[R\] + \[S\]` on average, -/// typically beating NLJ’s `|R|·\[S\]` (or its buffered variant) for nontrivial `|R|`, `\[S\]`. +/// Always going to be probe (O(N) * O(N)). /// /// # Further Reference Material /// DuckDB blog on Range Joins: [Range Joins in DuckDB](https://duckdb.org/2022/05/27/iejoin.html) @@ -237,9 +266,10 @@ pub struct PiecewiseMergeJoinExec { buffered_fut: OnceAsync, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// The left SortExpr + /// The left sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations left_sort_exprs: LexOrdering, - /// The right SortExpr + /// The right sort order, descending for `<`, `<=` operations + ascending for `>`, `>=` operations + /// Unsorted for mark joins right_sort_exprs: LexOrdering, /// Sort options of join columns used in sorting the stream and buffered execution plans sort_options: SortOptions, @@ -282,7 +312,7 @@ impl PiecewiseMergeJoinExec { } } _ => { - return plan_err!( + return internal_err!( "Cannot contain non-range operator in PiecewiseMergeJoinExec" ) } @@ -295,12 +325,12 @@ impl PiecewiseMergeJoinExec { vec![PhysicalSortExpr::new(Arc::clone(&on.1), sort_options)]; let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else { - return plan_err!( + return internal_err!( "PiecewiseMergeJoinExec requires valid sort expressions for its left side" ); }; let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else { - return plan_err!( + return internal_err!( "PiecewiseMergeJoinExec requires valid sort expressions for its right side" ); }; @@ -400,14 +430,15 @@ impl PiecewiseMergeJoinExec { )) } + // TODO: Add input order fn maintains_input_order(join_type: JoinType) -> Vec { match join_type { // The existence side is expected to come in sorted JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { - vec![false, true] + vec![false, false] } JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => { - vec![true, false] + vec![false, false] } // Left, Right, Full, Inner Join is not guaranteed to maintain // input order as the streamed side will be sorted during