Skip to content

Commit f79d821

Browse files
authored
fix: Support projection pushdown for RecordBatchReader provider (fixes #186) (#197)
This fixes #186. The provider ignored the projection indices passed to TableProvider::scan(), so the physical plan schema ([a,b]) did not match the pushed-down logical projection ([b]). This PR implements projection pushdown for RecordBatchReaderProvider and adds a regression test.
1 parent 95c156d commit f79d821

File tree

2 files changed

+104
-11
lines changed

2 files changed

+104
-11
lines changed

python/sedonadb/tests/test_dataframe.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,24 @@ def test_dataframe_to_parquet(con):
342342
)
343343

344344

345+
def test_record_batch_reader_projection(con):
346+
def batches():
347+
for _ in range(3):
348+
yield pa.record_batch({"a": ["a", "b", "c"], "b": [1, 2, 3]})
349+
350+
reader = pa.RecordBatchReader.from_batches(next(batches()).schema, batches())
351+
df = con.create_data_frame(reader)
352+
df.to_view("temp_rbr_proj", overwrite=True)
353+
try:
354+
# Query the view with projection (only select column b)
355+
proj_df = con.sql("SELECT b FROM temp_rbr_proj")
356+
tbl = proj_df.to_arrow_table()
357+
assert tbl.column_names == ["b"]
358+
assert tbl.to_pydict()["b"] == [1, 2, 3] * 3
359+
finally:
360+
con.drop_view("temp_rbr_proj")
361+
362+
345363
def test_show(con, capsys):
346364
con.sql("SELECT 1 as one").show()
347365
expected = """

rust/sedona/src/record_batch_reader_provider.rs

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
8484
async fn scan(
8585
&self,
8686
_state: &dyn Session,
87-
_projection: Option<&Vec<usize>>,
87+
projection: Option<&Vec<usize>>,
8888
_filters: &[Expr],
8989
limit: Option<usize>,
9090
) -> Result<Arc<dyn ExecutionPlan>> {
9191
let mut reader_guard = self.reader.lock();
9292
if let Some(reader) = reader_guard.take() {
93-
Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
93+
let projection = projection.cloned();
94+
Ok(Arc::new(RecordBatchReaderExec::try_new(
95+
reader, limit, projection,
96+
)?))
9497
} else {
9598
sedona_internal_err!("Can't scan RecordBatchReader provider more than once")
9699
}
@@ -158,24 +161,39 @@ struct RecordBatchReaderExec {
158161
schema: SchemaRef,
159162
properties: PlanProperties,
160163
limit: Option<usize>,
164+
projection: Option<Vec<usize>>,
161165
}
162166

163167
impl RecordBatchReaderExec {
164-
fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>) -> Self {
165-
let schema = reader.schema();
168+
fn try_new(
169+
reader: Box<dyn RecordBatchReader + Send>,
170+
limit: Option<usize>,
171+
projection: Option<Vec<usize>>,
172+
) -> Result<Self> {
173+
let full_schema = reader.schema();
174+
let schema: SchemaRef = if let Some(indices) = projection.as_ref() {
175+
SchemaRef::new(
176+
full_schema
177+
.project(indices)
178+
.map_err(DataFusionError::from)?,
179+
)
180+
} else {
181+
full_schema.clone()
182+
};
166183
let properties = PlanProperties::new(
167184
EquivalenceProperties::new(schema.clone()),
168185
Partitioning::UnknownPartitioning(1),
169186
EmissionType::Incremental,
170187
Boundedness::Bounded,
171188
);
172189

173-
Self {
190+
Ok(Self {
174191
reader: Mutex::new(Some(reader)),
175192
schema,
176193
properties,
177194
limit,
178-
}
195+
projection,
196+
})
179197
}
180198
}
181199

@@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
186204
.field("schema", &self.schema)
187205
.field("properties", &self.properties)
188206
.field("limit", &self.limit)
207+
.field("projection", &self.projection)
189208
.finish()
190209
}
191210
}
@@ -240,17 +259,34 @@ impl ExecutionPlan for RecordBatchReaderExec {
240259
match self.limit {
241260
Some(limit) => {
242261
// Create a row-limited iterator that properly handles row counting
243-
let iter = RowLimitedIterator::new(reader, limit);
262+
let projection = self.projection.clone();
263+
let iter = RowLimitedIterator::new(reader, limit).map(move |res| match res {
264+
Ok(batch) => {
265+
if let Some(indices) = projection.as_ref() {
266+
batch.project(indices).map_err(|e| e.into())
267+
} else {
268+
Ok(batch)
269+
}
270+
}
271+
Err(e) => Err(e),
272+
});
244273
let stream = Box::pin(futures::stream::iter(iter));
245274
let record_batch_stream =
246275
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
247276
Ok(Box::pin(record_batch_stream))
248277
}
249278
None => {
250279
// No limit, just convert the reader directly to a stream
251-
let iter = reader.map(|item| match item {
252-
Ok(batch) => Ok(batch),
253-
Err(e) => Err(DataFusionError::from(e)),
280+
let projection = self.projection.clone();
281+
let iter = reader.map(move |item| match item {
282+
Ok(batch) => {
283+
if let Some(indices) = projection.as_ref() {
284+
batch.project(indices).map_err(|e| e.into())
285+
} else {
286+
Ok(batch)
287+
}
288+
}
289+
Err(e) => Err(e.into()),
254290
});
255291
let stream = Box::pin(futures::stream::iter(iter));
256292
let record_batch_stream =
@@ -266,7 +302,7 @@ mod test {
266302

267303
use arrow_array::{RecordBatch, RecordBatchIterator};
268304
use arrow_schema::{DataType, Field, Schema};
269-
use datafusion::prelude::{DataFrame, SessionContext};
305+
use datafusion::prelude::{col, DataFrame, SessionContext};
270306
use rstest::rstest;
271307
use sedona_schema::datatypes::WKB_GEOMETRY;
272308
use sedona_testing::create::create_array_storage;
@@ -383,6 +419,45 @@ mod test {
383419
}
384420
}
385421

422+
#[tokio::test]
423+
async fn test_projection_pushdown() {
424+
let ctx = SessionContext::new();
425+
426+
// Create a two-column batch
427+
let schema = Schema::new(vec![
428+
Field::new("a", DataType::Int32, false),
429+
Field::new("b", DataType::Int32, false),
430+
]);
431+
let batch = RecordBatch::try_new(
432+
Arc::new(schema.clone()),
433+
vec![
434+
Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3])),
435+
Arc::new(arrow_array::Int32Array::from(vec![10, 20, 30])),
436+
],
437+
)
438+
.unwrap();
439+
440+
// Wrap in a RecordBatchReaderProvider
441+
let reader =
442+
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), Arc::new(schema));
443+
let provider = Arc::new(RecordBatchReaderProvider::new(Box::new(reader)));
444+
445+
// Read table then select only column b (this should push projection into scan)
446+
let df = ctx.read_table(provider).unwrap();
447+
let df_b = df.select(vec![col("b")]).unwrap();
448+
let results = df_b.collect().await.unwrap();
449+
assert_eq!(results.len(), 1);
450+
let out_batch = &results[0];
451+
assert_eq!(out_batch.num_columns(), 1);
452+
assert_eq!(out_batch.schema().field(0).name(), "b");
453+
let values = out_batch
454+
.column(0)
455+
.as_any()
456+
.downcast_ref::<arrow_array::Int32Array>()
457+
.unwrap();
458+
assert_eq!(values.values(), &[10, 20, 30]);
459+
}
460+
386461
fn read_test_table_with_limit(
387462
ctx: &SessionContext,
388463
batch_sizes: Vec<usize>,

0 commit comments

Comments
 (0)