Skip to content
2 changes: 2 additions & 0 deletions crates/iceberg/src/arrow/delete_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ pub(crate) mod tests {
project_field_ids: vec![],
predicate: None,
deletes: vec![pos_del_1, pos_del_2.clone()],
limit: None,
},
FileScanTask {
start: 0,
Expand All @@ -350,6 +351,7 @@ pub(crate) mod tests {
project_field_ids: vec![],
predicate: None,
deletes: vec![pos_del_3],
limit: None,
},
];

Expand Down
12 changes: 9 additions & 3 deletions crates/iceberg/src/arrow/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ impl ArrowReader {
row_group_filtering_enabled: bool,
row_selection_enabled: bool,
) -> Result<ArrowRecordBatchStream> {
let should_load_page_index =
(row_selection_enabled && task.predicate.is_some()) || !task.deletes.is_empty();
let should_load_page_index = (row_selection_enabled && task.predicate.is_some())
|| !task.deletes.is_empty()
|| task.limit.is_some();

let delete_filter_rx = delete_file_loader.load_deletes(&task.deletes, task.schema.clone());

Expand Down Expand Up @@ -310,6 +311,10 @@ impl ArrowReader {
record_batch_stream_builder.with_row_groups(selected_row_group_indices);
}

if let Some(limit) = task.limit {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I extended should_load_page_index logic and ArrowReaderOptions is initialized with with_page_index(should_load_page_index).

record_batch_stream_builder = record_batch_stream_builder.with_limit(limit);
}

// Build the batch stream and send all the RecordBatches that it generates
// to the requester.
let record_batch_stream =
Expand Down Expand Up @@ -341,7 +346,7 @@ impl ArrowReader {
// Create the record batch stream builder, which wraps the parquet file reader
let record_batch_stream_builder = ParquetRecordBatchStreamBuilder::new_with_options(
parquet_file_reader,
ArrowReaderOptions::new(),
ArrowReaderOptions::new().with_page_index(should_load_page_index),
)
.await?;
Ok(record_batch_stream_builder)
Expand Down Expand Up @@ -1745,6 +1750,7 @@ message schema {
project_field_ids: vec![1],
predicate: Some(predicate.bind(schema, true).unwrap()),
deletes: vec![],
limit: None,
})]
.into_iter(),
)) as FileScanTaskStream;
Expand Down
8 changes: 8 additions & 0 deletions crates/iceberg/src/scan/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub(crate) struct ManifestFileContext {

field_ids: Arc<Vec<i32>>,
bound_predicates: Option<Arc<BoundPredicates>>,
limit: Option<usize>,
object_cache: Arc<ObjectCache>,
snapshot_schema: SchemaRef,
expression_evaluator_cache: Arc<ExpressionEvaluatorCache>,
Expand All @@ -59,6 +60,7 @@ pub(crate) struct ManifestEntryContext {
pub partition_spec_id: i32,
pub snapshot_schema: SchemaRef,
pub delete_file_index: DeleteFileIndex,
pub limit: Option<usize>,
}

impl ManifestFileContext {
Expand All @@ -74,6 +76,7 @@ impl ManifestFileContext {
mut sender,
expression_evaluator_cache,
delete_file_index,
limit,
..
} = self;

Expand All @@ -89,6 +92,7 @@ impl ManifestFileContext {
bound_predicates: bound_predicates.clone(),
snapshot_schema: snapshot_schema.clone(),
delete_file_index: delete_file_index.clone(),
limit,
};

sender
Expand Down Expand Up @@ -128,6 +132,8 @@ impl ManifestEntryContext {
.map(|x| x.as_ref().snapshot_bound_predicate.clone()),

deletes,

limit: self.limit,
})
}
}
Expand All @@ -142,6 +148,7 @@ pub(crate) struct PlanContext {
pub snapshot_schema: SchemaRef,
pub case_sensitive: bool,
pub predicate: Option<Arc<Predicate>>,
pub limit: Option<usize>,
pub snapshot_bound_predicate: Option<Arc<BoundPredicate>>,
pub object_cache: Arc<ObjectCache>,
pub field_ids: Arc<Vec<i32>>,
Expand Down Expand Up @@ -255,6 +262,7 @@ impl PlanContext {
manifest_file: manifest_file.clone(),
bound_predicates,
sender,
limit: self.limit,
object_cache: self.object_cache.clone(),
snapshot_schema: self.snapshot_schema.clone(),
field_ids: self.field_ids.clone(),
Expand Down
136 changes: 136 additions & 0 deletions crates/iceberg/src/scan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ pub struct TableScanBuilder<'a> {
concurrency_limit_manifest_files: usize,
row_group_filtering_enabled: bool,
row_selection_enabled: bool,

limit: Option<usize>,
}

impl<'a> TableScanBuilder<'a> {
Expand All @@ -77,9 +79,16 @@ impl<'a> TableScanBuilder<'a> {
concurrency_limit_manifest_files: num_cpus,
row_group_filtering_enabled: true,
row_selection_enabled: false,
limit: None,
}
}

/// Sets the maximum number of records to return
pub fn with_limit(mut self, limit: Option<usize>) -> Self {
self.limit = limit;
self
}

/// Sets the desired size of batches in the response
/// to something other than the default
pub fn with_batch_size(mut self, batch_size: Option<usize>) -> Self {
Expand Down Expand Up @@ -281,6 +290,7 @@ impl<'a> TableScanBuilder<'a> {
snapshot_schema: schema,
case_sensitive: self.case_sensitive,
predicate: self.filter.map(Arc::new),
limit: self.limit,
snapshot_bound_predicate: snapshot_bound_predicate.map(Arc::new),
object_cache: self.table.object_cache(),
field_ids: Arc::new(field_ids),
Expand Down Expand Up @@ -1406,6 +1416,130 @@ pub mod tests {
assert_eq!(int64_arr.value(0), 2);
}

#[tokio::test]
async fn test_limit() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;

let mut builder = fixture.table.scan();
builder = builder.with_limit(Some(1));
let table_scan = builder.build().unwrap();

let batch_stream = table_scan.to_arrow().await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches.len(), 2);
assert_eq!(batches[0].num_rows(), 1);
assert_eq!(batches[1].num_rows(), 1);

let col = batches[0].column_by_name("x").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 1);

let col = batches[0].column_by_name("y").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 2);

let col = batches[0].column_by_name("x").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 1);

let col = batches[0].column_by_name("y").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 2);
}

#[tokio::test]
async fn test_limit_with_predicate() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;

// Filter: y > 3
let mut builder = fixture.table.scan();
let predicate = Reference::new("y").greater_than(Datum::long(3));
builder = builder.with_filter(predicate).with_limit(Some(1));
let table_scan = builder.build().unwrap();

let batch_stream = table_scan.to_arrow().await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches.len(), 2);
assert_eq!(batches[0].num_rows(), 1);
assert_eq!(batches[1].num_rows(), 1);

let col = batches[0].column_by_name("x").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 1);

let col = batches[0].column_by_name("y").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 4);
}

#[tokio::test]
async fn test_limit_with_predicate_and_row_selection() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;

// Filter: y > 3
let mut builder = fixture.table.scan();
let predicate = Reference::new("y").greater_than(Datum::long(3));
builder = builder
.with_filter(predicate)
.with_limit(Some(1))
.with_row_selection_enabled(true);
let table_scan = builder.build().unwrap();

let batch_stream = table_scan.to_arrow().await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches.len(), 2);
assert_eq!(batches[0].num_rows(), 1);
assert_eq!(batches[1].num_rows(), 1);

let col = batches[0].column_by_name("x").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 1);

let col = batches[0].column_by_name("y").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 4);
}

#[tokio::test]
async fn test_limit_higher_than_total_rows() {
let mut fixture = TableTestFixture::new();
fixture.setup_manifest_files().await;

// Filter: y > 3
let mut builder = fixture.table.scan();
let predicate = Reference::new("y").greater_than(Datum::long(3));
builder = builder
.with_filter(predicate)
.with_limit(Some(100_000_000))
.with_row_selection_enabled(true);
let table_scan = builder.build().unwrap();

let batch_stream = table_scan.to_arrow().await.unwrap();

let batches: Vec<_> = batch_stream.try_collect().await.unwrap();

assert_eq!(batches.len(), 2);
assert_eq!(batches[0].num_rows(), 312);
assert_eq!(batches[1].num_rows(), 312);

let col = batches[0].column_by_name("x").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 1);

let col = batches[0].column_by_name("y").unwrap();
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_arr.value(0), 4);
}

#[tokio::test]
async fn test_filter_on_arrow_gt_eq() {
let mut fixture = TableTestFixture::new();
Expand Down Expand Up @@ -1780,6 +1914,7 @@ pub mod tests {
record_count: Some(100),
data_file_format: DataFileFormat::Parquet,
deletes: vec![],
limit: None,
};
test_fn(task);

Expand All @@ -1794,6 +1929,7 @@ pub mod tests {
record_count: None,
data_file_format: DataFileFormat::Avro,
deletes: vec![],
limit: None,
};
test_fn(task);
}
Expand Down
3 changes: 3 additions & 0 deletions crates/iceberg/src/scan/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ pub struct FileScanTask {

/// The list of delete files that may need to be applied to this data file
pub deletes: Vec<FileScanTaskDeleteFile>,

/// Maximum number of records to return, None means no limit
pub limit: Option<usize>,
}

impl FileScanTask {
Expand Down
14 changes: 12 additions & 2 deletions crates/integrations/datafusion/src/physical_plan/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub struct IcebergTableScan {
projection: Option<Vec<String>>,
/// Filters to apply to the table scan
predicates: Option<Predicate>,
/// Maximum number of records to return, None means no limit
limit: Option<usize>,
}

impl IcebergTableScan {
Expand All @@ -61,6 +63,7 @@ impl IcebergTableScan {
schema: ArrowSchemaRef,
projection: Option<&Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Self {
let output_schema = match projection {
None => schema.clone(),
Expand All @@ -76,6 +79,7 @@ impl IcebergTableScan {
plan_properties,
projection,
predicates,
limit,
}
}

Expand Down Expand Up @@ -143,6 +147,7 @@ impl ExecutionPlan for IcebergTableScan {
self.snapshot_id,
self.projection.clone(),
self.predicates.clone(),
self.limit,
);
let stream = futures::stream::once(fut).try_flatten();

Expand All @@ -161,13 +166,14 @@ impl DisplayAs for IcebergTableScan {
) -> std::fmt::Result {
write!(
f,
"IcebergTableScan projection:[{}] predicate:[{}]",
"IcebergTableScan projection:[{}] predicate:[{}] limit:[{}]",
self.projection
.clone()
.map_or(String::new(), |v| v.join(",")),
self.predicates
.clone()
.map_or(String::from(""), |p| format!("{}", p))
.map_or(String::from(""), |p| format!("{}", p)),
self.limit.map_or(String::from(""), |p| format!("{}", p)),
)
}
}
Expand All @@ -182,6 +188,7 @@ async fn get_batch_stream(
snapshot_id: Option<i64>,
column_names: Option<Vec<String>>,
predicates: Option<Predicate>,
limit: Option<usize>,
) -> DFResult<Pin<Box<dyn Stream<Item = DFResult<RecordBatch>> + Send>>> {
let scan_builder = match snapshot_id {
Some(snapshot_id) => table.scan().snapshot_id(snapshot_id),
Expand All @@ -195,6 +202,9 @@ async fn get_batch_stream(
if let Some(pred) = predicates {
scan_builder = scan_builder.with_filter(pred);
}

scan_builder = scan_builder.with_limit(limit);

let table_scan = scan_builder.build().map_err(to_datafusion_error)?;

let stream = table_scan
Expand Down
3 changes: 2 additions & 1 deletion crates/integrations/datafusion/src/table/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,15 @@ impl TableProvider for IcebergTableProvider {
_state: &dyn Session,
projection: Option<&Vec<usize>>,
filters: &[Expr],
_limit: Option<usize>,
limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(IcebergTableScan::new(
self.table.clone(),
self.snapshot_id,
self.schema.clone(),
projection,
filters,
limit,
)))
}

Expand Down
Loading