@@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
84
84
async fn scan (
85
85
& self ,
86
86
_state : & dyn Session ,
87
- _projection : Option < & Vec < usize > > ,
87
+ projection : Option < & Vec < usize > > ,
88
88
_filters : & [ Expr ] ,
89
89
limit : Option < usize > ,
90
90
) -> Result < Arc < dyn ExecutionPlan > > {
91
91
let mut reader_guard = self . reader . lock ( ) ;
92
92
if let Some ( reader) = reader_guard. take ( ) {
93
- Ok ( Arc :: new ( RecordBatchReaderExec :: new ( reader, limit) ) )
93
+ let projection = projection. cloned ( ) ;
94
+ Ok ( Arc :: new ( RecordBatchReaderExec :: new (
95
+ reader, limit, projection,
96
+ ) ) )
94
97
} else {
95
98
sedona_internal_err ! ( "Can't scan RecordBatchReader provider more than once" )
96
99
}
@@ -158,11 +161,25 @@ struct RecordBatchReaderExec {
158
161
schema : SchemaRef ,
159
162
properties : PlanProperties ,
160
163
limit : Option < usize > ,
164
+ projection : Option < Vec < usize > > ,
161
165
}
162
166
163
167
impl RecordBatchReaderExec {
164
- fn new ( reader : Box < dyn RecordBatchReader + Send > , limit : Option < usize > ) -> Self {
165
- let schema = reader. schema ( ) ;
168
+ fn new (
169
+ reader : Box < dyn RecordBatchReader + Send > ,
170
+ limit : Option < usize > ,
171
+ projection : Option < Vec < usize > > ,
172
+ ) -> Self {
173
+ let full_schema = reader. schema ( ) ;
174
+ let schema: SchemaRef = if let Some ( indices) = projection. as_ref ( ) {
175
+ let fields: Vec < _ > = indices
176
+ . iter ( )
177
+ . map ( |i| full_schema. field ( * i) . clone ( ) )
178
+ . collect ( ) ;
179
+ Arc :: new ( arrow_schema:: Schema :: new ( fields) )
180
+ } else {
181
+ full_schema. clone ( )
182
+ } ;
166
183
let properties = PlanProperties :: new (
167
184
EquivalenceProperties :: new ( schema. clone ( ) ) ,
168
185
Partitioning :: UnknownPartitioning ( 1 ) ,
@@ -175,6 +192,7 @@ impl RecordBatchReaderExec {
175
192
schema,
176
193
properties,
177
194
limit,
195
+ projection,
178
196
}
179
197
}
180
198
}
@@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
186
204
. field ( "schema" , & self . schema )
187
205
. field ( "properties" , & self . properties )
188
206
. field ( "limit" , & self . limit )
207
+ . field ( "projection" , & self . projection )
189
208
. finish ( )
190
209
}
191
210
}
@@ -240,16 +259,33 @@ impl ExecutionPlan for RecordBatchReaderExec {
240
259
match self . limit {
241
260
Some ( limit) => {
242
261
// Create a row-limited iterator that properly handles row counting
243
- let iter = RowLimitedIterator :: new ( reader, limit) ;
262
+ let projection = self . projection . clone ( ) ;
263
+ let iter = RowLimitedIterator :: new ( reader, limit) . map ( move |res| match res {
264
+ Ok ( batch) => {
265
+ if let Some ( indices) = projection. as_ref ( ) {
266
+ Ok ( batch. project ( indices) . unwrap ( ) )
267
+ } else {
268
+ Ok ( batch)
269
+ }
270
+ }
271
+ Err ( e) => Err ( e) ,
272
+ } ) ;
244
273
let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
245
274
let record_batch_stream =
246
275
RecordBatchStreamAdapter :: new ( self . schema . clone ( ) , stream) ;
247
276
Ok ( Box :: pin ( record_batch_stream) )
248
277
}
249
278
None => {
250
279
// No limit, just convert the reader directly to a stream
251
- let iter = reader. map ( |item| match item {
252
- Ok ( batch) => Ok ( batch) ,
280
+ let projection = self . projection . clone ( ) ;
281
+ let iter = reader. map ( move |item| match item {
282
+ Ok ( batch) => {
283
+ if let Some ( indices) = projection. as_ref ( ) {
284
+ Ok ( batch. project ( indices) . unwrap ( ) )
285
+ } else {
286
+ Ok ( batch)
287
+ }
288
+ }
253
289
Err ( e) => Err ( DataFusionError :: from ( e) ) ,
254
290
} ) ;
255
291
let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
@@ -383,6 +419,49 @@ mod test {
383
419
}
384
420
}
385
421
422
+ #[ tokio:: test]
423
+ async fn test_projection_pushdown ( ) {
424
+ use arrow_array:: { RecordBatch , RecordBatchIterator } ;
425
+ use arrow_schema:: { DataType , Field , Schema } ;
426
+ use datafusion:: prelude:: col;
427
+ use datafusion:: prelude:: SessionContext ;
428
+ let ctx = SessionContext :: new ( ) ;
429
+
430
+ // Create a two-column batch
431
+ let schema = Schema :: new ( vec ! [
432
+ Field :: new( "a" , DataType :: Int32 , false ) ,
433
+ Field :: new( "b" , DataType :: Int32 , false ) ,
434
+ ] ) ;
435
+ let batch = RecordBatch :: try_new (
436
+ Arc :: new ( schema. clone ( ) ) ,
437
+ vec ! [
438
+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
439
+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 10 , 20 , 30 ] ) ) ,
440
+ ] ,
441
+ )
442
+ . unwrap ( ) ;
443
+
444
+ // Wrap in a RecordBatchReaderProvider
445
+ let reader =
446
+ RecordBatchIterator :: new ( vec ! [ batch. clone( ) ] . into_iter ( ) . map ( Ok ) , Arc :: new ( schema) ) ;
447
+ let provider = Arc :: new ( RecordBatchReaderProvider :: new ( Box :: new ( reader) ) ) ;
448
+
449
+ // Read table then select only column b (this should push projection into scan)
450
+ let df = ctx. read_table ( provider) . unwrap ( ) ;
451
+ let df_b = df. select ( vec ! [ col( "b" ) ] ) . unwrap ( ) ;
452
+ let results = df_b. collect ( ) . await . unwrap ( ) ;
453
+ assert_eq ! ( results. len( ) , 1 ) ;
454
+ let out_batch = & results[ 0 ] ;
455
+ assert_eq ! ( out_batch. num_columns( ) , 1 ) ;
456
+ assert_eq ! ( out_batch. schema( ) . field( 0 ) . name( ) , "b" ) ;
457
+ let values = out_batch
458
+ . column ( 0 )
459
+ . as_any ( )
460
+ . downcast_ref :: < arrow_array:: Int32Array > ( )
461
+ . unwrap ( ) ;
462
+ assert_eq ! ( values. values( ) , & [ 10 , 20 , 30 ] ) ;
463
+ }
464
+
386
465
fn read_test_table_with_limit (
387
466
ctx : & SessionContext ,
388
467
batch_sizes : Vec < usize > ,
0 commit comments