Skip to content

Commit f88c956

Browse files
lukebaumanncopybara-github
authored andcommitted
Add pause_resume decorator to Manager.
This change introduces a `pause_resume` decorator to the `Manager` class. This decorator wraps a function to automatically retry execution when a `jax.errors.JaxRuntimeError` occurs due to a slice down event. Before each attempt, it waits for all slices to be available and performs necessary cleanup of JAX caches and live arrays upon failure. PiperOrigin-RevId: 796970321
1 parent 7193f72 commit f88c956

File tree

1 file changed

+75
-1
lines changed

1 file changed

+75
-1
lines changed

pathwaysutils/elastic/manager.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import collections
2929
from collections.abc import Callable, Mapping, Sequence
3030
import copy
31+
import functools
3132
import itertools
3233
import logging
3334
import time
@@ -44,7 +45,10 @@
4445

4546

4647
class ElasticRuntimeError(RuntimeError):
47-
"""Error raised when too many elastic down events or reshard retries occur."""
48+
"""Error raised when elasticity cannot continue.
49+
50+
Some causes of this error are due to too many elastic down events or retries.
51+
"""
4852

4953

5054
class Manager:
@@ -684,3 +688,73 @@ def wait_for_slices(
684688
)
685689

686690
return good_slice_indices
691+
692+
def pause_resume(
693+
self,
694+
max_retries: int,
695+
wait_period: float | int = 10,
696+
timeout: float | None = None,
697+
) -> Any:
698+
"""Retries a function with pause/resume fault tolerance.
699+
700+
This decorator wraps a function to automatically retry execution in case of
701+
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
702+
available slices before each attempt and cleans up JAX caches on failure.
703+
The function will not be attempted (or reattempted) until all of the slices
704+
are available.
705+
706+
Often, the function will dispatch JAX operations and wait for them to
707+
complete while creating a log message. If using Python logging, it is
708+
recommended to set `logging.raiseExceptions=True` to ensure that the
709+
`jax.errors.JaxRuntimeError` is not silently ignored within the logging
710+
call.
711+
712+
Args:
713+
max_retries: The maximum number of times to retry the function.
714+
wait_period: The number of seconds to wait between availability checks.
715+
Defaults to 10 seconds.
716+
timeout: The maximum number of seconds to wait for slices to become
717+
available before each retry attempt. If None, there is no timeout.
718+
719+
Returns:
720+
The result of the wrapped function.
721+
722+
Raises:
723+
ElasticRuntimeError: If all retry attempts fail.
724+
Exception: Any other exception raised by the wrapped function that is not
725+
due to a slice down event.
726+
"""
727+
def decorator(func):
728+
@functools.wraps(func)
729+
def wrapper(*args, **kwargs):
730+
for retry_index in range(max_retries):
731+
try:
732+
_logger.info(
733+
"Elastic attempt %d out of %d", retry_index + 1, max_retries
734+
)
735+
736+
self.wait_for_slices(wait_period=wait_period, timeout=timeout)
737+
738+
return func(*args, **kwargs)
739+
except jax.errors.JaxRuntimeError as error:
740+
if not self.is_error_due_to_slice_down(error):
741+
raise
742+
743+
try:
744+
_logger.info("Cleaning up any ongoing traces")
745+
jax.profiler.stop_trace()
746+
except (RuntimeError, ValueError) as e:
747+
_logger.info("No ongoing traces to clean up")
748+
except Exception:
749+
_logger.exception("Error cleaning up ongoing traces")
750+
raise
751+
752+
jax.clear_caches()
753+
for array in jax.live_arrays():
754+
array.delete()
755+
raise ElasticRuntimeError(
756+
f"Elastic attempt {max_retries} out of {max_retries} failed."
757+
)
758+
759+
return wrapper
760+
return decorator

0 commit comments

Comments
 (0)