Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions datafusion/common/src/join_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,51 @@ impl JoinType {
| JoinType::RightAnti
)
}

/// Determines whether each input of the join is preserved for WHERE clause filters.
///
/// A join input is "preserved" if every row from that input appears in at least one
/// output row. This property determines whether filters referencing only columns
/// from that input can be safely pushed below the join.
///
/// For example:
/// - In an inner join, both sides are preserved, because each row of the output
/// maps directly to a row from each side.
/// - In a left join, the left side is preserved (we can push predicates) but
/// the right is not, because there may be rows in the output that don't
/// directly map to a row in the right input (due to nulls filling where there
/// is no match on the right).
/// - In semi joins, only the preserved side's columns appear in the output,
/// so filters can only reference and be pushed to that side.
///
/// # Returns
/// A tuple of `(left_preserved, right_preserved)` booleans.
///
/// # Examples
///
/// ```
/// use datafusion_common::JoinType;
///
/// assert_eq!(JoinType::Inner.lr_is_preserved(), (true, true));
/// assert_eq!(JoinType::Left.lr_is_preserved(), (true, false));
/// assert_eq!(JoinType::LeftSemi.lr_is_preserved(), (true, false));
/// ```
pub fn lr_is_preserved(self) -> (bool, bool) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should use a slightly more clear name, maybe like join_side_is_preserved() or something along those lines

match self {
JoinType::Inner => (true, true),
JoinType::Left => (true, false),
JoinType::Right => (false, true),
JoinType::Full => (false, false),
// No columns from the right side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
// No columns from the left side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
(false, true)
}
}
}
}

impl Display for JoinType {
Expand Down
174 changes: 174 additions & 0 deletions datafusion/core/tests/physical_optimizer/filter_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use datafusion::{
scalar::ScalarValue,
};
use datafusion_common::config::ConfigOptions;
use datafusion_common::JoinType;
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_expr::ScalarUDF;
use datafusion_functions::math::random::RandomFunc;
Expand All @@ -51,6 +52,7 @@ use datafusion_physical_plan::{
coalesce_partitions::CoalescePartitionsExec,
collect,
filter::FilterExec,
joins::{HashJoinExec, PartitionMode},
repartition::RepartitionExec,
sorts::sort::SortExec,
ExecutionPlan,
Expand Down Expand Up @@ -427,6 +429,69 @@ async fn test_static_filter_pushdown_through_hash_join() {
);
}

#[test]
fn test_filter_pushdown_left_semi_join() {
// Create schemas for left and right sides
let left_side_schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Utf8, false),
Field::new("c", DataType::Float64, false),
]));
let right_side_schema = Arc::new(Schema::new(vec![
Field::new("d", DataType::Utf8, false),
Field::new("e", DataType::Utf8, false),
Field::new("f", DataType::Float64, false),
]));

let left_scan = TestScanBuilder::new(Arc::clone(&left_side_schema))
.with_support(true)
.build();
let right_scan = TestScanBuilder::new(Arc::clone(&right_side_schema))
.with_support(true)
.build();

let on = vec![(
col("a", &left_side_schema).unwrap(),
col("d", &right_side_schema).unwrap(),
)];
let join = Arc::new(
HashJoinExec::try_new(
left_scan,
right_scan,
on,
None,
&JoinType::LeftSemi,
None,
PartitionMode::Partitioned,
datafusion_common::NullEquality::NullEqualsNothing,
)
.unwrap(),
);

let join_schema = join.schema();
let filter = col_lit_predicate("a", "aa", &join_schema);
let plan =
Arc::new(FilterExec::try_new(filter, join).unwrap()) as Arc<dyn ExecutionPlan>;

// Test that filters ARE pushed down for left semi join when they reference only left side
insta::assert_snapshot!(
OptimizationTest::new(plan, FilterPushdown::new(), true),
@r"
OptimizationTest:
input:
- FilterExec: a@0 = aa
- HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(a@0, d@0)]
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true
output:
Ok:
- HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(a@0, d@0)]
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[a, b, c], file_type=test, pushdown_supported=true, predicate=a@0 = aa
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[d, e, f], file_type=test, pushdown_supported=true
"
);
}

#[test]
fn test_filter_collapse() {
// filter should be pushed down into the parquet scan with two filters
Expand Down Expand Up @@ -1542,3 +1607,112 @@ fn col_lit_predicate(
Arc::new(Literal::new(scalar_value)),
))
}

#[tokio::test]
async fn test_left_semi_join_dynamic_filter_pushdown() {
use datafusion_common::JoinType;
use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode};

// Create build side (left side) with limited values
let build_batches = vec![record_batch!(
("id", Int32, [1, 2]),
("name", Utf8, ["Alice", "Bob"]),
("score", Float64, [95.0, 87.0])
)
.unwrap()];
let build_side_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
Field::new("score", DataType::Float64, false),
]));
let build_scan = TestScanBuilder::new(Arc::clone(&build_side_schema))
.with_support(true)
.with_batches(build_batches)
.build();

// Create probe side (right side) with more values
let probe_batches = vec![record_batch!(
("id", Int32, [1, 2, 3, 4]),
(
"department",
Utf8,
["Engineering", "Sales", "Marketing", "HR"]
),
("budget", Float64, [100000.0, 80000.0, 60000.0, 50000.0])
)
.unwrap()];
let probe_side_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("department", DataType::Utf8, false),
Field::new("budget", DataType::Float64, false),
]));
let probe_scan = TestScanBuilder::new(Arc::clone(&probe_side_schema))
.with_support(true)
.with_batches(probe_batches)
.build();

// Create HashJoinExec with LeftSemi join type
let on = vec![(
col("id", &build_side_schema).unwrap(),
col("id", &probe_side_schema).unwrap(),
)];
let plan = Arc::new(
HashJoinExec::try_new(
build_scan,
probe_scan,
on,
None,
&JoinType::LeftSemi,
None,
PartitionMode::Partitioned,
datafusion_common::NullEquality::NullEqualsNothing,
)
.unwrap(),
) as Arc<dyn ExecutionPlan>;

// Verify that dynamic filter pushdown creates the expected plan structure
insta::assert_snapshot!(
OptimizationTest::new(Arc::clone(&plan), FilterPushdown::new_post_optimization(), true),
@r"
OptimizationTest:
input:
- HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(id@0, id@0)]
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, name, score], file_type=test, pushdown_supported=true
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, department, budget], file_type=test, pushdown_supported=true
output:
Ok:
- HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(id@0, id@0)]
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, name, score], file_type=test, pushdown_supported=true
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, department, budget], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ true ]
",
);

// Apply the optimization and execute to see the actual filter bounds
let mut config = ConfigOptions::default();
config.execution.parquet.pushdown_filters = true;
config.optimizer.enable_dynamic_filter_pushdown = true;
let plan = FilterPushdown::new_post_optimization()
.optimize(plan, &config)
.unwrap();
let config = SessionConfig::new().with_batch_size(10);
let session_ctx = SessionContext::new_with_config(config);
session_ctx.register_object_store(
ObjectStoreUrl::parse("test://").unwrap().as_ref(),
Arc::new(InMemory::new()),
);
let state = session_ctx.state();
let task_ctx = state.task_ctx();
let mut stream = plan.execute(0, Arc::clone(&task_ctx)).unwrap();
// Execute one batch to populate the dynamic filter
stream.next().await.unwrap().unwrap();

// Verify that the dynamic filter shows the expected bounds for left semi join
insta::assert_snapshot!(
format!("{}", format_plan_for_test(&plan)),
@r"
- HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(id@0, id@0)], filter=[id@0 >= 1 AND id@0 <= 2]
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, name, score], file_type=test, pushdown_supported=true
- DataSourceExec: file_groups={1 group: [[test.parquet]]}, projection=[id, department, budget], file_type=test, pushdown_supported=true, predicate=DynamicFilterPhysicalExpr [ id@0 >= 1 AND id@0 <= 2 ]
"
);
}
43 changes: 2 additions & 41 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,45 +135,6 @@ use crate::{OptimizerConfig, OptimizerRule};
#[derive(Default, Debug)]
pub struct PushDownFilter {}

/// For a given JOIN type, determine whether each input of the join is preserved
/// for post-join (`WHERE` clause) filters.
///
/// It is only correct to push filters below a join for preserved inputs.
///
/// # Return Value
/// A tuple of booleans - (left_preserved, right_preserved).
///
/// # "Preserved" input definition
///
/// We say a join side is preserved if the join returns all or a subset of the rows from
/// the relevant side, such that each row of the output table directly maps to a row of
/// the preserved input table. If a table is not preserved, it can provide extra null rows.
/// That is, there may be rows in the output table that don't directly map to a row in the
/// input table.
///
/// For example:
/// - In an inner join, both sides are preserved, because each row of the output
/// maps directly to a row from each side.
///
/// - In a left join, the left side is preserved (we can push predicates) but
/// the right is not, because there may be rows in the output that don't
/// directly map to a row in the right input (due to nulls filling where there
/// is no match on the right).
pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
match join_type {
JoinType::Inner => (true, true),
JoinType::Left => (true, false),
JoinType::Right => (false, true),
JoinType::Full => (false, false),
// No columns from the right side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false),
// No columns from the left side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => (false, true),
}
}

/// For a given JOIN type, determine whether each input of the join is preserved
/// for the join condition (`ON` clause filters).
///
Expand All @@ -182,7 +143,7 @@ pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
/// # Return Value
/// A tuple of booleans - (left_preserved, right_preserved).
///
/// See [`lr_is_preserved`] for a definition of "preserved".
/// See [`JoinType::lr_is_preserved`] for a definition of "preserved".
pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
match join_type {
JoinType::Inner => (true, true),
Expand Down Expand Up @@ -426,7 +387,7 @@ fn push_down_all_join(
) -> Result<Transformed<LogicalPlan>> {
let is_inner_join = join.join_type == JoinType::Inner;
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);
let (left_preserved, right_preserved) = join.join_type.lr_is_preserved();

// The predicates can be divided to three categories:
// 1) can push through join to its children(left or right)
Expand Down
Loading
Loading