|
28 | 28 | import collections
|
29 | 29 | from collections.abc import Callable, Mapping, Sequence
|
30 | 30 | import copy
|
| 31 | +import functools |
31 | 32 | import itertools
|
32 | 33 | import logging
|
33 | 34 | import time
|
|
44 | 45 |
|
45 | 46 |
|
46 | 47 | class ElasticRuntimeError(RuntimeError):
|
47 |
| - """Error raised when too many elastic down events or reshard retries occur.""" |
| 48 | + """Error raised when elasticity cannot continue. |
| 49 | +
|
| 50 | + Some causes of this error are due to too many elastic down events or retries. |
| 51 | + """ |
48 | 52 |
|
49 | 53 |
|
50 | 54 | class Manager:
|
@@ -718,3 +722,68 @@ def wait_for_slices(
|
718 | 722 | )
|
719 | 723 |
|
720 | 724 | return good_slice_indices
|
| 725 | + |
| 726 | + def elastic_retry( |
| 727 | + self, |
| 728 | + max_retries: int, |
| 729 | + wait_period: float | int = 10, |
| 730 | + timeout: float | None = None, |
| 731 | + ) -> Any: |
| 732 | + """Retries a function with elastic fault tolerance. |
| 733 | +
|
| 734 | + This decorator wraps a function to automatically retry execution in case of |
| 735 | + `jax.errors.JaxRuntimeError` caused by slice down events. It waits for |
| 736 | + available slices before each attempt and cleans up JAX caches on failure. |
| 737 | + The function will not be attempted until all of the slices are available and |
| 738 | + will negate some of the benefits of late-binding. |
| 739 | +
|
| 740 | + Args: |
| 741 | + func: The function to wrap and retry. |
| 742 | + max_retries: The maximum number of times to retry the function. |
| 743 | + wait_period: The number of seconds to wait between availability checks. |
| 744 | + Defaults to 10 seconds. |
| 745 | + timeout: The maximum number of seconds to wait for slices to become |
| 746 | + available before each retry attempt. If None, there is no timeout. |
| 747 | +
|
| 748 | + Returns: |
| 749 | + The result of the wrapped function. |
| 750 | +
|
| 751 | + Raises: |
| 752 | + ElasticRuntimeError: If all retry attempts fail. |
| 753 | + Exception: Any other exception raised by the wrapped function that is not |
| 754 | + due to a slice down event. |
| 755 | + """ |
| 756 | + def decorator(func): |
| 757 | + @functools.wraps(func) |
| 758 | + def wrapper(*args, **kwargs): |
| 759 | + for retry_index in range(max_retries): |
| 760 | + try: |
| 761 | + _logger.info( |
| 762 | + "Elastic attempt %d out of %d", retry_index + 1, max_retries |
| 763 | + ) |
| 764 | + |
| 765 | + self.wait_for_slices(wait_period=wait_period, timeout=timeout) |
| 766 | + |
| 767 | + return func(*args, **kwargs) |
| 768 | + except jax.errors.JaxRuntimeError as error: |
| 769 | + if not self.is_error_due_to_slice_down(error): |
| 770 | + raise |
| 771 | + |
| 772 | + try: |
| 773 | + _logger.info("Cleaning up any ongoing traces") |
| 774 | + jax.profiler.stop_trace() |
| 775 | + except (RuntimeError, ValueError) as e: |
| 776 | + _logger.info("No ongoing traces to clean up") |
| 777 | + except Exception: |
| 778 | + _logger.exception("Error cleaning up ongoing traces") |
| 779 | + raise |
| 780 | + |
| 781 | + jax.clear_caches() |
| 782 | + for array in jax.live_arrays(): |
| 783 | + array.delete() |
| 784 | + raise ElasticRuntimeError( |
| 785 | + f"Elastic attempt {max_retries} out of {max_retries} failed." |
| 786 | + ) |
| 787 | + |
| 788 | + return wrapper |
| 789 | + return decorator |
0 commit comments