Skip to content

Commit 28524f0

Browse files
committed
feat: support projection pushdown for RecordBatchReader provider and add regression test (fixes #186)
1 parent abcd140 commit 28524f0

File tree

1 file changed

+87
-7
lines changed

1 file changed

+87
-7
lines changed

rust/sedona/src/record_batch_reader_provider.rs

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,14 @@ impl TableProvider for RecordBatchReaderProvider {
8484
async fn scan(
8585
&self,
8686
_state: &dyn Session,
87-
_projection: Option<&Vec<usize>>,
87+
projection: Option<&Vec<usize>>,
8888
_filters: &[Expr],
8989
limit: Option<usize>,
9090
) -> Result<Arc<dyn ExecutionPlan>> {
9191
let mut reader_guard = self.reader.lock();
9292
if let Some(reader) = reader_guard.take() {
93-
Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
93+
let projection = projection.cloned();
94+
Ok(Arc::new(RecordBatchReaderExec::new(reader, limit, projection)))
9495
} else {
9596
sedona_internal_err!("Can't scan RecordBatchReader provider more than once")
9697
}
@@ -158,11 +159,18 @@ struct RecordBatchReaderExec {
158159
schema: SchemaRef,
159160
properties: PlanProperties,
160161
limit: Option<usize>,
162+
projection: Option<Vec<usize>>,
161163
}
162164

163165
impl RecordBatchReaderExec {
164-
fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>) -> Self {
165-
let schema = reader.schema();
166+
fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>, projection: Option<Vec<usize>>) -> Self {
167+
let full_schema = reader.schema();
168+
let schema: SchemaRef = if let Some(indices) = projection.as_ref() {
169+
let fields: Vec<_> = indices.iter().map(|i| full_schema.field(*i).clone()).collect();
170+
Arc::new(arrow_schema::Schema::new(fields))
171+
} else {
172+
full_schema.clone()
173+
};
166174
let properties = PlanProperties::new(
167175
EquivalenceProperties::new(schema.clone()),
168176
Partitioning::UnknownPartitioning(1),
@@ -175,6 +183,7 @@ impl RecordBatchReaderExec {
175183
schema,
176184
properties,
177185
limit,
186+
projection,
178187
}
179188
}
180189
}
@@ -186,6 +195,7 @@ impl Debug for RecordBatchReaderExec {
186195
.field("schema", &self.schema)
187196
.field("properties", &self.properties)
188197
.field("limit", &self.limit)
198+
.field("projection", &self.projection)
189199
.finish()
190200
}
191201
}
@@ -240,16 +250,33 @@ impl ExecutionPlan for RecordBatchReaderExec {
240250
match self.limit {
241251
Some(limit) => {
242252
// Create a row-limited iterator that properly handles row counting
243-
let iter = RowLimitedIterator::new(reader, limit);
253+
let projection = self.projection.clone();
254+
let iter = RowLimitedIterator::new(reader, limit).map(move |res| match res {
255+
Ok(batch) => {
256+
if let Some(indices) = projection.as_ref() {
257+
Ok(batch.project(indices).unwrap())
258+
} else {
259+
Ok(batch)
260+
}
261+
}
262+
Err(e) => Err(e),
263+
});
244264
let stream = Box::pin(futures::stream::iter(iter));
245265
let record_batch_stream =
246266
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
247267
Ok(Box::pin(record_batch_stream))
248268
}
249269
None => {
250270
// No limit, just convert the reader directly to a stream
251-
let iter = reader.map(|item| match item {
252-
Ok(batch) => Ok(batch),
271+
let projection = self.projection.clone();
272+
let iter = reader.map(move |item| match item {
273+
Ok(batch) => {
274+
if let Some(indices) = projection.as_ref() {
275+
Ok(batch.project(indices).unwrap())
276+
} else {
277+
Ok(batch)
278+
}
279+
}
253280
Err(e) => Err(DataFusionError::from(e)),
254281
});
255282
let stream = Box::pin(futures::stream::iter(iter));
@@ -413,3 +440,56 @@ mod test {
413440
None
414441
}
415442
}
443+
444+
445+
446+
447+
448+
449+
450+
451+
452+
453+
454+
455+
#[tokio::test]
456+
async fn test_projection_pushdown() {
457+
use arrow_array::{RecordBatch, RecordBatchIterator};
458+
use arrow_schema::{DataType, Field, Schema};
459+
use datafusion::prelude::SessionContext;
460+
use datafusion::prelude::col;
461+
let ctx = SessionContext::new();
462+
463+
// Create a two-column batch
464+
let schema = Schema::new(vec![
465+
Field::new("a", DataType::Int32, false),
466+
Field::new("b", DataType::Int32, false),
467+
]);
468+
let batch = RecordBatch::try_new(
469+
Arc::new(schema.clone()),
470+
vec![
471+
Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3])),
472+
Arc::new(arrow_array::Int32Array::from(vec![10, 20, 30])),
473+
],
474+
)
475+
.unwrap();
476+
477+
// Wrap in a RecordBatchReaderProvider
478+
let reader = RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), Arc::new(schema));
479+
let provider = Arc::new(RecordBatchReaderProvider::new(Box::new(reader)));
480+
481+
// Read table then select only column b (this should push projection into scan)
482+
let df = ctx.read_table(provider).unwrap();
483+
let df_b = df.select(vec![col("b")]).unwrap();
484+
let results = df_b.collect().await.unwrap();
485+
assert_eq!(results.len(), 1);
486+
let out_batch = &results[0];
487+
assert_eq!(out_batch.num_columns(), 1);
488+
assert_eq!(out_batch.schema().field(0).name(), "b");
489+
let values = out_batch
490+
.column(0)
491+
.as_any()
492+
.downcast_ref::<arrow_array::Int32Array>()
493+
.unwrap();
494+
assert_eq!(values.values(), &[10, 20, 30]);
495+
}

0 commit comments

Comments
 (0)