Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.11.30] - 2025-11-26

### Fixed

- Roll back earlier change altering metadata format, which was observed to cause
breakages.

## [0.11.29] - 2025-11-25

### Fixed

- Fix `step_from_checkpoint_name` to allow the passed in checkpoint name to
include an arbitrary `step_prefix` with any character(s) such as underscores.
- Fix CheckpointManager initial directory creation to use `file_options.path_permission_mode`.
- Fix using jax.eval_shape with StandardRestore

### Changed

- Validate checkpoints before writing merged OCDBT database using in-memory
state, avoiding additional I/O to re-read metadata.
- add `support_format` to utils.to_shape_dtype_struct()
- Moved `register_pathways_handlers` to `ocp.pathways.register_type_handlers`.
- Replace usage of `get_json_tpec_read` and delegate functionality to new
function `build_array_read_spec` which constructs and returns an
`ArrayReadSpec`.

## [0.11.28] - 2025-11-06

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,13 @@ def handler(self) -> StandardCheckpointHandler:
return StandardCheckpointHandler()

def test_with_random_keys(self):
# TODO(b/393160483) investigate Pathways remote Python support for
# random.keys.
if utils.is_pathways_backend():
self.skipTest('Pathways does not support random keys checkpoint.')
self.skipTest(
'Disabled on Pathways because random keys are not supported by'
' remote Python.'
)

def create_random_keys(seed):
duplicated_sharding = jax.sharding.NamedSharding(
Expand Down Expand Up @@ -559,3 +564,38 @@ def create_random_keys(seed):
args=self.restore_args_cls(abstract_tree),
)
test_utils.assert_tree_equal(self, self.pytree, restored)

def test_save_restore_random_keys_with_jax_eval_shape(self):
# TODO(b/393160483) investigate Pathways remote Python support for
# random.keys.
if utils.is_pathways_backend():
self.skipTest(
'Disabled on Pathways because random keys are not supported by'
' remote Python.'
)

mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

@functools.partial(
jax.jit,
in_shardings=sharding,
out_shardings=sharding,
)
def sharded_create_state_fn(root_key):
return dict(
matrix=jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]),
rngkey=jax.random.fold_in(root_key, 42),
)

pytree = sharded_create_state_fn(jax.random.key(0))
abstract_pytree = jax.eval_shape(
sharded_create_state_fn, jax.random.key(0)
)

self.handler.save(self.directory, args=self.save_args_cls(pytree))

restored = self.handler.restore(
self.directory, args=self.restore_args_cls(abstract_pytree)
)
test_utils.assert_tree_equal(self, pytree, restored)
4 changes: 2 additions & 2 deletions checkpoint/orbax/checkpoint/_src/path/atomicity.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,13 @@ async def _create_tmp_directory(
def _get_tmp_directory(final_path: epath.Path) -> epath.Path:
# Path may not be completely unique if a preemption occurs. We rely on the
# existing tmp directory being deleted elsewhere.
return epath.Path(final_path.parent) / (final_path.name + TMP_DIR_SUFFIX)
return final_path.parent / (final_path.name + TMP_DIR_SUFFIX)


def _get_final_directory(tmp_path: epath.Path) -> epath.Path:
if (suffix_idx := tmp_path.name.find(TMP_DIR_SUFFIX)) == -1:
raise ValueError(f'Expected {tmp_path} to end with "{TMP_DIR_SUFFIX}".')
return epath.Path(tmp_path.parent) / tmp_path.name[:suffix_idx]
return tmp_path.parent / tmp_path.name[:suffix_idx]


class TemporaryPathBase(atomicity_types.TemporaryPath):
Expand Down
104 changes: 38 additions & 66 deletions checkpoint/orbax/checkpoint/_src/path/deleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import threading
import time
from typing import Optional, Protocol, Sequence
from urllib import parse

from absl import logging
from etils import epath
Expand All @@ -31,6 +32,7 @@
from orbax.checkpoint._src.path import step as step_lib


urlparse = parse.urlparse
PurePosixPath = pathlib.PurePosixPath

_THREADED_DELETE_DURATION = (
Expand Down Expand Up @@ -183,7 +185,9 @@ def delete(self, step: int) -> None:
# Attempt to rename using GCS HNS API if configured.
if self._todelete_full_path is not None:
if gcs_utils.is_gcs_path(self._directory):
self._rename_gcs_step_with_hns(step, delete_target)
# This is recommended for GCS buckets with HNS enabled and requires
# `_todelete_full_path` to be specified.
self._gcs_rename_step(step, delete_target)
else:
raise NotImplementedError()
# Attempt to rename to local subdirectory using `todelete_subdir`
Expand All @@ -204,88 +208,56 @@ def delete(self, step: int) -> None:
time.time() - start,
)

def _rename_gcs_step_with_hns(
def _gcs_rename_step(
self, step: int, delete_target: epath.Path
):
"""Renames a GCS directory using the Storage Control API.
"""Renames a GCS directory to a temporary location for deletion.

This method renames the directory using the
underlying `tf.io.gfile.rename` method. This underlying
implementation will automatically detect if the bucket is HNS-enabled
and use a fast atomic rename, or fall back to a legacy
copy/delete rename if it is not.

Args:
step: The checkpoint step number.
delete_target: The path to the directory to be renamed.

Raises:
ValueError: If the GCS bucket is not HNS-enabled, as this is a
hard requirement for this operation.
"""
logging.info(
'Condition: GCS path with `todelete_full_path` set. Checking for HNS.'
)
bucket_name, _ = gcs_utils.parse_gcs_path(self._directory)
if not gcs_utils.is_hierarchical_namespace_enabled(self._directory):
raise ValueError(
f'Bucket "{bucket_name}" does not have Hierarchical Namespace'
' enabled, which is required when _todelete_full_path is set.'
)

logging.info('HNS bucket detected. Attempting to rename step %d.', step)
# pylint: disable=g-import-not-at-top
from google.api_core import exceptions as google_exceptions # pytype: disable=import-error
try:
from google.cloud import storage_control_v2 # pytype: disable=import-error
import google.auth # pytype: disable=import-error

# Use default credentials, but without a quota project to avoid
# quota issues with this API.
credentials, _ = google.auth.default()
creds_without_quota_project = credentials.with_quota_project(None)
client = storage_control_v2.StorageControlClient(
credentials=creds_without_quota_project
)
# Destination parent is the absolute path to the bucket.
destination_parent_dir_str = (
# Get the bucket name from the source path
bucket_name = urlparse(str(delete_target)).netloc
if not bucket_name:
raise ValueError(
f'Could not parse bucket name from path: {delete_target}'
)

# Construct the destination path inside the `_todelete_full_path` dir.
destination_parent_path = epath.Path(
f'gs://{bucket_name}/{self._todelete_full_path}'
)
destination_parent_path = PurePosixPath(destination_parent_dir_str)
logging.info(
'Ensuring destination parent folder exists via HNS API: %s',
destination_parent_dir_str,
)
try:
parent_folder_id = str(
destination_parent_path.relative_to(f'gs://{bucket_name}')
)
bucket_resource_name = f'projects/_/buckets/{bucket_name}'
client.create_folder(
request=storage_control_v2.CreateFolderRequest(
parent=bucket_resource_name,
folder_id=parent_folder_id,
recursive=True,
)
)
logging.info('HNS parent folder creation request sent.')
except google_exceptions.AlreadyExists:
logging.info('HNS parent folder already exists, proceeding.')
destination_parent_path.mkdir(parents=True, exist_ok=True)

# Create a unique name for the destination to avoid collisions.
now = datetime.datetime.now()
timestamp_str = now.strftime('%Y%m%d-%H%M%S-%f')
new_name_with_timestamp = f'{delete_target.name}-{timestamp_str}'
dest_path = destination_parent_path / new_name_with_timestamp
source_folder_id = str(delete_target.relative_to(f'gs://{bucket_name}'))
destination_folder_id = str(dest_path.relative_to(f'gs://{bucket_name}'))
source_resource_name = (
f'projects/_/buckets/{bucket_name}/folders/{source_folder_id}'
)
logging.info('Rename API call: Source: %s', source_resource_name)
logging.info('Rename API call: Destination ID: %s', destination_folder_id)
request = storage_control_v2.RenameFolderRequest(
name=source_resource_name,
destination_folder_id=destination_folder_id,

logging.info(
'Executing filesystem-aware rename: Source=`%s`, Destination=`%s`',
delete_target,
dest_path,
)
op = client.rename_folder(request=request)
op.result()

# Call the high-level rename method.
# This will be fast on HNS and slow (but functional) on non-HNS.
delete_target.rename(dest_path)
logging.info('Successfully renamed step %d to %s', step, dest_path)
except google_exceptions.GoogleAPIError as e:
logging.error('HNS rename failed for step %d. Error: %s', step, e)

except Exception as e:
message = f'Rename failed for step {step}. Error: {e}'
logging.error(message)
raise RuntimeError(message) from e

def _rename_step_to_subdir(self, step: int, delete_target: epath.Path):
"""Renames a step directory to its corresponding todelete_subdir."""
Expand Down
50 changes: 50 additions & 0 deletions checkpoint/orbax/checkpoint/_src/path/deleter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# limitations under the License.

"""To test Orbax in single-host setup."""

import unittest
from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
Expand Down Expand Up @@ -64,5 +68,51 @@ def test_checkpoint_deleter_delete(
deleter.close()


class GcsRenameTest(unittest.TestCase):

@mock.patch('orbax.checkpoint._src.path.deleter.epath.Path')
def test_gcs_rename_logic_directly(self, mock_epath_constructor):
"""Tests path construction and rename call logic."""
standard_checkpoint_deleter = deleter_lib.StandardCheckpointDeleter

deleter = standard_checkpoint_deleter(
directory=mock.MagicMock(),
name_format=step_lib.standard_name_format(),
primary_host=None,
todelete_subdir=None,
todelete_full_path='trash_bin',
enable_hns=False,
)
# When epath.Path() is called inside the code, it returns this mock parent
mock_dest_parent = mock.MagicMock()
mock_epath_constructor.return_value = mock_dest_parent

# When the code does (parent / child), return a specific final mock
mock_final_dest = mock.MagicMock()
mock_final_dest.__str__.return_value = 'gs://mocked/final/destination'
mock_dest_parent.__truediv__.return_value = mock_final_dest

# Setup the "Source" Mock (The step being deleted)
mock_step_path = mock.MagicMock()
mock_step_path.__str__.return_value = 'gs://my-bucket/checkpoints/step_10'
mock_step_path.name = 'step_10'

deleter._gcs_rename_step(step=10, delete_target=mock_step_path)

# Verify mkdir was called on the destination parent.
mock_dest_parent.mkdir.assert_called_with(parents=True, exist_ok=True)

# Verify the Parent Path string was constructed correctly
# The code does: epath.Path(f'gs://{bucket}/{todelete_full_path}')
(parent_path_arg,), _ = mock_epath_constructor.call_args
self.assertEqual(parent_path_arg, 'gs://my-bucket/trash_bin')

# Verify the Child Filename was constructed correctly
(child_name_arg,), _ = mock_dest_parent.__truediv__.call_args
self.assertIn('step_10-', child_name_arg)

# Verify the Rename was actually called
mock_step_path.rename.assert_called_with(mock_final_dest)

if __name__ == '__main__':
absltest.main()
8 changes: 7 additions & 1 deletion checkpoint/orbax/checkpoint/_src/path/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def get_bucket(bucket_name: str):

def is_hierarchical_namespace_enabled(path: epath.PathLike) -> bool:
"""Return whether hierarchical namespace is enabled."""
parsed = parse.urlparse(str(path))
if parsed.scheme != 'gs':
return False
bucket_name, _ = parse_gcs_path(path)
bucket = get_bucket(bucket_name)
return bucket.hierarchical_namespace_enabled
return (
hasattr(bucket, 'hierarchical_namespace_enabled')
and bucket.hierarchical_namespace_enabled
)
9 changes: 1 addition & 8 deletions checkpoint/orbax/checkpoint/_src/path/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,13 +321,11 @@ class _StandardNameFormat(NameFormat[Metadata]):
single_host_load_and_broadcast: If True, the jax process=0 will list all
steps and broadcast them to all other processes. NOTE: Ignored if jax
backend is not multi controller.
enable_hns: Enables HNS-specific path logic.
"""

step_prefix: Optional[str] = None
step_format_fixed_length: Optional[int] = None
single_host_load_and_broadcast: bool = False
enable_hns: bool = False

def __str__(self):
return f'StandardNameFormat("{self.build_name(1234)}")'
Expand Down Expand Up @@ -375,9 +373,7 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]:
"""Returns step paths under `base_path`."""
base_path = epath.Path(base_path)
# <step_prefix>_?<0 padding>?*
if self.enable_hns and gcs_utils.is_hierarchical_namespace_enabled(
base_path
):
if gcs_utils.is_hierarchical_namespace_enabled(base_path):
logging.vlog(
1,
'HNS enabled. Using GCS API to list step paths at %s',
Expand Down Expand Up @@ -560,7 +556,6 @@ def standard_name_format(
step_prefix: Optional[str] = None,
step_format_fixed_length: Optional[int] = None,
single_host_load_and_broadcast: bool = False,
enable_hns: bool = False,
) -> NameFormat[Metadata]:
"""Returns NameFormat for 'standard' steps for common Orbax use cases.

Expand All @@ -580,13 +575,11 @@ def standard_name_format(
single_host_load_and_broadcast: If True, the jax process=0 will list all
steps and broadcast them to all other processes. NOTE: Ignored if jax
backend is not multi controller.
enable_hns: Enables HNS-specific path logic.
"""
return _StandardNameFormat(
step_prefix=step_prefix,
step_format_fixed_length=step_format_fixed_length,
single_host_load_and_broadcast=single_host_load_and_broadcast,
enable_hns=enable_hns,
)


Expand Down
Loading
Loading