@@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
84
84
async fn scan (
85
85
& self ,
86
86
_state : & dyn Session ,
87
- _projection : Option < & Vec < usize > > ,
87
+ projection : Option < & Vec < usize > > ,
88
88
_filters : & [ Expr ] ,
89
89
limit : Option < usize > ,
90
90
) -> Result < Arc < dyn ExecutionPlan > > {
91
91
let mut reader_guard = self . reader . lock ( ) ;
92
92
if let Some ( reader) = reader_guard. take ( ) {
93
- Ok ( Arc :: new ( RecordBatchReaderExec :: new ( reader, limit) ) )
93
+ let projection = projection. cloned ( ) ;
94
+ Ok ( Arc :: new ( RecordBatchReaderExec :: try_new (
95
+ reader, limit, projection,
96
+ ) ?) )
94
97
} else {
95
98
sedona_internal_err ! ( "Can't scan RecordBatchReader provider more than once" )
96
99
}
@@ -158,24 +161,39 @@ struct RecordBatchReaderExec {
158
161
schema : SchemaRef ,
159
162
properties : PlanProperties ,
160
163
limit : Option < usize > ,
164
+ projection : Option < Vec < usize > > ,
161
165
}
162
166
163
167
impl RecordBatchReaderExec {
164
- fn new ( reader : Box < dyn RecordBatchReader + Send > , limit : Option < usize > ) -> Self {
165
- let schema = reader. schema ( ) ;
168
+ fn try_new (
169
+ reader : Box < dyn RecordBatchReader + Send > ,
170
+ limit : Option < usize > ,
171
+ projection : Option < Vec < usize > > ,
172
+ ) -> Result < Self > {
173
+ let full_schema = reader. schema ( ) ;
174
+ let schema: SchemaRef = if let Some ( indices) = projection. as_ref ( ) {
175
+ SchemaRef :: new (
176
+ full_schema
177
+ . project ( indices)
178
+ . map_err ( DataFusionError :: from) ?,
179
+ )
180
+ } else {
181
+ full_schema. clone ( )
182
+ } ;
166
183
let properties = PlanProperties :: new (
167
184
EquivalenceProperties :: new ( schema. clone ( ) ) ,
168
185
Partitioning :: UnknownPartitioning ( 1 ) ,
169
186
EmissionType :: Incremental ,
170
187
Boundedness :: Bounded ,
171
188
) ;
172
189
173
- Self {
190
+ Ok ( Self {
174
191
reader : Mutex :: new ( Some ( reader) ) ,
175
192
schema,
176
193
properties,
177
194
limit,
178
- }
195
+ projection,
196
+ } )
179
197
}
180
198
}
181
199
@@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
186
204
. field ( "schema" , & self . schema )
187
205
. field ( "properties" , & self . properties )
188
206
. field ( "limit" , & self . limit )
207
+ . field ( "projection" , & self . projection )
189
208
. finish ( )
190
209
}
191
210
}
@@ -240,17 +259,34 @@ impl ExecutionPlan for RecordBatchReaderExec {
240
259
match self . limit {
241
260
Some ( limit) => {
242
261
// Create a row-limited iterator that properly handles row counting
243
- let iter = RowLimitedIterator :: new ( reader, limit) ;
262
+ let projection = self . projection . clone ( ) ;
263
+ let iter = RowLimitedIterator :: new ( reader, limit) . map ( move |res| match res {
264
+ Ok ( batch) => {
265
+ if let Some ( indices) = projection. as_ref ( ) {
266
+ batch. project ( indices) . map_err ( |e| e. into ( ) )
267
+ } else {
268
+ Ok ( batch)
269
+ }
270
+ }
271
+ Err ( e) => Err ( e) ,
272
+ } ) ;
244
273
let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
245
274
let record_batch_stream =
246
275
RecordBatchStreamAdapter :: new ( self . schema . clone ( ) , stream) ;
247
276
Ok ( Box :: pin ( record_batch_stream) )
248
277
}
249
278
None => {
250
279
// No limit, just convert the reader directly to a stream
251
- let iter = reader. map ( |item| match item {
252
- Ok ( batch) => Ok ( batch) ,
253
- Err ( e) => Err ( DataFusionError :: from ( e) ) ,
280
+ let projection = self . projection . clone ( ) ;
281
+ let iter = reader. map ( move |item| match item {
282
+ Ok ( batch) => {
283
+ if let Some ( indices) = projection. as_ref ( ) {
284
+ batch. project ( indices) . map_err ( |e| e. into ( ) )
285
+ } else {
286
+ Ok ( batch)
287
+ }
288
+ }
289
+ Err ( e) => Err ( e. into ( ) ) ,
254
290
} ) ;
255
291
let stream = Box :: pin ( futures:: stream:: iter ( iter) ) ;
256
292
let record_batch_stream =
@@ -266,7 +302,7 @@ mod test {
266
302
267
303
use arrow_array:: { RecordBatch , RecordBatchIterator } ;
268
304
use arrow_schema:: { DataType , Field , Schema } ;
269
- use datafusion:: prelude:: { DataFrame , SessionContext } ;
305
+ use datafusion:: prelude:: { col , DataFrame , SessionContext } ;
270
306
use rstest:: rstest;
271
307
use sedona_schema:: datatypes:: WKB_GEOMETRY ;
272
308
use sedona_testing:: create:: create_array_storage;
@@ -383,6 +419,45 @@ mod test {
383
419
}
384
420
}
385
421
422
+ #[ tokio:: test]
423
+ async fn test_projection_pushdown ( ) {
424
+ let ctx = SessionContext :: new ( ) ;
425
+
426
+ // Create a two-column batch
427
+ let schema = Schema :: new ( vec ! [
428
+ Field :: new( "a" , DataType :: Int32 , false ) ,
429
+ Field :: new( "b" , DataType :: Int32 , false ) ,
430
+ ] ) ;
431
+ let batch = RecordBatch :: try_new (
432
+ Arc :: new ( schema. clone ( ) ) ,
433
+ vec ! [
434
+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ,
435
+ Arc :: new( arrow_array:: Int32Array :: from( vec![ 10 , 20 , 30 ] ) ) ,
436
+ ] ,
437
+ )
438
+ . unwrap ( ) ;
439
+
440
+ // Wrap in a RecordBatchReaderProvider
441
+ let reader =
442
+ RecordBatchIterator :: new ( vec ! [ batch. clone( ) ] . into_iter ( ) . map ( Ok ) , Arc :: new ( schema) ) ;
443
+ let provider = Arc :: new ( RecordBatchReaderProvider :: new ( Box :: new ( reader) ) ) ;
444
+
445
+ // Read table then select only column b (this should push projection into scan)
446
+ let df = ctx. read_table ( provider) . unwrap ( ) ;
447
+ let df_b = df. select ( vec ! [ col( "b" ) ] ) . unwrap ( ) ;
448
+ let results = df_b. collect ( ) . await . unwrap ( ) ;
449
+ assert_eq ! ( results. len( ) , 1 ) ;
450
+ let out_batch = & results[ 0 ] ;
451
+ assert_eq ! ( out_batch. num_columns( ) , 1 ) ;
452
+ assert_eq ! ( out_batch. schema( ) . field( 0 ) . name( ) , "b" ) ;
453
+ let values = out_batch
454
+ . column ( 0 )
455
+ . as_any ( )
456
+ . downcast_ref :: < arrow_array:: Int32Array > ( )
457
+ . unwrap ( ) ;
458
+ assert_eq ! ( values. values( ) , & [ 10 , 20 , 30 ] ) ;
459
+ }
460
+
386
461
fn read_test_table_with_limit (
387
462
ctx : & SessionContext ,
388
463
batch_sizes : Vec < usize > ,
0 commit comments