Skip to content

Commit 8bf7123

Browse files
Fix Partial Sort Get Slice Point Between Batches (#16881)
* Update partial_sort.rs * Update partial_sort.rs * Update partial_sort.rs * add sql test
1 parent fd08e72 commit 8bf7123

File tree

4 files changed

+289
-46
lines changed

4 files changed

+289
-46
lines changed

datafusion/catalog/src/streaming.rs

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,17 @@
2020
use std::any::Any;
2121
use std::sync::Arc;
2222

23-
use arrow::datatypes::SchemaRef;
24-
use async_trait::async_trait;
25-
2623
use crate::Session;
2724
use crate::TableProvider;
28-
use datafusion_common::{plan_err, Result};
29-
use datafusion_expr::{Expr, TableType};
25+
26+
use arrow::datatypes::SchemaRef;
27+
use datafusion_common::{plan_err, DFSchema, Result};
28+
use datafusion_expr::{Expr, SortExpr, TableType};
29+
use datafusion_physical_expr::{create_physical_sort_exprs, LexOrdering};
3030
use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec};
3131
use datafusion_physical_plan::ExecutionPlan;
32+
33+
use async_trait::async_trait;
3234
use log::debug;
3335

3436
/// A [`TableProvider`] that streams a set of [`PartitionStream`]
@@ -37,6 +39,7 @@ pub struct StreamingTable {
3739
schema: SchemaRef,
3840
partitions: Vec<Arc<dyn PartitionStream>>,
3941
infinite: bool,
42+
sort_order: Vec<SortExpr>,
4043
}
4144

4245
impl StreamingTable {
@@ -60,13 +63,21 @@ impl StreamingTable {
6063
schema,
6164
partitions,
6265
infinite: false,
66+
sort_order: vec![],
6367
})
6468
}
69+
6570
/// Sets streaming table can be infinite.
6671
pub fn with_infinite_table(mut self, infinite: bool) -> Self {
6772
self.infinite = infinite;
6873
self
6974
}
75+
76+
/// Sets the existing ordering of streaming table.
77+
pub fn with_sort_order(mut self, sort_order: Vec<SortExpr>) -> Self {
78+
self.sort_order = sort_order;
79+
self
80+
}
7081
}
7182

7283
#[async_trait]
@@ -85,16 +96,25 @@ impl TableProvider for StreamingTable {
8596

8697
async fn scan(
8798
&self,
88-
_state: &dyn Session,
99+
state: &dyn Session,
89100
projection: Option<&Vec<usize>>,
90101
_filters: &[Expr],
91102
limit: Option<usize>,
92103
) -> Result<Arc<dyn ExecutionPlan>> {
104+
let physical_sort = if !self.sort_order.is_empty() {
105+
let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
106+
let eqp = state.execution_props();
107+
108+
create_physical_sort_exprs(&self.sort_order, &df_schema, eqp)?
109+
} else {
110+
vec![]
111+
};
112+
93113
Ok(Arc::new(StreamingTableExec::try_new(
94114
Arc::clone(&self.schema),
95115
self.partitions.clone(),
96116
projection,
97-
None,
117+
LexOrdering::new(physical_sort),
98118
self.infinite,
99119
limit,
100120
)?))

datafusion/core/tests/memory_limit/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,9 +1084,9 @@ fn batches_byte_size(batches: &[RecordBatch]) -> usize {
10841084
}
10851085

10861086
#[derive(Debug)]
1087-
struct DummyStreamPartition {
1088-
schema: SchemaRef,
1089-
batches: Vec<RecordBatch>,
1087+
pub(crate) struct DummyStreamPartition {
1088+
pub(crate) schema: SchemaRef,
1089+
pub(crate) batches: Vec<RecordBatch>,
10901090
}
10911091

10921092
impl PartitionStream for DummyStreamPartition {

datafusion/core/tests/physical_optimizer/enforce_sorting.rs

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use std::sync::Arc;
1919

20+
use crate::memory_limit::DummyStreamPartition;
2021
use crate::physical_optimizer::test_utils::{
2122
aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition,
2223
check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema,
@@ -32,11 +33,11 @@ use arrow::compute::SortOptions;
3233
use arrow::datatypes::{DataType, SchemaRef};
3334
use datafusion_common::config::ConfigOptions;
3435
use datafusion_common::tree_node::{TreeNode, TransformedResult};
35-
use datafusion_common::{Result, ScalarValue};
36+
use datafusion_common::{Result, ScalarValue, TableReference};
3637
use datafusion_datasource::file_scan_config::FileScanConfigBuilder;
3738
use datafusion_datasource::source::DataSourceExec;
3839
use datafusion_expr_common::operator::Operator;
39-
use datafusion_expr::{JoinType, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition};
40+
use datafusion_expr::{JoinType, SortExpr, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition};
4041
use datafusion_execution::object_store::ObjectStoreUrl;
4142
use datafusion_functions_aggregate::average::avg_udaf;
4243
use datafusion_functions_aggregate::count::count_udaf;
@@ -61,7 +62,14 @@ use datafusion_physical_optimizer::enforce_sorting::sort_pushdown::{SortPushDown
6162
use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution;
6263
use datafusion_physical_optimizer::output_requirements::OutputRequirementExec;
6364
use datafusion_physical_optimizer::PhysicalOptimizerRule;
64-
65+
use datafusion::prelude::*;
66+
use arrow::array::{Int32Array, RecordBatch};
67+
use arrow::datatypes::{Field};
68+
use arrow_schema::Schema;
69+
use datafusion_execution::TaskContext;
70+
use datafusion_catalog::streaming::StreamingTable;
71+
72+
use futures::StreamExt;
6573
use rstest::rstest;
6674

6775
/// Create a sorted Csv exec
@@ -879,6 +887,7 @@ async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()>
879887
assert_optimized!(expected_input, expected_optimized, physical_plan, true);
880888
Ok(())
881889
}
890+
882891
#[tokio::test]
883892
async fn test_soft_hard_requirements_multiple_sorts() -> Result<()> {
884893
let schema = create_test_schema()?;
@@ -3842,3 +3851,124 @@ fn test_parallelize_sort_preserves_fetch() -> Result<()> {
38423851
);
38433852
Ok(())
38443853
}
3854+
3855+
#[tokio::test]
3856+
async fn test_partial_sort_with_homogeneous_batches() -> Result<()> {
3857+
// Create schema for the table
3858+
let schema = Arc::new(Schema::new(vec![
3859+
Field::new("a", DataType::Int32, false),
3860+
Field::new("b", DataType::Int32, false),
3861+
Field::new("c", DataType::Int32, false),
3862+
]));
3863+
3864+
// Create homogeneous batches - each batch has the same values for columns a and b
3865+
let batch1 = RecordBatch::try_new(
3866+
schema.clone(),
3867+
vec![
3868+
Arc::new(Int32Array::from(vec![1, 1, 1])),
3869+
Arc::new(Int32Array::from(vec![1, 1, 1])),
3870+
Arc::new(Int32Array::from(vec![3, 2, 1])),
3871+
],
3872+
)?;
3873+
let batch2 = RecordBatch::try_new(
3874+
schema.clone(),
3875+
vec![
3876+
Arc::new(Int32Array::from(vec![2, 2, 2])),
3877+
Arc::new(Int32Array::from(vec![2, 2, 2])),
3878+
Arc::new(Int32Array::from(vec![4, 6, 5])),
3879+
],
3880+
)?;
3881+
let batch3 = RecordBatch::try_new(
3882+
schema.clone(),
3883+
vec![
3884+
Arc::new(Int32Array::from(vec![3, 3, 3])),
3885+
Arc::new(Int32Array::from(vec![3, 3, 3])),
3886+
Arc::new(Int32Array::from(vec![9, 7, 8])),
3887+
],
3888+
)?;
3889+
3890+
// Create session with batch size of 3 to match our homogeneous batch pattern
3891+
let session_config = SessionConfig::new()
3892+
.with_batch_size(3)
3893+
.with_target_partitions(1);
3894+
let ctx = SessionContext::new_with_config(session_config);
3895+
3896+
let sort_order = vec![
3897+
SortExpr::new(
3898+
Expr::Column(datafusion_common::Column::new(
3899+
Option::<TableReference>::None,
3900+
"a",
3901+
)),
3902+
true,
3903+
false,
3904+
),
3905+
SortExpr::new(
3906+
Expr::Column(datafusion_common::Column::new(
3907+
Option::<TableReference>::None,
3908+
"b",
3909+
)),
3910+
true,
3911+
false,
3912+
),
3913+
];
3914+
let batches = Arc::new(DummyStreamPartition {
3915+
schema: schema.clone(),
3916+
batches: vec![batch1, batch2, batch3],
3917+
}) as _;
3918+
let provider = StreamingTable::try_new(schema.clone(), vec![batches])?
3919+
.with_sort_order(sort_order)
3920+
.with_infinite_table(true);
3921+
ctx.register_table("test_table", Arc::new(provider))?;
3922+
3923+
let sql = "SELECT * FROM test_table ORDER BY a ASC, c ASC";
3924+
let df = ctx.sql(sql).await?;
3925+
3926+
let physical_plan = df.create_physical_plan().await?;
3927+
3928+
// Verify that PartialSortExec is used
3929+
let plan_str = displayable(physical_plan.as_ref()).indent(true).to_string();
3930+
assert!(
3931+
plan_str.contains("PartialSortExec"),
3932+
"Expected PartialSortExec in plan:\n{plan_str}",
3933+
);
3934+
3935+
let task_ctx = Arc::new(TaskContext::default());
3936+
let mut stream = physical_plan.execute(0, task_ctx.clone())?;
3937+
3938+
let mut collected_batches = Vec::new();
3939+
while let Some(batch) = stream.next().await {
3940+
let batch = batch?;
3941+
if batch.num_rows() > 0 {
3942+
collected_batches.push(batch);
3943+
}
3944+
}
3945+
3946+
// Assert we got 3 separate batches (not concatenated into fewer)
3947+
assert_eq!(
3948+
collected_batches.len(),
3949+
3,
3950+
"Expected 3 separate batches, got {}",
3951+
collected_batches.len()
3952+
);
3953+
3954+
// Verify each batch has been sorted within itself
3955+
let expected_values = [vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]];
3956+
3957+
for (i, batch) in collected_batches.iter().enumerate() {
3958+
let c_array = batch
3959+
.column(2)
3960+
.as_any()
3961+
.downcast_ref::<Int32Array>()
3962+
.unwrap();
3963+
let actual = c_array.values().iter().copied().collect::<Vec<i32>>();
3964+
assert_eq!(actual, expected_values[i], "Batch {i} not sorted correctly",);
3965+
}
3966+
3967+
assert_eq!(
3968+
task_ctx.runtime_env().memory_pool.reserved(),
3969+
0,
3970+
"Memory should be released after execution"
3971+
);
3972+
3973+
Ok(())
3974+
}

0 commit comments

Comments
 (0)