Skip to content

Commit 6df6ac0

Browse files
committed
feat: support projection pushdown for RecordBatchReader provider and add regression test (fixes #186)
1 parent abcd140 commit 6df6ac0

File tree

1 file changed

+86
-7
lines changed

1 file changed

+86
-7
lines changed

rust/sedona/src/record_batch_reader_provider.rs

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
8484
async fn scan(
8585
&self,
8686
_state: &dyn Session,
87-
_projection: Option<&Vec<usize>>,
87+
projection: Option<&Vec<usize>>,
8888
_filters: &[Expr],
8989
limit: Option<usize>,
9090
) -> Result<Arc<dyn ExecutionPlan>> {
9191
let mut reader_guard = self.reader.lock();
9292
if let Some(reader) = reader_guard.take() {
93-
Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
93+
let projection = projection.cloned();
94+
Ok(Arc::new(RecordBatchReaderExec::new(
95+
reader, limit, projection,
96+
)))
9497
} else {
9598
sedona_internal_err!("Can't scan RecordBatchReader provider more than once")
9699
}
@@ -158,11 +161,25 @@ struct RecordBatchReaderExec {
158161
schema: SchemaRef,
159162
properties: PlanProperties,
160163
limit: Option<usize>,
164+
projection: Option<Vec<usize>>,
161165
}
162166

163167
impl RecordBatchReaderExec {
164-
fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>) -> Self {
165-
let schema = reader.schema();
168+
fn new(
169+
reader: Box<dyn RecordBatchReader + Send>,
170+
limit: Option<usize>,
171+
projection: Option<Vec<usize>>,
172+
) -> Self {
173+
let full_schema = reader.schema();
174+
let schema: SchemaRef = if let Some(indices) = projection.as_ref() {
175+
let fields: Vec<_> = indices
176+
.iter()
177+
.map(|i| full_schema.field(*i).clone())
178+
.collect();
179+
Arc::new(arrow_schema::Schema::new(fields))
180+
} else {
181+
full_schema.clone()
182+
};
166183
let properties = PlanProperties::new(
167184
EquivalenceProperties::new(schema.clone()),
168185
Partitioning::UnknownPartitioning(1),
@@ -175,6 +192,7 @@ impl RecordBatchReaderExec {
175192
schema,
176193
properties,
177194
limit,
195+
projection,
178196
}
179197
}
180198
}
@@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
186204
.field("schema", &self.schema)
187205
.field("properties", &self.properties)
188206
.field("limit", &self.limit)
207+
.field("projection", &self.projection)
189208
.finish()
190209
}
191210
}
@@ -240,16 +259,33 @@ impl ExecutionPlan for RecordBatchReaderExec {
240259
match self.limit {
241260
Some(limit) => {
242261
// Create a row-limited iterator that properly handles row counting
243-
let iter = RowLimitedIterator::new(reader, limit);
262+
let projection = self.projection.clone();
263+
let iter = RowLimitedIterator::new(reader, limit).map(move |res| match res {
264+
Ok(batch) => {
265+
if let Some(indices) = projection.as_ref() {
266+
Ok(batch.project(indices).unwrap())
267+
} else {
268+
Ok(batch)
269+
}
270+
}
271+
Err(e) => Err(e),
272+
});
244273
let stream = Box::pin(futures::stream::iter(iter));
245274
let record_batch_stream =
246275
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
247276
Ok(Box::pin(record_batch_stream))
248277
}
249278
None => {
250279
// No limit, just convert the reader directly to a stream
251-
let iter = reader.map(|item| match item {
252-
Ok(batch) => Ok(batch),
280+
let projection = self.projection.clone();
281+
let iter = reader.map(move |item| match item {
282+
Ok(batch) => {
283+
if let Some(indices) = projection.as_ref() {
284+
Ok(batch.project(indices).unwrap())
285+
} else {
286+
Ok(batch)
287+
}
288+
}
253289
Err(e) => Err(DataFusionError::from(e)),
254290
});
255291
let stream = Box::pin(futures::stream::iter(iter));
@@ -413,3 +449,46 @@ mod test {
413449
None
414450
}
415451
}
452+
453+
#[tokio::test]
454+
async fn test_projection_pushdown() {
455+
use arrow_array::{RecordBatch, RecordBatchIterator};
456+
use arrow_schema::{DataType, Field, Schema};
457+
use datafusion::prelude::col;
458+
use datafusion::prelude::SessionContext;
459+
let ctx = SessionContext::new();
460+
461+
// Create a two-column batch
462+
let schema = Schema::new(vec![
463+
Field::new("a", DataType::Int32, false),
464+
Field::new("b", DataType::Int32, false),
465+
]);
466+
let batch = RecordBatch::try_new(
467+
Arc::new(schema.clone()),
468+
vec![
469+
Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3])),
470+
Arc::new(arrow_array::Int32Array::from(vec![10, 20, 30])),
471+
],
472+
)
473+
.unwrap();
474+
475+
// Wrap in a RecordBatchReaderProvider
476+
let reader =
477+
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), Arc::new(schema));
478+
let provider = Arc::new(RecordBatchReaderProvider::new(Box::new(reader)));
479+
480+
// Read table then select only column b (this should push projection into scan)
481+
let df = ctx.read_table(provider).unwrap();
482+
let df_b = df.select(vec![col("b")]).unwrap();
483+
let results = df_b.collect().await.unwrap();
484+
assert_eq!(results.len(), 1);
485+
let out_batch = &results[0];
486+
assert_eq!(out_batch.num_columns(), 1);
487+
assert_eq!(out_batch.schema().field(0).name(), "b");
488+
let values = out_batch
489+
.column(0)
490+
.as_any()
491+
.downcast_ref::<arrow_array::Int32Array>()
492+
.unwrap();
493+
assert_eq!(values.values(), &[10, 20, 30]);
494+
}

0 commit comments

Comments
 (0)