Skip to content

Add pause_resume decorator to Manager. #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion pathwaysutils/elastic/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import collections
from collections.abc import Callable, Mapping, Sequence
import copy
import functools
import itertools
import logging
import time
Expand All @@ -44,7 +45,10 @@


class ElasticRuntimeError(RuntimeError):
"""Error raised when too many elastic down events or reshard retries occur."""
"""Error raised when elasticity cannot continue.

Some causes of this error are due to too many elastic down events or retries.
"""


class Manager:
Expand Down Expand Up @@ -684,3 +688,73 @@ def wait_for_slices(
)

return good_slice_indices

def pause_resume(
self,
max_retries: int,
wait_period: float | int = 10,
timeout: float | None = None,
) -> Any:
"""Retries a function with pause/resume fault tolerance.

This decorator wraps a function to automatically retry execution in case of
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
available slices before each attempt and cleans up JAX caches on failure.
The function will not be attempted (or reattempted) until all of the slices
are available.

Often, the function will dispatch JAX operations and wait for them to
complete while creating a log message. If using Python logging, it is
recommended to set `logging.raiseExceptions=True` to ensure that the
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
call.

Args:
max_retries: The maximum number of times to retry the function.
wait_period: The number of seconds to wait between availability checks.
Defaults to 10 seconds.
timeout: The maximum number of seconds to wait for slices to become
available before each retry attempt. If None, there is no timeout.

Returns:
The result of the wrapped function.

Raises:
ElasticRuntimeError: If all retry attempts fail.
Exception: Any other exception raised by the wrapped function that is not
due to a slice down event.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for retry_index in range(max_retries):
try:
_logger.info(
"Elastic attempt %d out of %d", retry_index + 1, max_retries
)

self.wait_for_slices(wait_period=wait_period, timeout=timeout)

return func(*args, **kwargs)
except jax.errors.JaxRuntimeError as error:
if not self.is_error_due_to_slice_down(error):
raise

try:
_logger.info("Cleaning up any ongoing traces")
jax.profiler.stop_trace()
except (RuntimeError, ValueError) as e:
_logger.info("No ongoing traces to clean up")
except Exception:
_logger.exception("Error cleaning up ongoing traces")
raise

jax.clear_caches()
for array in jax.live_arrays():
array.delete()
raise ElasticRuntimeError(
f"Elastic attempt {max_retries} out of {max_retries} failed."
)

return wrapper
return decorator