Skip to content

Commit 021ae3d

Browse files
author
Orbax Authors
committed
This PR decouples Orbax from HNS rename API dependencies by delegating the filesystem operations to the underlying TensorFlow gfile implementation. This change improves testability and makes Orbax less susceptible to regressions from this api client changes.
PiperOrigin-RevId: 833996717
1 parent 79b3896 commit 021ae3d

File tree

3 files changed

+37
-77
lines changed

3 files changed

+37
-77
lines changed

checkpoint/orbax/checkpoint/_src/path/deleter.py

Lines changed: 36 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import threading
2222
import time
2323
from typing import Optional, Protocol, Sequence
24+
from urllib import parse
2425

2526
from absl import logging
2627
from etils import epath
@@ -31,6 +32,7 @@
3132
from orbax.checkpoint._src.path import step as step_lib
3233

3334

35+
urlparse = parse.urlparse
3436
PurePosixPath = pathlib.PurePosixPath
3537

3638
_THREADED_DELETE_DURATION = (
@@ -183,7 +185,8 @@ def delete(self, step: int) -> None:
183185
# Attempt to rename using GCS HNS API if configured.
184186
if self._todelete_full_path is not None:
185187
if gcs_utils.is_gcs_path(self._directory):
186-
self._rename_gcs_step_with_hns(step, delete_target)
188+
# This is recommended for GCS buckets with HNS enabled.
189+
self._gcs_rename_step(step, delete_target)
187190
else:
188191
raise NotImplementedError()
189192
# Attempt to rename to local subdirectory using `todelete_subdir`
@@ -204,88 +207,55 @@ def delete(self, step: int) -> None:
204207
time.time() - start,
205208
)
206209

207-
def _rename_gcs_step_with_hns(
210+
def _gcs_rename_step(
208211
self, step: int, delete_target: epath.Path
209212
):
210-
"""Renames a GCS directory using the Storage Control API.
213+
"""Renames a GCS directory to a temporary location for deletion.
214+
215+
This method renames the directory using the
216+
underlying `tf.io.gfile.rename` method. This underlying
217+
implementation will automatically detect if the bucket is HNS-enabled
218+
and use a fast atomic rename, or fall back to a legacy
219+
copy/delete rename if it is not.
211220
212221
Args:
213222
step: The checkpoint step number.
214223
delete_target: The path to the directory to be renamed.
215-
216-
Raises:
217-
ValueError: If the GCS bucket is not HNS-enabled, as this is a
218-
hard requirement for this operation.
219224
"""
220-
logging.info(
221-
'Condition: GCS path with `todelete_full_path` set. Checking for HNS.'
222-
)
223-
bucket_name, _ = gcs_utils.parse_gcs_path(self._directory)
224-
if not gcs_utils.is_hierarchical_namespace_enabled(self._directory):
225-
raise ValueError(
226-
f'Bucket "{bucket_name}" does not have Hierarchical Namespace'
227-
' enabled, which is required when _todelete_full_path is set.'
228-
)
229-
230-
logging.info('HNS bucket detected. Attempting to rename step %d.', step)
231-
# pylint: disable=g-import-not-at-top
232-
from google.api_core import exceptions as google_exceptions # pytype: disable=import-error
233225
try:
234-
from google.cloud import storage_control_v2 # pytype: disable=import-error
235-
import google.auth # pytype: disable=import-error
236-
237-
# Use default credentials, but without a quota project to avoid
238-
# quota issues with this API.
239-
credentials, _ = google.auth.default()
240-
creds_without_quota_project = credentials.with_quota_project(None)
241-
client = storage_control_v2.StorageControlClient(
242-
credentials=creds_without_quota_project
243-
)
244-
# Destination parent is the absolute path to the bucket.
245-
destination_parent_dir_str = (
226+
# Get the bucket name from the source path
227+
bucket_name = urlparse(str(delete_target)).netloc
228+
if not bucket_name:
229+
raise ValueError(
230+
f'Could not parse bucket name from path: {delete_target}'
231+
)
232+
233+
# Construct the destination path inside the `_todelete_full_path` dir.
234+
destination_parent_path = epath.Path(
246235
f'gs://{bucket_name}/{self._todelete_full_path}'
247236
)
248-
destination_parent_path = PurePosixPath(destination_parent_dir_str)
249-
logging.info(
250-
'Ensuring destination parent folder exists via HNS API: %s',
251-
destination_parent_dir_str,
252-
)
253-
try:
254-
parent_folder_id = str(
255-
destination_parent_path.relative_to(f'gs://{bucket_name}')
256-
)
257-
bucket_resource_name = f'projects/_/buckets/{bucket_name}'
258-
client.create_folder(
259-
request=storage_control_v2.CreateFolderRequest(
260-
parent=bucket_resource_name,
261-
folder_id=parent_folder_id,
262-
recursive=True,
263-
)
264-
)
265-
logging.info('HNS parent folder creation request sent.')
266-
except google_exceptions.AlreadyExists:
267-
logging.info('HNS parent folder already exists, proceeding.')
237+
destination_parent_path.mkdir(parents=True, exist_ok=True)
268238

239+
# Create a unique name for the destination to avoid collisions.
269240
now = datetime.datetime.now()
270241
timestamp_str = now.strftime('%Y%m%d-%H%M%S-%f')
271242
new_name_with_timestamp = f'{delete_target.name}-{timestamp_str}'
272243
dest_path = destination_parent_path / new_name_with_timestamp
273-
source_folder_id = str(delete_target.relative_to(f'gs://{bucket_name}'))
274-
destination_folder_id = str(dest_path.relative_to(f'gs://{bucket_name}'))
275-
source_resource_name = (
276-
f'projects/_/buckets/{bucket_name}/folders/{source_folder_id}'
277-
)
278-
logging.info('Rename API call: Source: %s', source_resource_name)
279-
logging.info('Rename API call: Destination ID: %s', destination_folder_id)
280-
request = storage_control_v2.RenameFolderRequest(
281-
name=source_resource_name,
282-
destination_folder_id=destination_folder_id,
244+
245+
logging.info(
246+
'Executing filesystem-aware rename: Source=`%s`, Destination=`%s`',
247+
delete_target,
248+
dest_path,
283249
)
284-
op = client.rename_folder(request=request)
285-
op.result()
250+
251+
# Call the high-level rename method.
252+
# This will be fast on HNS and slow (but functional) on non-HNS.
253+
delete_target.rename(dest_path)
286254
logging.info('Successfully renamed step %d to %s', step, dest_path)
287-
except google_exceptions.GoogleAPIError as e:
288-
logging.error('HNS rename failed for step %d. Error: %s', step, e)
255+
256+
except Exception as e:
257+
logging.error('Rename failed for step %d. Error: %s', step, e)
258+
raise
289259

290260
def _rename_step_to_subdir(self, step: int, delete_target: epath.Path):
291261
"""Renames a step directory to its corresponding todelete_subdir."""

checkpoint/orbax/checkpoint/_src/path/step.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,11 @@ class _StandardNameFormat(NameFormat[Metadata]):
321321
single_host_load_and_broadcast: If True, the jax process=0 will list all
322322
steps and broadcast them to all other processes. NOTE: Ignored if jax
323323
backend is not multi controller.
324-
enable_hns: Enables HNS-specific path logic.
325324
"""
326325

327326
step_prefix: Optional[str] = None
328327
step_format_fixed_length: Optional[int] = None
329328
single_host_load_and_broadcast: bool = False
330-
enable_hns: bool = False
331329

332330
def __str__(self):
333331
return f'StandardNameFormat("{self.build_name(1234)}")'
@@ -375,7 +373,7 @@ def _glob_step_paths(self, base_path: epath.PathLike) -> list[epath.Path]:
375373
"""Returns step paths under `base_path`."""
376374
base_path = epath.Path(base_path)
377375
# <step_prefix>_?<0 padding>?*
378-
if self.enable_hns and gcs_utils.is_hierarchical_namespace_enabled(
376+
if gcs_utils.is_hierarchical_namespace_enabled(
379377
base_path
380378
):
381379
logging.vlog(
@@ -560,7 +558,6 @@ def standard_name_format(
560558
step_prefix: Optional[str] = None,
561559
step_format_fixed_length: Optional[int] = None,
562560
single_host_load_and_broadcast: bool = False,
563-
enable_hns: bool = False,
564561
) -> NameFormat[Metadata]:
565562
"""Returns NameFormat for 'standard' steps for common Orbax use cases.
566563
@@ -580,13 +577,11 @@ def standard_name_format(
580577
single_host_load_and_broadcast: If True, the jax process=0 will list all
581578
steps and broadcast them to all other processes. NOTE: Ignored if jax
582579
backend is not multi controller.
583-
enable_hns: Enables HNS-specific path logic.
584580
"""
585581
return _StandardNameFormat(
586582
step_prefix=step_prefix,
587583
step_format_fixed_length=step_format_fixed_length,
588584
single_host_load_and_broadcast=single_host_load_and_broadcast,
589-
enable_hns=enable_hns,
590585
)
591586

592587

checkpoint/orbax/checkpoint/checkpoint_manager.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,6 @@ class CheckpointManagerOptions:
318318
gs://my-bucket/trash/<step_id>. Useful when direct deletion is time
319319
consuming. It gathers all deleted items in a centralized path for
320320
future cleanup.
321-
enable_hns: If True, enables HNS-specific path manipulation logic.
322-
Experimental feature.
323321
enable_background_delete: If True, old checkpoint deletions will be done in a
324322
background thread, otherwise, it will be done at the end of each save. When
325323
it's enabled, make sure to call CheckpointManager.close() or use context to
@@ -386,7 +384,6 @@ class CheckpointManagerOptions:
386384
single_host_load_and_broadcast: bool = False
387385
todelete_subdir: Optional[str] = None
388386
todelete_full_path: Optional[str] = None
389-
enable_hns: bool = False
390387
enable_background_delete: bool = False
391388
read_only: bool = False
392389
enable_async_checkpointing: bool = True
@@ -874,7 +871,6 @@ def __init__(
874871
single_host_load_and_broadcast=(
875872
self._options.single_host_load_and_broadcast
876873
),
877-
enable_hns=self._options.enable_hns,
878874
)
879875
)
880876

@@ -905,7 +901,6 @@ def __init__(
905901
primary_host=self._multiprocessing_options.primary_host,
906902
todelete_subdir=self._options.todelete_subdir,
907903
todelete_full_path=self._options.todelete_full_path,
908-
enable_hns=self._options.enable_hns,
909904
enable_background_delete=self._options.enable_background_delete,
910905
)
911906
)

0 commit comments

Comments
 (0)