Skip to content

Commit 4d9f40d

Browse files
lukebaumanncopybara-github
authored andcommitted
Add elastic_retry decorator to Manager.
This change introduces an `elastic_retry` decorator to the `Manager` class. This decorator wraps a function to automatically retry execution when a `jax.errors.JaxRuntimeError` occurs due to a slice down event. Before each attempt, it waits for all slices to be available and performs necessary cleanup of JAX caches and live arrays upon failure. PiperOrigin-RevId: 796970321
1 parent 5756f63 commit 4d9f40d

File tree

1 file changed

+70
-1
lines changed

1 file changed

+70
-1
lines changed

pathwaysutils/elastic/manager.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import collections
2929
from collections.abc import Callable, Mapping, Sequence
3030
import copy
31+
import functools
3132
import itertools
3233
import logging
3334
import time
@@ -44,7 +45,10 @@
4445

4546

4647
class ElasticRuntimeError(RuntimeError):
47-
"""Error raised when too many elastic down events or reshard retries occur."""
48+
"""Error raised when elasticity cannot continue.
49+
50+
Some causes of this error are due to too many elastic down events or retries.
51+
"""
4852

4953

5054
class Manager:
@@ -718,3 +722,68 @@ def wait_for_slices(
718722
)
719723

720724
return good_slice_indices
725+
726+
def elastic_retry(
727+
self,
728+
max_retries: int,
729+
wait_period: float | int = 10,
730+
timeout: float | None = None,
731+
) -> Any:
732+
"""Retries a function with elastic fault tolerance.
733+
734+
This decorator wraps a function to automatically retry execution in case of
735+
`jax.errors.JaxRuntimeError` caused by slice down events. It waits for
736+
available slices before each attempt and cleans up JAX caches on failure.
737+
The function will not be attempted until all of the slices are available and
738+
will negate some of the benefits of late-binding.
739+
740+
Args:
741+
func: The function to wrap and retry.
742+
max_retries: The maximum number of times to retry the function.
743+
wait_period: The number of seconds to wait between availability checks.
744+
Defaults to 10 seconds.
745+
timeout: The maximum number of seconds to wait for slices to become
746+
available before each retry attempt. If None, there is no timeout.
747+
748+
Returns:
749+
The result of the wrapped function.
750+
751+
Raises:
752+
ElasticRuntimeError: If all retry attempts fail.
753+
Exception: Any other exception raised by the wrapped function that is not
754+
due to a slice down event.
755+
"""
756+
def decorator(func):
757+
@functools.wraps(func)
758+
def wrapper(*args, **kwargs):
759+
for retry_index in range(max_retries):
760+
try:
761+
_logger.info(
762+
"Elastic attempt %d out of %d", retry_index + 1, max_retries
763+
)
764+
765+
self.wait_for_slices(wait_period=wait_period, timeout=timeout)
766+
767+
return func(*args, **kwargs)
768+
except jax.errors.JaxRuntimeError as error:
769+
if not self.is_error_due_to_slice_down(error):
770+
raise
771+
772+
try:
773+
_logger.info("Cleaning up any ongoing traces")
774+
jax.profiler.stop_trace()
775+
except (RuntimeError, ValueError) as e:
776+
_logger.info("No ongoing traces to clean up")
777+
except Exception:
778+
_logger.exception("Error cleaning up ongoing traces")
779+
raise
780+
781+
jax.clear_caches()
782+
for array in jax.live_arrays():
783+
array.delete()
784+
raise ElasticRuntimeError(
785+
f"Elastic attempt {max_retries} out of {max_retries} failed."
786+
)
787+
788+
return wrapper
789+
return decorator

0 commit comments

Comments
 (0)