Skip to content

Commit acab5f6

Browse files
committed
Implement limit push down for IcebergTableProvider (#19)
Original PR: #19 Upstream PR: apache#1673
1 parent ac7f053 commit acab5f6

File tree

7 files changed

+172
-6
lines changed

7 files changed

+172
-6
lines changed

crates/iceberg/src/arrow/delete_filter.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ pub(crate) mod tests {
339339
project_field_ids: vec![],
340340
predicate: None,
341341
deletes: vec![pos_del_1, pos_del_2.clone()],
342+
limit: None,
342343
},
343344
FileScanTask {
344345
start: 0,
@@ -350,6 +351,7 @@ pub(crate) mod tests {
350351
project_field_ids: vec![],
351352
predicate: None,
352353
deletes: vec![pos_del_3],
354+
limit: None,
353355
},
354356
];
355357

crates/iceberg/src/arrow/reader.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,9 @@ impl ArrowReader {
176176
row_group_filtering_enabled: bool,
177177
row_selection_enabled: bool,
178178
) -> Result<ArrowRecordBatchStream> {
179-
let should_load_page_index =
180-
(row_selection_enabled && task.predicate.is_some()) || !task.deletes.is_empty();
179+
let should_load_page_index = (row_selection_enabled && task.predicate.is_some())
180+
|| !task.deletes.is_empty()
181+
|| task.limit.is_some();
181182

182183
let delete_filter_rx = delete_file_loader.load_deletes(&task.deletes, task.schema.clone());
183184

@@ -310,6 +311,10 @@ impl ArrowReader {
310311
record_batch_stream_builder.with_row_groups(selected_row_group_indices);
311312
}
312313

314+
if let Some(limit) = task.limit {
315+
record_batch_stream_builder = record_batch_stream_builder.with_limit(limit);
316+
}
317+
313318
// Build the batch stream and send all the RecordBatches that it generates
314319
// to the requester.
315320
let record_batch_stream =
@@ -341,7 +346,7 @@ impl ArrowReader {
341346
// Create the record batch stream builder, which wraps the parquet file reader
342347
let record_batch_stream_builder = ParquetRecordBatchStreamBuilder::new_with_options(
343348
parquet_file_reader,
344-
ArrowReaderOptions::new(),
349+
ArrowReaderOptions::new().with_page_index(should_load_page_index),
345350
)
346351
.await?;
347352
Ok(record_batch_stream_builder)
@@ -1745,6 +1750,7 @@ message schema {
17451750
project_field_ids: vec![1],
17461751
predicate: Some(predicate.bind(schema, true).unwrap()),
17471752
deletes: vec![],
1753+
limit: None,
17481754
})]
17491755
.into_iter(),
17501756
)) as FileScanTaskStream;

crates/iceberg/src/scan/context.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pub(crate) struct ManifestFileContext {
4242

4343
field_ids: Arc<Vec<i32>>,
4444
bound_predicates: Option<Arc<BoundPredicates>>,
45+
limit: Option<usize>,
4546
object_cache: Arc<ObjectCache>,
4647
snapshot_schema: SchemaRef,
4748
expression_evaluator_cache: Arc<ExpressionEvaluatorCache>,
@@ -59,6 +60,7 @@ pub(crate) struct ManifestEntryContext {
5960
pub partition_spec_id: i32,
6061
pub snapshot_schema: SchemaRef,
6162
pub delete_file_index: DeleteFileIndex,
63+
pub limit: Option<usize>,
6264
}
6365

6466
impl ManifestFileContext {
@@ -74,6 +76,7 @@ impl ManifestFileContext {
7476
mut sender,
7577
expression_evaluator_cache,
7678
delete_file_index,
79+
limit,
7780
..
7881
} = self;
7982

@@ -89,6 +92,7 @@ impl ManifestFileContext {
8992
bound_predicates: bound_predicates.clone(),
9093
snapshot_schema: snapshot_schema.clone(),
9194
delete_file_index: delete_file_index.clone(),
95+
limit,
9296
};
9397

9498
sender
@@ -128,6 +132,8 @@ impl ManifestEntryContext {
128132
.map(|x| x.as_ref().snapshot_bound_predicate.clone()),
129133

130134
deletes,
135+
136+
limit: self.limit,
131137
})
132138
}
133139
}
@@ -142,6 +148,7 @@ pub(crate) struct PlanContext {
142148
pub snapshot_schema: SchemaRef,
143149
pub case_sensitive: bool,
144150
pub predicate: Option<Arc<Predicate>>,
151+
pub limit: Option<usize>,
145152
pub snapshot_bound_predicate: Option<Arc<BoundPredicate>>,
146153
pub object_cache: Arc<ObjectCache>,
147154
pub field_ids: Arc<Vec<i32>>,
@@ -255,6 +262,7 @@ impl PlanContext {
255262
manifest_file: manifest_file.clone(),
256263
bound_predicates,
257264
sender,
265+
limit: self.limit,
258266
object_cache: self.object_cache.clone(),
259267
snapshot_schema: self.snapshot_schema.clone(),
260268
field_ids: self.field_ids.clone(),

crates/iceberg/src/scan/mod.rs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ pub struct TableScanBuilder<'a> {
5959
concurrency_limit_manifest_files: usize,
6060
row_group_filtering_enabled: bool,
6161
row_selection_enabled: bool,
62+
63+
limit: Option<usize>,
6264
}
6365

6466
impl<'a> TableScanBuilder<'a> {
@@ -77,9 +79,16 @@ impl<'a> TableScanBuilder<'a> {
7779
concurrency_limit_manifest_files: num_cpus,
7880
row_group_filtering_enabled: true,
7981
row_selection_enabled: false,
82+
limit: None,
8083
}
8184
}
8285

86+
/// Sets the maximum number of records to return
87+
pub fn with_limit(mut self, limit: Option<usize>) -> Self {
88+
self.limit = limit;
89+
self
90+
}
91+
8392
/// Sets the desired size of batches in the response
8493
/// to something other than the default
8594
pub fn with_batch_size(mut self, batch_size: Option<usize>) -> Self {
@@ -281,6 +290,7 @@ impl<'a> TableScanBuilder<'a> {
281290
snapshot_schema: schema,
282291
case_sensitive: self.case_sensitive,
283292
predicate: self.filter.map(Arc::new),
293+
limit: self.limit,
284294
snapshot_bound_predicate: snapshot_bound_predicate.map(Arc::new),
285295
object_cache: self.table.object_cache(),
286296
field_ids: Arc::new(field_ids),
@@ -1406,6 +1416,130 @@ pub mod tests {
14061416
assert_eq!(int64_arr.value(0), 2);
14071417
}
14081418

1419+
#[tokio::test]
1420+
async fn test_limit() {
1421+
let mut fixture = TableTestFixture::new();
1422+
fixture.setup_manifest_files().await;
1423+
1424+
let mut builder = fixture.table.scan();
1425+
builder = builder.with_limit(Some(1));
1426+
let table_scan = builder.build().unwrap();
1427+
1428+
let batch_stream = table_scan.to_arrow().await.unwrap();
1429+
1430+
let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
1431+
1432+
assert_eq!(batches.len(), 2);
1433+
assert_eq!(batches[0].num_rows(), 1);
1434+
assert_eq!(batches[1].num_rows(), 1);
1435+
1436+
let col = batches[0].column_by_name("x").unwrap();
1437+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1438+
assert_eq!(int64_arr.value(0), 1);
1439+
1440+
let col = batches[0].column_by_name("y").unwrap();
1441+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1442+
assert_eq!(int64_arr.value(0), 2);
1443+
1444+
let col = batches[0].column_by_name("x").unwrap();
1445+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1446+
assert_eq!(int64_arr.value(0), 1);
1447+
1448+
let col = batches[0].column_by_name("y").unwrap();
1449+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1450+
assert_eq!(int64_arr.value(0), 2);
1451+
}
1452+
1453+
#[tokio::test]
1454+
async fn test_limit_with_predicate() {
1455+
let mut fixture = TableTestFixture::new();
1456+
fixture.setup_manifest_files().await;
1457+
1458+
// Filter: y > 3
1459+
let mut builder = fixture.table.scan();
1460+
let predicate = Reference::new("y").greater_than(Datum::long(3));
1461+
builder = builder.with_filter(predicate).with_limit(Some(1));
1462+
let table_scan = builder.build().unwrap();
1463+
1464+
let batch_stream = table_scan.to_arrow().await.unwrap();
1465+
1466+
let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
1467+
1468+
assert_eq!(batches.len(), 2);
1469+
assert_eq!(batches[0].num_rows(), 1);
1470+
assert_eq!(batches[1].num_rows(), 1);
1471+
1472+
let col = batches[0].column_by_name("x").unwrap();
1473+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1474+
assert_eq!(int64_arr.value(0), 1);
1475+
1476+
let col = batches[0].column_by_name("y").unwrap();
1477+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1478+
assert_eq!(int64_arr.value(0), 4);
1479+
}
1480+
1481+
#[tokio::test]
1482+
async fn test_limit_with_predicate_and_row_selection() {
1483+
let mut fixture = TableTestFixture::new();
1484+
fixture.setup_manifest_files().await;
1485+
1486+
// Filter: y > 3
1487+
let mut builder = fixture.table.scan();
1488+
let predicate = Reference::new("y").greater_than(Datum::long(3));
1489+
builder = builder
1490+
.with_filter(predicate)
1491+
.with_limit(Some(1))
1492+
.with_row_selection_enabled(true);
1493+
let table_scan = builder.build().unwrap();
1494+
1495+
let batch_stream = table_scan.to_arrow().await.unwrap();
1496+
1497+
let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
1498+
1499+
assert_eq!(batches.len(), 2);
1500+
assert_eq!(batches[0].num_rows(), 1);
1501+
assert_eq!(batches[1].num_rows(), 1);
1502+
1503+
let col = batches[0].column_by_name("x").unwrap();
1504+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1505+
assert_eq!(int64_arr.value(0), 1);
1506+
1507+
let col = batches[0].column_by_name("y").unwrap();
1508+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1509+
assert_eq!(int64_arr.value(0), 4);
1510+
}
1511+
1512+
#[tokio::test]
1513+
async fn test_limit_higher_than_total_rows() {
1514+
let mut fixture = TableTestFixture::new();
1515+
fixture.setup_manifest_files().await;
1516+
1517+
// Filter: y > 3
1518+
let mut builder = fixture.table.scan();
1519+
let predicate = Reference::new("y").greater_than(Datum::long(3));
1520+
builder = builder
1521+
.with_filter(predicate)
1522+
.with_limit(Some(100_000_000))
1523+
.with_row_selection_enabled(true);
1524+
let table_scan = builder.build().unwrap();
1525+
1526+
let batch_stream = table_scan.to_arrow().await.unwrap();
1527+
1528+
let batches: Vec<_> = batch_stream.try_collect().await.unwrap();
1529+
1530+
assert_eq!(batches.len(), 2);
1531+
assert_eq!(batches[0].num_rows(), 312);
1532+
assert_eq!(batches[1].num_rows(), 312);
1533+
1534+
let col = batches[0].column_by_name("x").unwrap();
1535+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1536+
assert_eq!(int64_arr.value(0), 1);
1537+
1538+
let col = batches[0].column_by_name("y").unwrap();
1539+
let int64_arr = col.as_any().downcast_ref::<Int64Array>().unwrap();
1540+
assert_eq!(int64_arr.value(0), 4);
1541+
}
1542+
14091543
#[tokio::test]
14101544
async fn test_filter_on_arrow_gt_eq() {
14111545
let mut fixture = TableTestFixture::new();
@@ -1780,6 +1914,7 @@ pub mod tests {
17801914
record_count: Some(100),
17811915
data_file_format: DataFileFormat::Parquet,
17821916
deletes: vec![],
1917+
limit: None,
17831918
};
17841919
test_fn(task);
17851920

@@ -1794,6 +1929,7 @@ pub mod tests {
17941929
record_count: None,
17951930
data_file_format: DataFileFormat::Avro,
17961931
deletes: vec![],
1932+
limit: None,
17971933
};
17981934
test_fn(task);
17991935
}

crates/iceberg/src/scan/task.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ pub struct FileScanTask {
5454

5555
/// The list of delete files that may need to be applied to this data file
5656
pub deletes: Vec<FileScanTaskDeleteFile>,
57+
58+
/// Maximum number of records to return, None means no limit
59+
pub limit: Option<usize>,
5760
}
5861

5962
impl FileScanTask {

crates/integrations/datafusion/src/physical_plan/scan.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ pub struct IcebergTableScan {
5151
projection: Option<Vec<String>>,
5252
/// Filters to apply to the table scan
5353
predicates: Option<Predicate>,
54+
/// Maximum number of records to return, None means no limit
55+
limit: Option<usize>,
5456
}
5557

5658
impl IcebergTableScan {
@@ -61,6 +63,7 @@ impl IcebergTableScan {
6163
schema: ArrowSchemaRef,
6264
projection: Option<&Vec<usize>>,
6365
filters: &[Expr],
66+
limit: Option<usize>,
6467
) -> Self {
6568
let output_schema = match projection {
6669
None => schema.clone(),
@@ -76,6 +79,7 @@ impl IcebergTableScan {
7679
plan_properties,
7780
projection,
7881
predicates,
82+
limit,
7983
}
8084
}
8185

@@ -143,6 +147,7 @@ impl ExecutionPlan for IcebergTableScan {
143147
self.snapshot_id,
144148
self.projection.clone(),
145149
self.predicates.clone(),
150+
self.limit,
146151
);
147152
let stream = futures::stream::once(fut).try_flatten();
148153

@@ -161,13 +166,14 @@ impl DisplayAs for IcebergTableScan {
161166
) -> std::fmt::Result {
162167
write!(
163168
f,
164-
"IcebergTableScan projection:[{}] predicate:[{}]",
169+
"IcebergTableScan projection:[{}] predicate:[{}] limit:[{}]",
165170
self.projection
166171
.clone()
167172
.map_or(String::new(), |v| v.join(",")),
168173
self.predicates
169174
.clone()
170-
.map_or(String::from(""), |p| format!("{}", p))
175+
.map_or(String::from(""), |p| format!("{}", p)),
176+
self.limit.map_or(String::from(""), |p| format!("{}", p)),
171177
)
172178
}
173179
}
@@ -182,6 +188,7 @@ async fn get_batch_stream(
182188
snapshot_id: Option<i64>,
183189
column_names: Option<Vec<String>>,
184190
predicates: Option<Predicate>,
191+
limit: Option<usize>,
185192
) -> DFResult<Pin<Box<dyn Stream<Item = DFResult<RecordBatch>> + Send>>> {
186193
let scan_builder = match snapshot_id {
187194
Some(snapshot_id) => table.scan().snapshot_id(snapshot_id),
@@ -195,6 +202,9 @@ async fn get_batch_stream(
195202
if let Some(pred) = predicates {
196203
scan_builder = scan_builder.with_filter(pred);
197204
}
205+
206+
scan_builder = scan_builder.with_limit(limit);
207+
198208
let table_scan = scan_builder.build().map_err(to_datafusion_error)?;
199209

200210
let stream = table_scan

crates/integrations/datafusion/src/table/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ impl TableProvider for IcebergTableProvider {
153153
_state: &dyn Session,
154154
projection: Option<&Vec<usize>>,
155155
filters: &[Expr],
156-
_limit: Option<usize>,
156+
limit: Option<usize>,
157157
) -> DFResult<Arc<dyn ExecutionPlan>> {
158158
// Get the latest table metadata from the catalog if it exists
159159
let table = if let Some(catalog) = &self.catalog {
@@ -172,6 +172,7 @@ impl TableProvider for IcebergTableProvider {
172172
self.schema.clone(),
173173
projection,
174174
filters,
175+
limit,
175176
)))
176177
}
177178

0 commit comments

Comments
 (0)