@@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
84
84
async fn scan (
85
85
& self ,
86
86
_state : & dyn Session ,
87
- _projection : Option < & Vec < usize > > ,
87
+ projection : Option < & Vec < usize > > ,
88
88
_filters : & [ Expr ] ,
89
89
limit : Option < usize > ,
90
90
) -> Result < Arc < dyn ExecutionPlan > > {
91
91
let mut reader_guard = self . reader . lock ( ) ;
92
92
if let Some ( reader) = reader_guard. take ( ) {
93
- Ok ( Arc :: new ( RecordBatchReaderExec :: new ( reader, limit) ) )
93
+ let projection = projection. cloned ( ) ;
94
+ Ok ( Arc :: new ( RecordBatchReaderExec :: new (
95
+ reader, limit, projection,
96
+ ) ) )
94
97
} else {
95
98
sedona_internal_err ! ( "Can't scan RecordBatchReader provider more than once" )
96
99
}
@@ -158,11 +161,25 @@ struct RecordBatchReaderExec {
158
161
schema : SchemaRef ,
159
162
properties : PlanProperties ,
160
163
limit : Option < usize > ,
164
+ projection : Option < Vec < usize > > ,
161
165
}
162
166
163
167
impl RecordBatchReaderExec {
164
- fn new ( reader : Box < dyn RecordBatchReader + Send > , limit : Option < usize > ) -> Self {
165
- let schema = reader. schema ( ) ;
168
+ fn new (
169
+ reader : Box < dyn RecordBatchReader + Send > ,
170
+ limit : Option < usize > ,
171
+ projection : Option < Vec < usize > > ,
172
+ ) -> Self {
173
+ let full_schema = reader. schema ( ) ;
174
+ let schema: SchemaRef = if let Some ( indices) = projection. as_ref ( ) {
175
+ let fields: Vec < _ > = indices
176
+ . iter ( )
177
+ . map ( |i| full_schema. field ( * i) . clone ( ) )
178
+ . collect ( ) ;
179
+ Arc :: new ( arrow_schema:: Schema :: new ( fields) )
180
+ } else {
181
+ full_schema. clone ( )
182
+ } ;
166
183
let properties = PlanProperties :: new (
167
184
EquivalenceProperties :: new ( schema. clone ( ) ) ,
168
185
Partitioning :: UnknownPartitioning ( 1 ) ,
@@ -175,6 +192,7 @@ impl RecordBatchReaderExec {
175
192
schema,
176
193
properties,
177
194
limit,
195
+ projection,
178
196
}
179
197
}
180
198
}
@@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
186
204
. field ( "schema" , & self . schema )
187
205
. field ( "properties" , & self . properties )
188
206
. field ( "limit" , & self . limit )
207
+ . field ( "projection" , & self . projection )
189
208
. finish ( )
190
209
}
191
210
}
@@ -240,16 +259,33 @@ impl ExecutionPlan for RecordBatchReaderExec {
240
259
match self . limit {
241
260
Some ( limit) => {
242
261
// Create a row-limited iterator that properly handles row counting
243
- let iter = RowLimitedIterator :: new ( reader, limit) ;
262
+ let projection = self . projection . clone ( ) ;
263
+ let iter = RowLimitedIterator :: new ( reader, limit) . map ( move |res| match res {
264
+ Ok ( batch) => {
265
+ if let Some ( indices) = projection. as_ref ( ) {
266
+ Ok ( batch. project ( indices) . unwrap ( ) )
267
+ } else {
268
+ Ok ( batch)
269
+ }
270
+ }
271
+ Err ( e) => Err ( e) ,
272
+ } ) ;
244
273
let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
245
274
let record_batch_stream =
246
275
RecordBatchStreamAdapter :: new ( self . schema . clone ( ) , stream) ;
247
276
Ok ( Box :: pin ( record_batch_stream) )
248
277
}
249
278
None => {
250
279
// No limit, just convert the reader directly to a stream
251
- let iter = reader. map ( |item| match item {
252
- Ok ( batch) => Ok ( batch) ,
280
+ let projection = self . projection . clone ( ) ;
281
+ let iter = reader. map ( move |item| match item {
282
+ Ok ( batch) => {
283
+ if let Some ( indices) = projection. as_ref ( ) {
284
+ Ok ( batch. project ( indices) . unwrap ( ) )
285
+ } else {
286
+ Ok ( batch)
287
+ }
288
+ }
253
289
Err ( e) => Err ( DataFusionError :: from ( e) ) ,
254
290
} ) ;
255
291
let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
@@ -413,3 +449,46 @@ mod test {
413
449
None
414
450
}
415
451
}
452
+
453
+ #[ tokio:: test]
454
+ async fn test_projection_pushdown ( ) {
455
+ use arrow_array:: { RecordBatch , RecordBatchIterator } ;
456
+ use arrow_schema:: { DataType , Field , Schema } ;
457
+ use datafusion:: prelude:: col;
458
+ use datafusion:: prelude:: SessionContext ;
459
+ let ctx = SessionContext :: new ( ) ;
460
+
461
+ // Create a two-column batch
462
+ let schema = Schema :: new ( vec ! [
463
+ Field :: new( "a" , DataType :: Int32 , false ) ,
464
+ Field :: new( "b" , DataType :: Int32 , false ) ,
465
+ ] ) ;
466
+ let batch = RecordBatch :: try_new (
467
+ Arc :: new ( schema. clone ( ) ) ,
468
+ vec ! [
469
+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
470
+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 10 , 20 , 30 ] ) ) ,
471
+ ] ,
472
+ )
473
+ . unwrap ( ) ;
474
+
475
+ // Wrap in a RecordBatchReaderProvider
476
+ let reader =
477
+ RecordBatchIterator :: new ( vec ! [ batch. clone( ) ] . into_iter ( ) . map ( Ok ) , Arc :: new ( schema) ) ;
478
+ let provider = Arc :: new ( RecordBatchReaderProvider :: new ( Box :: new ( reader) ) ) ;
479
+
480
+ // Read table then select only column b (this should push projection into scan)
481
+ let df = ctx. read_table ( provider) . unwrap ( ) ;
482
+ let df_b = df. select ( vec ! [ col( "b" ) ] ) . unwrap ( ) ;
483
+ let results = df_b. collect ( ) . await . unwrap ( ) ;
484
+ assert_eq ! ( results. len( ) , 1 ) ;
485
+ let out_batch = & results[ 0 ] ;
486
+ assert_eq ! ( out_batch. num_columns( ) , 1 ) ;
487
+ assert_eq ! ( out_batch. schema( ) . field( 0 ) . name( ) , "b" ) ;
488
+ let values = out_batch
489
+ . column ( 0 )
490
+ . as_any ( )
491
+ . downcast_ref :: < arrow_array:: Int32Array > ( )
492
+ . unwrap ( ) ;
493
+ assert_eq ! ( values. values( ) , & [ 10 , 20 , 30 ] ) ;
494
+ }
0 commit comments