diff --git a/pathwaysutils/elastic/manager.py b/pathwaysutils/elastic/manager.py index 642dcb2..5d3f9d9 100644 --- a/pathwaysutils/elastic/manager.py +++ b/pathwaysutils/elastic/manager.py @@ -28,6 +28,7 @@ import collections from collections.abc import Callable, Mapping, Sequence import copy +import functools import itertools import logging import time @@ -44,7 +45,10 @@ class ElasticRuntimeError(RuntimeError): - """Error raised when too many elastic down events or reshard retries occur.""" + """Error raised when elasticity cannot continue. + + Some causes of this error are due to too many elastic down events or retries. + """ class Manager: @@ -684,3 +688,73 @@ def wait_for_slices( ) return good_slice_indices + + def pause_resume( + self, + max_retries: int, + wait_period: float | int = 10, + timeout: float | None = None, + ) -> Any: + """Retries a function with pause/resume fault tolerance. + + This decorator wraps a function to automatically retry execution in case of + `jax.errors.JaxRuntimeError` caused by slice down events. It waits for + available slices before each attempt and cleans up JAX caches on failure. + The function will not be attempted (or reattempted) until all of the slices + are available. + + Often, the function will dispatch JAX operations and wait for them to + complete while creating a log message. If using Python logging, it is + recommended to set `logging.raiseExceptions=True` to ensure that the + `jax.errors.JaxRuntimeError` is not silently ignored within the logging + call. + + Args: + max_retries: The maximum number of times to retry the function. + wait_period: The number of seconds to wait between availability checks. + Defaults to 10 seconds. + timeout: The maximum number of seconds to wait for slices to become + available before each retry attempt. If None, there is no timeout. + + Returns: + The result of the wrapped function. + + Raises: + ElasticRuntimeError: If all retry attempts fail. + Exception: Any other exception raised by the wrapped function that is not + due to a slice down event. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for retry_index in range(max_retries): + try: + _logger.info( + "Elastic attempt %d out of %d", retry_index + 1, max_retries + ) + + self.wait_for_slices(wait_period=wait_period, timeout=timeout) + + return func(*args, **kwargs) + except jax.errors.JaxRuntimeError as error: + if not self.is_error_due_to_slice_down(error): + raise + + try: + _logger.info("Cleaning up any ongoing traces") + jax.profiler.stop_trace() + except (RuntimeError, ValueError) as e: + _logger.info("No ongoing traces to clean up") + except Exception: + _logger.exception("Error cleaning up ongoing traces") + raise + + jax.clear_caches() + for array in jax.live_arrays(): + array.delete() + raise ElasticRuntimeError( + f"Elastic attempt {max_retries} out of {max_retries} failed." + ) + + return wrapper + return decorator