Skip to content

Commit ad8263b

Browse files
authored
Fix support for writing to nested field partition (#2204)
Closes #2095 # Rationale for this change Currently, we can only partition on top-level valid field types, but this PR adds support for partitioning on primitive fields in a struct type using dot notation to determine the partitions against the nested structure. # Are these changes tested? Yes added tests and tested a write against the problem in the above issue. ``` > aws s3 ls s3://myBucket/demo1/nestedPartition/data/ PRE timestamp_hour=2024-01-15-10/ PRE timestamp_hour=2024-01-15-11/ PRE timestamp_hour=2024-04-15-11/ PRE timestamp_hour=2024-05-15-10/ ``` # Are there any user-facing changes? no but now can add data to tables that are partitioned by a source column that's in a struct
1 parent 2d7d089 commit ad8263b

File tree

2 files changed

+131
-4
lines changed

2 files changed

+131
-4
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2728,9 +2728,11 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
27282728

27292729
for partition, name in zip(spec.fields, partition_fields):
27302730
source_field = schema.find_field(partition.source_id)
2731-
arrow_table = arrow_table.append_column(
2732-
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
2733-
)
2731+
full_field_name = schema.find_column_name(partition.source_id)
2732+
if full_field_name is None:
2733+
raise ValueError(f"Could not find column name for field ID: {partition.source_id}")
2734+
field_array = _get_field_from_arrow_table(arrow_table, full_field_name)
2735+
arrow_table = arrow_table.append_column(name, partition.transform.pyarrow_transform(source_field.field_type)(field_array))
27342736

27352737
unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])
27362738

@@ -2765,3 +2767,32 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
27652767
)
27662768

27672769
return table_partitions
2770+
2771+
2772+
def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Array:
2773+
"""Get a field from an Arrow table, supporting both literal field names and nested field paths.
2774+
2775+
This function handles two cases:
2776+
1. Literal field names that may contain dots (e.g., "some.id")
2777+
2. Nested field paths using dot notation (e.g., "bar.baz" for nested access)
2778+
2779+
Args:
2780+
arrow_table: The Arrow table containing the field
2781+
field_path: Field name or dot-separated path
2782+
2783+
Returns:
2784+
The field as a PyArrow Array
2785+
2786+
Raises:
2787+
KeyError: If the field path cannot be resolved
2788+
"""
2789+
# Try exact column name match (handles field names containing literal dots)
2790+
if field_path in arrow_table.column_names:
2791+
return arrow_table[field_path]
2792+
2793+
# If not found as exact name, treat as nested field path
2794+
path_parts = field_path.split(".")
2795+
# Get the struct column from the table (e.g., "bar" from "bar.baz")
2796+
field_array = arrow_table[path_parts[0]]
2797+
# Navigate into the struct using the remaining path parts
2798+
return pc.struct_field(field_array, path_parts[1:])

tests/io/test_pyarrow.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484
from pyiceberg.table import FileScanTask, TableProperties
8585
from pyiceberg.table.metadata import TableMetadataV2
8686
from pyiceberg.table.name_mapping import create_mapping_from_schema
87-
from pyiceberg.transforms import IdentityTransform
87+
from pyiceberg.transforms import HourTransform, IdentityTransform
8888
from pyiceberg.typedef import UTF8, Properties, Record
8989
from pyiceberg.types import (
9090
BinaryType,
@@ -2350,6 +2350,102 @@ def test_partition_for_demo() -> None:
23502350
)
23512351

23522352

2353+
def test_partition_for_nested_field() -> None:
2354+
schema = Schema(
2355+
NestedField(id=1, name="foo", field_type=StringType(), required=True),
2356+
NestedField(
2357+
id=2,
2358+
name="bar",
2359+
field_type=StructType(
2360+
NestedField(id=3, name="baz", field_type=TimestampType(), required=False),
2361+
NestedField(id=4, name="qux", field_type=IntegerType(), required=False),
2362+
),
2363+
required=True,
2364+
),
2365+
)
2366+
2367+
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=HourTransform(), name="ts"))
2368+
2369+
from datetime import datetime
2370+
2371+
t1 = datetime(2025, 7, 11, 9, 30, 0)
2372+
t2 = datetime(2025, 7, 11, 10, 30, 0)
2373+
2374+
test_data = [
2375+
{"foo": "a", "bar": {"baz": t1, "qux": 1}},
2376+
{"foo": "b", "bar": {"baz": t2, "qux": 2}},
2377+
]
2378+
2379+
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
2380+
partitions = _determine_partitions(spec, schema, arrow_table)
2381+
partition_values = {p.partition_key.partition[0] for p in partitions}
2382+
2383+
assert partition_values == {486729, 486730}
2384+
2385+
2386+
def test_partition_for_deep_nested_field() -> None:
2387+
schema = Schema(
2388+
NestedField(
2389+
id=1,
2390+
name="foo",
2391+
field_type=StructType(
2392+
NestedField(
2393+
id=2,
2394+
name="bar",
2395+
field_type=StructType(NestedField(id=3, name="baz", field_type=StringType(), required=False)),
2396+
required=True,
2397+
)
2398+
),
2399+
required=True,
2400+
)
2401+
)
2402+
2403+
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="qux"))
2404+
2405+
test_data = [
2406+
{"foo": {"bar": {"baz": "data-1"}}},
2407+
{"foo": {"bar": {"baz": "data-2"}}},
2408+
{"foo": {"bar": {"baz": "data-1"}}},
2409+
]
2410+
2411+
arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
2412+
partitions = _determine_partitions(spec, schema, arrow_table)
2413+
2414+
assert len(partitions) == 2 # 2 unique partitions
2415+
partition_values = {p.partition_key.partition[0] for p in partitions}
2416+
assert partition_values == {"data-1", "data-2"}
2417+
2418+
2419+
def test_inspect_partition_for_nested_field(catalog: InMemoryCatalog) -> None:
2420+
schema = Schema(
2421+
NestedField(id=1, name="foo", field_type=StringType(), required=True),
2422+
NestedField(
2423+
id=2,
2424+
name="bar",
2425+
field_type=StructType(
2426+
NestedField(id=3, name="baz", field_type=StringType(), required=False),
2427+
NestedField(id=4, name="qux", field_type=IntegerType(), required=False),
2428+
),
2429+
required=True,
2430+
),
2431+
)
2432+
spec = PartitionSpec(PartitionField(source_id=3, field_id=1000, transform=IdentityTransform(), name="part"))
2433+
catalog.create_namespace("default")
2434+
table = catalog.create_table("default.test_partition_in_struct", schema=schema, partition_spec=spec)
2435+
test_data = [
2436+
{"foo": "a", "bar": {"baz": "data-a", "qux": 1}},
2437+
{"foo": "b", "bar": {"baz": "data-b", "qux": 2}},
2438+
]
2439+
2440+
arrow_table = pa.Table.from_pylist(test_data, schema=table.schema().as_arrow())
2441+
table.append(arrow_table)
2442+
partitions_table = table.inspect.partitions()
2443+
partitions = partitions_table["partition"].to_pylist()
2444+
2445+
assert len(partitions) == 2
2446+
assert {part["part"] for part in partitions} == {"data-a", "data-b"}
2447+
2448+
23532449
def test_identity_partition_on_multi_columns() -> None:
23542450
test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
23552451
test_schema = Schema(

0 commit comments

Comments
 (0)