@@ -84,13 +84,14 @@ impl TableProvider for RecordBatchReaderProvider {
84
84
async fn scan (
85
85
& self ,
86
86
_state : & dyn Session ,
87
- _projection : Option < & Vec < usize > > ,
87
+ projection : Option < & Vec < usize > > ,
88
88
_filters : & [ Expr ] ,
89
89
limit : Option < usize > ,
90
90
) -> Result < Arc < dyn ExecutionPlan > > {
91
91
let mut reader_guard = self . reader . lock ( ) ;
92
92
if let Some ( reader) = reader_guard. take ( ) {
93
- Ok ( Arc :: new ( RecordBatchReaderExec :: new ( reader, limit) ) )
93
+ let projection = projection. cloned ( ) ;
94
+ Ok ( Arc :: new ( RecordBatchReaderExec :: new ( reader, limit, projection) ) )
94
95
} else {
95
96
sedona_internal_err ! ( "Can't scan RecordBatchReader provider more than once" )
96
97
}
@@ -158,11 +159,18 @@ struct RecordBatchReaderExec {
158
159
schema : SchemaRef ,
159
160
properties : PlanProperties ,
160
161
limit : Option < usize > ,
162
+ projection : Option < Vec < usize > > ,
161
163
}
162
164
163
165
impl RecordBatchReaderExec {
164
- fn new ( reader : Box < dyn RecordBatchReader + Send > , limit : Option < usize > ) -> Self {
165
- let schema = reader. schema ( ) ;
166
+ fn new ( reader : Box < dyn RecordBatchReader + Send > , limit : Option < usize > , projection : Option < Vec < usize > > ) -> Self {
167
+ let full_schema = reader. schema ( ) ;
168
+ let schema: SchemaRef = if let Some ( indices) = projection. as_ref ( ) {
169
+ let fields: Vec < _ > = indices. iter ( ) . map ( |i| full_schema. field ( * i) . clone ( ) ) . collect ( ) ;
170
+ Arc :: new ( arrow_schema:: Schema :: new ( fields) )
171
+ } else {
172
+ full_schema. clone ( )
173
+ } ;
166
174
let properties = PlanProperties :: new (
167
175
EquivalenceProperties :: new ( schema. clone ( ) ) ,
168
176
Partitioning :: UnknownPartitioning ( 1 ) ,
@@ -175,6 +183,7 @@ impl RecordBatchReaderExec {
175
183
schema,
176
184
properties,
177
185
limit,
186
+ projection,
178
187
}
179
188
}
180
189
}
@@ -186,6 +195,7 @@ impl Debug for RecordBatchReaderExec {
186
195
. field ( "schema" , & self . schema )
187
196
. field ( "properties" , & self . properties )
188
197
. field ( "limit" , & self . limit )
198
+ . field ( "projection" , & self . projection )
189
199
. finish ( )
190
200
}
191
201
}
@@ -240,16 +250,33 @@ impl ExecutionPlan for RecordBatchReaderExec {
240
250
match self . limit {
241
251
Some ( limit) => {
242
252
// Create a row-limited iterator that properly handles row counting
243
- let iter = RowLimitedIterator :: new ( reader, limit) ;
253
+ let projection = self . projection . clone ( ) ;
254
+ let iter = RowLimitedIterator :: new ( reader, limit) . map ( move |res| match res {
255
+ Ok ( batch) => {
256
+ if let Some ( indices) = projection. as_ref ( ) {
257
+ Ok ( batch. project ( indices) . unwrap ( ) )
258
+ } else {
259
+ Ok ( batch)
260
+ }
261
+ }
262
+ Err ( e) => Err ( e) ,
263
+ } ) ;
244
264
let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
245
265
let record_batch_stream =
246
266
RecordBatchStreamAdapter :: new ( self . schema . clone ( ) , stream) ;
247
267
Ok ( Box :: pin ( record_batch_stream) )
248
268
}
249
269
None => {
250
270
// No limit, just convert the reader directly to a stream
251
- let iter = reader. map ( |item| match item {
252
- Ok ( batch) => Ok ( batch) ,
271
+ let projection = self . projection . clone ( ) ;
272
+ let iter = reader. map ( move |item| match item {
273
+ Ok ( batch) => {
274
+ if let Some ( indices) = projection. as_ref ( ) {
275
+ Ok ( batch. project ( indices) . unwrap ( ) )
276
+ } else {
277
+ Ok ( batch)
278
+ }
279
+ }
253
280
Err ( e) => Err ( DataFusionError :: from ( e) ) ,
254
281
} ) ;
255
282
let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
@@ -413,3 +440,56 @@ mod test {
413
440
None
414
441
}
415
442
}
443
+
444
+
445
+
446
+
447
+
448
+
449
+
450
+
451
+
452
+
453
+
454
+
455
+ #[ tokio:: test]
456
+ async fn test_projection_pushdown ( ) {
457
+ use arrow_array:: { RecordBatch , RecordBatchIterator } ;
458
+ use arrow_schema:: { DataType , Field , Schema } ;
459
+ use datafusion:: prelude:: SessionContext ;
460
+ use datafusion:: prelude:: col;
461
+ let ctx = SessionContext :: new ( ) ;
462
+
463
+ // Create a two-column batch
464
+ let schema = Schema :: new ( vec ! [
465
+ Field :: new( "a" , DataType :: Int32 , false ) ,
466
+ Field :: new( "b" , DataType :: Int32 , false ) ,
467
+ ] ) ;
468
+ let batch = RecordBatch :: try_new (
469
+ Arc :: new ( schema. clone ( ) ) ,
470
+ vec ! [
471
+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
472
+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 10 , 20 , 30 ] ) ) ,
473
+ ] ,
474
+ )
475
+ . unwrap ( ) ;
476
+
477
+ // Wrap in a RecordBatchReaderProvider
478
+ let reader = RecordBatchIterator :: new ( vec ! [ batch. clone( ) ] . into_iter ( ) . map ( Ok ) , Arc :: new ( schema) ) ;
479
+ let provider = Arc :: new ( RecordBatchReaderProvider :: new ( Box :: new ( reader) ) ) ;
480
+
481
+ // Read table then select only column b (this should push projection into scan)
482
+ let df = ctx. read_table ( provider) . unwrap ( ) ;
483
+ let df_b = df. select ( vec ! [ col( "b" ) ] ) . unwrap ( ) ;
484
+ let results = df_b. collect ( ) . await . unwrap ( ) ;
485
+ assert_eq ! ( results. len( ) , 1 ) ;
486
+ let out_batch = & results[ 0 ] ;
487
+ assert_eq ! ( out_batch. num_columns( ) , 1 ) ;
488
+ assert_eq ! ( out_batch. schema( ) . field( 0 ) . name( ) , "b" ) ;
489
+ let values = out_batch
490
+ . column ( 0 )
491
+ . as_any ( )
492
+ . downcast_ref :: < arrow_array:: Int32Array > ( )
493
+ . unwrap ( ) ;
494
+ assert_eq ! ( values. values( ) , & [ 10 , 20 , 30 ] ) ;
495
+ }
0 commit comments