Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/sedonadb/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,24 @@ def test_dataframe_to_parquet(con):
)


def test_record_batch_reader_projection(con):
def batches():
for _ in range(3):
yield pa.record_batch({"a": ["a", "b", "c"], "b": [1, 2, 3]})

reader = pa.RecordBatchReader.from_batches(next(batches()).schema, batches())
df = con.create_data_frame(reader)
df.to_view("temp_rbr_proj", overwrite=True)
try:
# Query the view with projection (only select column b)
proj_df = con.sql("SELECT b FROM temp_rbr_proj")
tbl = proj_df.to_arrow_table()
assert tbl.column_names == ["b"]
assert tbl.to_pydict()["b"] == [1, 2, 3] * 3
finally:
con.drop_view("temp_rbr_proj")


def test_show(con, capsys):
con.sql("SELECT 1 as one").show()
expected = """
Expand Down
97 changes: 86 additions & 11 deletions rust/sedona/src/record_batch_reader_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,16 @@ impl TableProvider for RecordBatchReaderProvider {
async fn scan(
&self,
_state: &dyn Session,
_projection: Option<&Vec<usize>>,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let mut reader_guard = self.reader.lock();
if let Some(reader) = reader_guard.take() {
Ok(Arc::new(RecordBatchReaderExec::new(reader, limit)))
let projection = projection.cloned();
Ok(Arc::new(RecordBatchReaderExec::try_new(
reader, limit, projection,
)?))
} else {
sedona_internal_err!("Can't scan RecordBatchReader provider more than once")
}
Expand Down Expand Up @@ -158,24 +161,39 @@ struct RecordBatchReaderExec {
schema: SchemaRef,
properties: PlanProperties,
limit: Option<usize>,
projection: Option<Vec<usize>>,
}

impl RecordBatchReaderExec {
fn new(reader: Box<dyn RecordBatchReader + Send>, limit: Option<usize>) -> Self {
let schema = reader.schema();
fn try_new(
reader: Box<dyn RecordBatchReader + Send>,
limit: Option<usize>,
projection: Option<Vec<usize>>,
) -> Result<Self> {
let full_schema = reader.schema();
let schema: SchemaRef = if let Some(indices) = projection.as_ref() {
SchemaRef::new(
full_schema
.project(indices)
.map_err(DataFusionError::from)?,
)
} else {
full_schema.clone()
};
let properties = PlanProperties::new(
EquivalenceProperties::new(schema.clone()),
Partitioning::UnknownPartitioning(1),
EmissionType::Incremental,
Boundedness::Bounded,
);

Self {
Ok(Self {
reader: Mutex::new(Some(reader)),
schema,
properties,
limit,
}
projection,
})
}
}

Expand All @@ -186,6 +204,7 @@ impl Debug for RecordBatchReaderExec {
.field("schema", &self.schema)
.field("properties", &self.properties)
.field("limit", &self.limit)
.field("projection", &self.projection)
.finish()
}
}
Expand Down Expand Up @@ -240,17 +259,34 @@ impl ExecutionPlan for RecordBatchReaderExec {
match self.limit {
Some(limit) => {
// Create a row-limited iterator that properly handles row counting
let iter = RowLimitedIterator::new(reader, limit);
let projection = self.projection.clone();
let iter = RowLimitedIterator::new(reader, limit).map(move |res| match res {
Ok(batch) => {
if let Some(indices) = projection.as_ref() {
batch.project(indices).map_err(|e| e.into())
} else {
Ok(batch)
}
}
Err(e) => Err(e),
});
let stream = Box::pin(futures::stream::iter(iter));
let record_batch_stream =
RecordBatchStreamAdapter::new(self.schema.clone(), stream);
Ok(Box::pin(record_batch_stream))
}
None => {
// No limit, just convert the reader directly to a stream
let iter = reader.map(|item| match item {
Ok(batch) => Ok(batch),
Err(e) => Err(DataFusionError::from(e)),
let projection = self.projection.clone();
let iter = reader.map(move |item| match item {
Ok(batch) => {
if let Some(indices) = projection.as_ref() {
batch.project(indices).map_err(|e| e.into())
} else {
Ok(batch)
}
}
Err(e) => Err(e.into()),
});
let stream = Box::pin(futures::stream::iter(iter));
let record_batch_stream =
Expand All @@ -266,7 +302,7 @@ mod test {

use arrow_array::{RecordBatch, RecordBatchIterator};
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::{DataFrame, SessionContext};
use datafusion::prelude::{col, DataFrame, SessionContext};
use rstest::rstest;
use sedona_schema::datatypes::WKB_GEOMETRY;
use sedona_testing::create::create_array_storage;
Expand Down Expand Up @@ -383,6 +419,45 @@ mod test {
}
}

#[tokio::test]
async fn test_projection_pushdown() {
let ctx = SessionContext::new();

// Create a two-column batch
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
]);
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3])),
Arc::new(arrow_array::Int32Array::from(vec![10, 20, 30])),
],
)
.unwrap();

// Wrap in a RecordBatchReaderProvider
let reader =
RecordBatchIterator::new(vec![batch.clone()].into_iter().map(Ok), Arc::new(schema));
let provider = Arc::new(RecordBatchReaderProvider::new(Box::new(reader)));

// Read table then select only column b (this should push projection into scan)
let df = ctx.read_table(provider).unwrap();
let df_b = df.select(vec![col("b")]).unwrap();
let results = df_b.collect().await.unwrap();
assert_eq!(results.len(), 1);
let out_batch = &results[0];
assert_eq!(out_batch.num_columns(), 1);
assert_eq!(out_batch.schema().field(0).name(), "b");
let values = out_batch
.column(0)
.as_any()
.downcast_ref::<arrow_array::Int32Array>()
.unwrap();
assert_eq!(values.values(), &[10, 20, 30]);
}

fn read_test_table_with_limit(
ctx: &SessionContext,
batch_sizes: Vec<usize>,
Expand Down