Skip to content

implement stageOnly Commit #2269

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,9 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
name_mapping=self.table_metadata.name_mapping(),
)

def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
def update_snapshot(
self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None, stage_only: bool = False
) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.

Returns:
Expand All @@ -441,7 +443,9 @@ def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, bran
if branch is None:
branch = MAIN_BRANCH

return UpdateSnapshot(self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties)
return UpdateSnapshot(
self, io=self._table.io, branch=branch, snapshot_properties=snapshot_properties, stage_only=stage_only
)

def update_statistics(self) -> UpdateStatistics:
"""
Expand Down
58 changes: 39 additions & 19 deletions pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class _SnapshotProducer(UpdateTableMetadata[U], Generic[U]):
_deleted_data_files: Set[DataFile]
_compression: AvroCompressionCodec
_target_branch = MAIN_BRANCH
_stage_only = False

def __init__(
self,
Expand All @@ -118,6 +119,7 @@ def __init__(
commit_uuid: Optional[uuid.UUID] = None,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
branch: str = MAIN_BRANCH,
stage_only: bool = False,
Comment on lines 121 to +122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think API wise, this makes more sense. In the case branch is None then we don't set the ref, if it is not None then we set the ref.

Suggested change
branch: str = MAIN_BRANCH,
stage_only: bool = False,
branch: Optional[str] = MAIN_BRANCH

@kevinjqliu WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that make sense to me. Instead of setting stage_only=True, it would be branch=None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense, thanks for taking a look!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at it a bit more, I think that's the way forward.

def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = None) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.
Returns:
A new UpdateSnapshot
"""
if branch is None:
branch = MAIN_BRANCH

We would change that into:

    def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT, branch: Optional[str] = MAIN_BRANCH) -> UpdateSnapshot:
        """Create a new UpdateSnapshot to produce a new snapshot for the table.

        Returns:
            A new UpdateSnapshot
        """

There are a couple more places where we need to change the default to MAIN_BRANCH. Let me know what you think!

) -> None:
super().__init__(transaction)
self.commit_uuid = commit_uuid or uuid.uuid4()
Expand All @@ -137,6 +139,7 @@ def __init__(
self._parent_snapshot_id = (
snapshot.snapshot_id if (snapshot := self._transaction.table_metadata.snapshot_by_name(self._target_branch)) else None
)
self._stage_only = stage_only

def _validate_target_branch(self, branch: str) -> str:
# Default is already set to MAIN_BRANCH. So branch name can't be None.
Expand Down Expand Up @@ -292,25 +295,33 @@ def _commit(self) -> UpdatesAndRequirements:
schema_id=self._transaction.table_metadata.current_schema_id,
)

return (
(
AddSnapshotUpdate(snapshot=snapshot),
SetSnapshotRefUpdate(
snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id,
ref_name=self._target_branch,
type=SnapshotRefType.BRANCH,
add_snapshot_update = AddSnapshotUpdate(snapshot=snapshot)

if self._stage_only:
return (
(add_snapshot_update,),
(),
)
else:
return (
(
add_snapshot_update,
SetSnapshotRefUpdate(
snapshot_id=self._snapshot_id,
parent_snapshot_id=self._parent_snapshot_id,
ref_name=self._target_branch,
type=SnapshotRefType.BRANCH,
),
),
),
(
AssertRefSnapshotId(
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
if self._target_branch in self._transaction.table_metadata.refs
else None,
ref=self._target_branch,
(
AssertRefSnapshotId(
snapshot_id=self._transaction.table_metadata.refs[self._target_branch].snapshot_id
if self._target_branch in self._transaction.table_metadata.refs
else None,
ref=self._target_branch,
),
),
),
)
)

@property
def snapshot_id(self) -> int:
Expand Down Expand Up @@ -360,8 +371,9 @@ def __init__(
branch: str,
commit_uuid: Optional[uuid.UUID] = None,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
stage_only: bool = False,
):
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only)
self._predicate = AlwaysFalse()
self._case_sensitive = True

Expand Down Expand Up @@ -530,10 +542,11 @@ def __init__(
branch: str,
commit_uuid: Optional[uuid.UUID] = None,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
stage_only: bool = False,
) -> None:
from pyiceberg.table import TableProperties

super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch)
super().__init__(operation, transaction, io, commit_uuid, snapshot_properties, branch, stage_only)
self._target_size_bytes = property_as_int(
self._transaction.table_metadata.properties,
TableProperties.MANIFEST_TARGET_SIZE_BYTES,
Expand Down Expand Up @@ -649,19 +662,22 @@ class UpdateSnapshot:
_transaction: Transaction
_io: FileIO
_branch: str
_stage_only: bool
_snapshot_properties: Dict[str, str]

def __init__(
self,
transaction: Transaction,
io: FileIO,
branch: str,
stage_only: bool = False,
snapshot_properties: Dict[str, str] = EMPTY_DICT,
) -> None:
self._transaction = transaction
self._io = io
self._snapshot_properties = snapshot_properties
self._branch = branch
self._stage_only = stage_only

def fast_append(self) -> _FastAppendFiles:
return _FastAppendFiles(
Expand All @@ -670,6 +686,7 @@ def fast_append(self) -> _FastAppendFiles:
io=self._io,
branch=self._branch,
snapshot_properties=self._snapshot_properties,
stage_only=self._stage_only,
)

def merge_append(self) -> _MergeAppendFiles:
Expand All @@ -679,6 +696,7 @@ def merge_append(self) -> _MergeAppendFiles:
io=self._io,
branch=self._branch,
snapshot_properties=self._snapshot_properties,
stage_only=self._stage_only,
)

def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles:
Expand All @@ -691,6 +709,7 @@ def overwrite(self, commit_uuid: Optional[uuid.UUID] = None) -> _OverwriteFiles:
io=self._io,
branch=self._branch,
snapshot_properties=self._snapshot_properties,
stage_only=self._stage_only,
)

def delete(self) -> _DeleteFiles:
Expand All @@ -700,6 +719,7 @@ def delete(self) -> _DeleteFiles:
io=self._io,
branch=self._branch,
snapshot_properties=self._snapshot_properties,
stage_only=self._stage_only,
)


Expand Down
171 changes: 171 additions & 0 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,3 +2098,174 @@ def test_branch_py_write_spark_read(session_catalog: Catalog, spark: SparkSessio
)
assert main_df.count() == 3
assert branch_df.count() == 2


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_stage_only_delete(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = f"default.test_stage_only_delete_files_v{format_version}"
iceberg_spec = PartitionSpec(
*[PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="integer_partition")]
)
tbl = _create_table(
session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null], iceberg_spec
)

current_snapshot = tbl.metadata.current_snapshot_id
assert current_snapshot is not None

original_count = len(tbl.scan().to_arrow())
assert original_count == 3

files_to_delete = []
for file_task in tbl.scan().plan_files():
files_to_delete.append(file_task.file)
assert len(files_to_delete) > 0

with tbl.transaction() as txn:
with txn.update_snapshot(stage_only=True).delete() as delete:
delete.delete_by_predicate(EqualTo("int", 9))

# a new delete snapshot is added
snapshots = tbl.snapshots()
assert len(snapshots) == 2

rows = spark.sql(
f"""
SELECT operation, summary
FROM {identifier}.snapshots
ORDER BY committed_at ASC
"""
).collect()
operations = [row.operation for row in rows]
assert operations == ["append", "delete"]

# snapshot main ref has not changed
assert current_snapshot == tbl.metadata.current_snapshot_id
assert len(tbl.scan().to_arrow()) == original_count


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_stage_only_fast_append(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = f"default.test_stage_only_fast_append_files_v{format_version}"
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])

current_snapshot = tbl.metadata.current_snapshot_id
assert current_snapshot is not None

original_count = len(tbl.scan().to_arrow())
assert original_count == 3

with tbl.transaction() as txn:
with txn.update_snapshot(stage_only=True).fast_append() as fast_append:
for data_file in _dataframe_to_data_files(
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
):
fast_append.append_data_file(data_file=data_file)

# Main ref has not changed and data is not yet appended
assert current_snapshot == tbl.metadata.current_snapshot_id
assert len(tbl.scan().to_arrow()) == original_count

# There should be a new staged snapshot
snapshots = tbl.snapshots()
assert len(snapshots) == 2

rows = spark.sql(
f"""
SELECT operation, summary
FROM {identifier}.snapshots
ORDER BY committed_at ASC
"""
).collect()
operations = [row.operation for row in rows]
assert operations == ["append", "append"]


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_stage_only_merge_append(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = f"default.test_stage_only_merge_append_files_v{format_version}"
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])

current_snapshot = tbl.metadata.current_snapshot_id
assert current_snapshot is not None

original_count = len(tbl.scan().to_arrow())
assert original_count == 3

with tbl.transaction() as txn:
with txn.update_snapshot(stage_only=True).merge_append() as merge_append:
for data_file in _dataframe_to_data_files(
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
):
merge_append.append_data_file(data_file=data_file)

# Main ref has not changed and data is not yet appended
assert current_snapshot == tbl.metadata.current_snapshot_id
assert len(tbl.scan().to_arrow()) == original_count

# There should be a new staged snapshot
snapshots = tbl.snapshots()
assert len(snapshots) == 2

rows = spark.sql(
f"""
SELECT operation, summary
FROM {identifier}.snapshots
ORDER BY committed_at ASC
"""
).collect()
operations = [row.operation for row in rows]
assert operations == ["append", "append"]


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_stage_only_overwrite_files(
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
) -> None:
identifier = f"default.test_stage_only_overwrite_files_v{format_version}"
tbl = _create_table(session_catalog, identifier, {"format-version": str(format_version)}, [arrow_table_with_null])

current_snapshot = tbl.metadata.current_snapshot_id
assert current_snapshot is not None

original_count = len(tbl.scan().to_arrow())
assert original_count == 3

files_to_delete = []
for file_task in tbl.scan().plan_files():
files_to_delete.append(file_task.file)
assert len(files_to_delete) > 0

with tbl.transaction() as txn:
with txn.update_snapshot(stage_only=True).overwrite() as overwrite:
for data_file in _dataframe_to_data_files(
table_metadata=txn.table_metadata, df=arrow_table_with_null, io=txn._table.io
):
overwrite.append_data_file(data_file=data_file)
overwrite.delete_data_file(files_to_delete[0])

assert current_snapshot == tbl.metadata.current_snapshot_id
assert len(tbl.scan().to_arrow()) == original_count

snapshots = tbl.snapshots()
assert len(snapshots) == 2

rows = spark.sql(
f"""
SELECT operation, summary
FROM {identifier}.snapshots
ORDER BY committed_at ASC
"""
).collect()
operations = [row.operation for row in rows]
assert operations == ["append", "overwrite"]