17
17
18
18
use std:: sync:: Arc ;
19
19
20
+ use crate :: memory_limit:: DummyStreamPartition ;
20
21
use crate :: physical_optimizer:: test_utils:: {
21
22
aggregate_exec, bounded_window_exec, bounded_window_exec_with_partition,
22
23
check_integrity, coalesce_batches_exec, coalesce_partitions_exec, create_test_schema,
@@ -32,11 +33,11 @@ use arrow::compute::SortOptions;
32
33
use arrow:: datatypes:: { DataType , SchemaRef } ;
33
34
use datafusion_common:: config:: ConfigOptions ;
34
35
use datafusion_common:: tree_node:: { TreeNode , TransformedResult } ;
35
- use datafusion_common:: { Result , ScalarValue } ;
36
+ use datafusion_common:: { Result , ScalarValue , TableReference } ;
36
37
use datafusion_datasource:: file_scan_config:: FileScanConfigBuilder ;
37
38
use datafusion_datasource:: source:: DataSourceExec ;
38
39
use datafusion_expr_common:: operator:: Operator ;
39
- use datafusion_expr:: { JoinType , WindowFrame , WindowFrameBound , WindowFrameUnits , WindowFunctionDefinition } ;
40
+ use datafusion_expr:: { JoinType , SortExpr , WindowFrame , WindowFrameBound , WindowFrameUnits , WindowFunctionDefinition } ;
40
41
use datafusion_execution:: object_store:: ObjectStoreUrl ;
41
42
use datafusion_functions_aggregate:: average:: avg_udaf;
42
43
use datafusion_functions_aggregate:: count:: count_udaf;
@@ -61,7 +62,14 @@ use datafusion_physical_optimizer::enforce_sorting::sort_pushdown::{SortPushDown
61
62
use datafusion_physical_optimizer:: enforce_distribution:: EnforceDistribution ;
62
63
use datafusion_physical_optimizer:: output_requirements:: OutputRequirementExec ;
63
64
use datafusion_physical_optimizer:: PhysicalOptimizerRule ;
64
-
65
+ use datafusion:: prelude:: * ;
66
+ use arrow:: array:: { Int32Array , RecordBatch } ;
67
+ use arrow:: datatypes:: { Field } ;
68
+ use arrow_schema:: Schema ;
69
+ use datafusion_execution:: TaskContext ;
70
+ use datafusion_catalog:: streaming:: StreamingTable ;
71
+
72
+ use futures:: StreamExt ;
65
73
use rstest:: rstest;
66
74
67
75
/// Create a sorted Csv exec
@@ -879,6 +887,7 @@ async fn test_soft_hard_requirements_multiple_soft_requirements() -> Result<()>
879
887
assert_optimized ! ( expected_input, expected_optimized, physical_plan, true ) ;
880
888
Ok ( ( ) )
881
889
}
890
+
882
891
#[ tokio:: test]
883
892
async fn test_soft_hard_requirements_multiple_sorts ( ) -> Result < ( ) > {
884
893
let schema = create_test_schema ( ) ?;
@@ -3842,3 +3851,124 @@ fn test_parallelize_sort_preserves_fetch() -> Result<()> {
3842
3851
) ;
3843
3852
Ok ( ( ) )
3844
3853
}
3854
+
3855
+ #[ tokio:: test]
3856
+ async fn test_partial_sort_with_homogeneous_batches ( ) -> Result < ( ) > {
3857
+ // Create schema for the table
3858
+ let schema = Arc :: new ( Schema :: new ( vec ! [
3859
+ Field :: new( "a" , DataType :: Int32 , false ) ,
3860
+ Field :: new( "b" , DataType :: Int32 , false ) ,
3861
+ Field :: new( "c" , DataType :: Int32 , false ) ,
3862
+ ] ) ) ;
3863
+
3864
+ // Create homogeneous batches - each batch has the same values for columns a and b
3865
+ let batch1 = RecordBatch :: try_new (
3866
+ schema. clone ( ) ,
3867
+ vec ! [
3868
+ Arc :: new( Int32Array :: from( vec![ 1 , 1 , 1 ] ) ) ,
3869
+ Arc :: new( Int32Array :: from( vec![ 1 , 1 , 1 ] ) ) ,
3870
+ Arc :: new( Int32Array :: from( vec![ 3 , 2 , 1 ] ) ) ,
3871
+ ] ,
3872
+ ) ?;
3873
+ let batch2 = RecordBatch :: try_new (
3874
+ schema. clone ( ) ,
3875
+ vec ! [
3876
+ Arc :: new( Int32Array :: from( vec![ 2 , 2 , 2 ] ) ) ,
3877
+ Arc :: new( Int32Array :: from( vec![ 2 , 2 , 2 ] ) ) ,
3878
+ Arc :: new( Int32Array :: from( vec![ 4 , 6 , 5 ] ) ) ,
3879
+ ] ,
3880
+ ) ?;
3881
+ let batch3 = RecordBatch :: try_new (
3882
+ schema. clone ( ) ,
3883
+ vec ! [
3884
+ Arc :: new( Int32Array :: from( vec![ 3 , 3 , 3 ] ) ) ,
3885
+ Arc :: new( Int32Array :: from( vec![ 3 , 3 , 3 ] ) ) ,
3886
+ Arc :: new( Int32Array :: from( vec![ 9 , 7 , 8 ] ) ) ,
3887
+ ] ,
3888
+ ) ?;
3889
+
3890
+ // Create session with batch size of 3 to match our homogeneous batch pattern
3891
+ let session_config = SessionConfig :: new ( )
3892
+ . with_batch_size ( 3 )
3893
+ . with_target_partitions ( 1 ) ;
3894
+ let ctx = SessionContext :: new_with_config ( session_config) ;
3895
+
3896
+ let sort_order = vec ! [
3897
+ SortExpr :: new(
3898
+ Expr :: Column ( datafusion_common:: Column :: new(
3899
+ Option :: <TableReference >:: None ,
3900
+ "a" ,
3901
+ ) ) ,
3902
+ true ,
3903
+ false ,
3904
+ ) ,
3905
+ SortExpr :: new(
3906
+ Expr :: Column ( datafusion_common:: Column :: new(
3907
+ Option :: <TableReference >:: None ,
3908
+ "b" ,
3909
+ ) ) ,
3910
+ true ,
3911
+ false ,
3912
+ ) ,
3913
+ ] ;
3914
+ let batches = Arc :: new ( DummyStreamPartition {
3915
+ schema : schema. clone ( ) ,
3916
+ batches : vec ! [ batch1, batch2, batch3] ,
3917
+ } ) as _ ;
3918
+ let provider = StreamingTable :: try_new ( schema. clone ( ) , vec ! [ batches] ) ?
3919
+ . with_sort_order ( sort_order)
3920
+ . with_infinite_table ( true ) ;
3921
+ ctx. register_table ( "test_table" , Arc :: new ( provider) ) ?;
3922
+
3923
+ let sql = "SELECT * FROM test_table ORDER BY a ASC, c ASC" ;
3924
+ let df = ctx. sql ( sql) . await ?;
3925
+
3926
+ let physical_plan = df. create_physical_plan ( ) . await ?;
3927
+
3928
+ // Verify that PartialSortExec is used
3929
+ let plan_str = displayable ( physical_plan. as_ref ( ) ) . indent ( true ) . to_string ( ) ;
3930
+ assert ! (
3931
+ plan_str. contains( "PartialSortExec" ) ,
3932
+ "Expected PartialSortExec in plan:\n {plan_str}" ,
3933
+ ) ;
3934
+
3935
+ let task_ctx = Arc :: new ( TaskContext :: default ( ) ) ;
3936
+ let mut stream = physical_plan. execute ( 0 , task_ctx. clone ( ) ) ?;
3937
+
3938
+ let mut collected_batches = Vec :: new ( ) ;
3939
+ while let Some ( batch) = stream. next ( ) . await {
3940
+ let batch = batch?;
3941
+ if batch. num_rows ( ) > 0 {
3942
+ collected_batches. push ( batch) ;
3943
+ }
3944
+ }
3945
+
3946
+ // Assert we got 3 separate batches (not concatenated into fewer)
3947
+ assert_eq ! (
3948
+ collected_batches. len( ) ,
3949
+ 3 ,
3950
+ "Expected 3 separate batches, got {}" ,
3951
+ collected_batches. len( )
3952
+ ) ;
3953
+
3954
+ // Verify each batch has been sorted within itself
3955
+ let expected_values = [ vec ! [ 1 , 2 , 3 ] , vec ! [ 4 , 5 , 6 ] , vec ! [ 7 , 8 , 9 ] ] ;
3956
+
3957
+ for ( i, batch) in collected_batches. iter ( ) . enumerate ( ) {
3958
+ let c_array = batch
3959
+ . column ( 2 )
3960
+ . as_any ( )
3961
+ . downcast_ref :: < Int32Array > ( )
3962
+ . unwrap ( ) ;
3963
+ let actual = c_array. values ( ) . iter ( ) . copied ( ) . collect :: < Vec < i32 > > ( ) ;
3964
+ assert_eq ! ( actual, expected_values[ i] , "Batch {i} not sorted correctly" , ) ;
3965
+ }
3966
+
3967
+ assert_eq ! (
3968
+ task_ctx. runtime_env( ) . memory_pool. reserved( ) ,
3969
+ 0 ,
3970
+ "Memory should be released after execution"
3971
+ ) ;
3972
+
3973
+ Ok ( ( ) )
3974
+ }
0 commit comments