diff --git a/python/sedonadb/tests/test_dataframe.py b/python/sedonadb/tests/test_dataframe.py index b6096357..60ac9953 100644 --- a/python/sedonadb/tests/test_dataframe.py +++ b/python/sedonadb/tests/test_dataframe.py @@ -342,6 +342,24 @@ def test_dataframe_to_parquet(con): ) +def test_record_batch_reader_projection(con): + def batches(): + for _ in range(3): + yield pa.record_batch({"a": ["a", "b", "c"], "b": [1, 2, 3]}) + + reader = pa.RecordBatchReader.from_batches(next(batches()).schema, batches()) + df = con.create_data_frame(reader) + df.to_view("temp_rbr_proj", overwrite=True) + try: + # Query the view with projection (only select column b) + proj_df = con.sql("SELECT b FROM temp_rbr_proj") + tbl = proj_df.to_arrow_table() + assert tbl.column_names == ["b"] + assert tbl.to_pydict()["b"] == [1, 2, 3] * 3 + finally: + con.drop_view("temp_rbr_proj") + + def test_show(con, capsys): con.sql("SELECT 1 as one").show() expected = """ diff --git a/rust/sedona/src/record_batch_reader_provider.rs b/rust/sedona/src/record_batch_reader_provider.rs index 1832e93e..e197f89d 100644 --- a/rust/sedona/src/record_batch_reader_provider.rs +++ b/rust/sedona/src/record_batch_reader_provider.rs @@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider { async fn scan( &self, _state: &dyn Session, - _projection: Option<&Vec>, + projection: Option<&Vec>, _filters: &[Expr], limit: Option, ) -> Result> { let mut reader_guard = self.reader.lock(); if let Some(reader) = reader_guard.take() { - Ok(Arc::new(RecordBatchReaderExec::new(reader, limit))) + let projection = projection.cloned(); + Ok(Arc::new(RecordBatchReaderExec::try_new( + reader, limit, projection, + )?)) } else { sedona_internal_err!("Can't scan RecordBatchReader provider more than once") } @@ -158,11 +161,25 @@ struct RecordBatchReaderExec { schema: SchemaRef, properties: PlanProperties, limit: Option, + projection: Option>, } impl RecordBatchReaderExec { - fn new(reader: Box, limit: Option) -> Self { - let schema = reader.schema(); + fn try_new( + reader: Box, + limit: Option, + projection: Option>, + ) -> Result { + let full_schema = reader.schema(); + let schema: SchemaRef = if let Some(indices) = projection.as_ref() { + SchemaRef::new( + full_schema + .project(indices) + .map_err(DataFusionError::from)?, + ) + } else { + full_schema.clone() + }; let properties = PlanProperties::new( EquivalenceProperties::new(schema.clone()), Partitioning::UnknownPartitioning(1), @@ -170,12 +187,13 @@ impl RecordBatchReaderExec { Boundedness::Bounded, ); - Self { + Ok(Self { reader: Mutex::new(Some(reader)), schema, properties, limit, - } + projection, + }) } } @@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec { .field("schema", &self.schema) .field("properties", &self.properties) .field("limit", &self.limit) + .field("projection", &self.projection) .finish() } } @@ -240,7 +259,17 @@ impl ExecutionPlan for RecordBatchReaderExec { match self.limit { Some(limit) => { // Create a row-limited iterator that properly handles row counting - let iter = RowLimitedIterator::new(reader, limit); + let projection = self.projection.clone(); + let iter = RowLimitedIterator::new(reader, limit).map(move |res| match res { + Ok(batch) => { + if let Some(indices) = projection.as_ref() { + batch.project(indices).map_err(|e| e.into()) + } else { + Ok(batch) + } + } + Err(e) => Err(e), + }); let stream = Box::pin(futures::stream::iter(iter)); let record_batch_stream = RecordBatchStreamAdapter::new(self.schema.clone(), stream); @@ -248,9 +277,16 @@ impl ExecutionPlan for RecordBatchReaderExec { } None => { // No limit, just convert the reader directly to a stream - let iter = reader.map(|item| match item { - Ok(batch) => Ok(batch), - Err(e) => Err(DataFusionError::from(e)), + let projection = self.projection.clone(); + let iter = reader.map(move |item| match item { + Ok(batch) => { + if let Some(indices) = projection.as_ref() { + batch.project(indices).map_err(|e| e.into()) + } else { + Ok(batch) + } + } + Err(e) => Err(e.into()), }); let stream = Box::pin(futures::stream::iter(iter)); let record_batch_stream = @@ -266,7 +302,7 @@ mod test { use arrow_array::{RecordBatch, RecordBatchIterator}; use arrow_schema::{DataType, Field, Schema}; - use datafusion::prelude::{DataFrame, SessionContext}; + use datafusion::prelude::{col, DataFrame, SessionContext}; use rstest::rstest; use sedona_schema::datatypes::WKB_GEOMETRY; use sedona_testing::create::create_array_storage; @@ -383,6 +419,45 @@ mod test { } } + #[tokio::test] + async fn test_projection_pushdown() { + let ctx = SessionContext::new(); + + // Create a two-column batch + let schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3])), + Arc::new(arrow_array::Int32Array::from(vec![10, 20, 30])), + ], + ) + .unwrap(); + + // Wrap in a RecordBatchReaderProvider + let reader = + RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), Arc::new(schema)); + let provider = Arc::new(RecordBatchReaderProvider::new(Box::new(reader))); + + // Read table then select only column b (this should push projection into scan) + let df = ctx.read_table(provider).unwrap(); + let df_b = df.select(vec![col("b")]).unwrap(); + let results = df_b.collect().await.unwrap(); + assert_eq!(results.len(), 1); + let out_batch = &results[0]; + assert_eq!(out_batch.num_columns(), 1); + assert_eq!(out_batch.schema().field(0).name(), "b"); + let values = out_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.values(), &[10, 20, 30]); + } + fn read_test_table_with_limit( ctx: &SessionContext, batch_sizes: Vec,