Skip to content

Commit 0b3b35a

Browse files
committed
feat: support projection pushdown for RecordBatchReader provider and add regression test (fixes #186)
1 parent abcd140 commit 0b3b35a

File tree

1 file changed

+86
-7
lines changed

1 file changed

+86
-7
lines changed

rust/sedona/src/record_batch_reader_provider.rs

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
8484
async fn scan(
8585
&self,
8686
_state: &dyn Session,
87-
_projection: Option<&Vec<usize>>,
87+
projection: Option<&Vec<usize>>,
8888
_filters: &[Expr],
8989
limit: Option<usize>,
9090
) -> Result<Arc<dyn ExecutionPlan>> {
9191
let mut reader_guard = self.reader.lock();
9292
if let Some(reader) = reader_guard.take() {
93-
Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
93+
let projection = projection.cloned();
94+
Ok(Arc::new(RecordBatchReaderExec::new(
95+
reader, limit, projection,
96+
)))
9497
} else {
9598
sedona_internal_err!("Can't scan RecordBatchReader provider more than once")
9699
}
@@ -158,11 +161,25 @@ struct RecordBatchReaderExec {
158161
schema: SchemaRef,
159162
properties: PlanProperties,
160163
limit: Option<usize>,
164+
projection: Option<Vec<usize>>,
161165
}
162166

163167
impl RecordBatchReaderExec {
164-
fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>) -> Self {
165-
let schema = reader.schema();
168+
fn new(
169+
reader: Box<dyn RecordBatchReader + Send>,
170+
limit: Option<usize>,
171+
projection: Option<Vec<usize>>,
172+
) -> Self {
173+
let full_schema = reader.schema();
174+
let schema: SchemaRef = if let Some(indices) = projection.as_ref() {
175+
let fields: Vec<_> = indices
176+
.iter()
177+
.map(|i| full_schema.field(*i).clone())
178+
.collect();
179+
Arc::new(arrow_schema::Schema::new(fields))
180+
} else {
181+
full_schema.clone()
182+
};
166183
let properties = PlanProperties::new(
167184
EquivalenceProperties::new(schema.clone()),
168185
Partitioning::UnknownPartitioning(1),
@@ -175,6 +192,7 @@ impl RecordBatchReaderExec {
175192
schema,
176193
properties,
177194
limit,
195+
projection,
178196
}
179197
}
180198
}
@@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
186204
.field("schema", &self.schema)
187205
.field("properties", &self.properties)
188206
.field("limit", &self.limit)
207+
.field("projection", &self.projection)
189208
.finish()
190209
}
191210
}
@@ -240,16 +259,33 @@ impl ExecutionPlan for RecordBatchReaderExec {
240259
match self.limit {
241260
Some(limit) => {
242261
// Create a row-limited iterator that properly handles row counting
243-
let iter = RowLimitedIterator::new(reader, limit);
262+
let projection = self.projection.clone();
263+
let iter = RowLimitedIterator::new(reader, limit).map(move |res| match res {
264+
Ok(batch) => {
265+
if let Some(indices) = projection.as_ref() {
266+
Ok(batch.project(indices).unwrap())
267+
} else {
268+
Ok(batch)
269+
}
270+
}
271+
Err(e) => Err(e),
272+
});
244273
let stream = Box::pin(futures::stream::iter(iter));
245274
let record_batch_stream =
246275
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
247276
Ok(Box::pin(record_batch_stream))
248277
}
249278
None => {
250279
// No limit, just convert the reader directly to a stream
251-
let iter = reader.map(|item| match item {
252-
Ok(batch) => Ok(batch),
280+
let projection = self.projection.clone();
281+
let iter = reader.map(move |item| match item {
282+
Ok(batch) => {
283+
if let Some(indices) = projection.as_ref() {
284+
Ok(batch.project(indices).unwrap())
285+
} else {
286+
Ok(batch)
287+
}
288+
}
253289
Err(e) => Err(DataFusionError::from(e)),
254290
});
255291
let stream = Box::pin(futures::stream::iter(iter));
@@ -383,6 +419,49 @@ mod test {
383419
}
384420
}
385421

422+
#[tokio::test]
423+
async fn test_projection_pushdown() {
424+
use arrow_array::{RecordBatch, RecordBatchIterator};
425+
use arrow_schema::{DataType, Field, Schema};
426+
use datafusion::prelude::col;
427+
use datafusion::prelude::SessionContext;
428+
let ctx = SessionContext::new();
429+
430+
// Create a two-column batch
431+
let schema = Schema::new(vec![
432+
Field::new("a", DataType::Int32, false),
433+
Field::new("b", DataType::Int32, false),
434+
]);
435+
let batch = RecordBatch::try_new(
436+
Arc::new(schema.clone()),
437+
vec![
438+
Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3])),
439+
Arc::new(arrow_array::Int32Array::from(vec![10, 20, 30])),
440+
],
441+
)
442+
.unwrap();
443+
444+
// Wrap in a RecordBatchReaderProvider
445+
let reader =
446+
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), Arc::new(schema));
447+
let provider = Arc::new(RecordBatchReaderProvider::new(Box::new(reader)));
448+
449+
// Read table then select only column b (this should push projection into scan)
450+
let df = ctx.read_table(provider).unwrap();
451+
let df_b = df.select(vec![col("b")]).unwrap();
452+
let results = df_b.collect().await.unwrap();
453+
assert_eq!(results.len(), 1);
454+
let out_batch = &results[0];
455+
assert_eq!(out_batch.num_columns(), 1);
456+
assert_eq!(out_batch.schema().field(0).name(), "b");
457+
let values = out_batch
458+
.column(0)
459+
.as_any()
460+
.downcast_ref::<arrow_array::Int32Array>()
461+
.unwrap();
462+
assert_eq!(values.values(), &[10, 20, 30]);
463+
}
464+
386465
fn read_test_table_with_limit(
387466
ctx: &SessionContext,
388467
batch_sizes: Vec<usize>,

0 commit comments

Comments
 (0)