|
28 | 28 | import collections
|
29 | 29 | from collections.abc import Callable, Mapping, Sequence
|
30 | 30 | import copy
|
| 31 | +import functools |
31 | 32 | import itertools
|
32 | 33 | import logging
|
33 | 34 | import time
|
|
44 | 45 |
|
45 | 46 |
|
46 | 47 | class ElasticRuntimeError(RuntimeError):
|
47 |
| - """Error raised when too many elastic down events or reshard retries occur.""" |
| 48 | + """Error raised when elasticity cannot continue. |
| 49 | +
|
| 50 | + Some causes of this error are due to too many elastic down events or retries. |
| 51 | + """ |
48 | 52 |
|
49 | 53 |
|
50 | 54 | class Manager:
|
@@ -684,3 +688,73 @@ def wait_for_slices(
|
684 | 688 | )
|
685 | 689 |
|
686 | 690 | return good_slice_indices
|
| 691 | + |
| 692 | + def pause_resume( |
| 693 | + self, |
| 694 | + max_retries: int, |
| 695 | + wait_period: float | int = 10, |
| 696 | + timeout: float | None = None, |
| 697 | + ) -> Any: |
| 698 | + """Retries a function with pause/resume fault tolerance. |
| 699 | +
|
| 700 | + This decorator wraps a function to automatically retry execution in case of |
| 701 | + `jax.errors.JaxRuntimeError` caused by slice down events. It waits for |
| 702 | + available slices before each attempt and cleans up JAX caches on failure. |
| 703 | + The function will not be attempted (or reattempted) until all of the slices |
| 704 | + are available. |
| 705 | +
|
| 706 | + Often, the function will dispatch JAX operations and wait for them to |
| 707 | + complete while creating a log message. If using Python logging, it is |
| 708 | + recommended to set `logging.raiseExceptions=True` to ensure that the |
| 709 | + `jax.errors.JaxRuntimeError` is not silently ignored within the logging |
| 710 | + call. |
| 711 | +
|
| 712 | + Args: |
| 713 | + max_retries: The maximum number of times to retry the function. |
| 714 | + wait_period: The number of seconds to wait between availability checks. |
| 715 | + Defaults to 10 seconds. |
| 716 | + timeout: The maximum number of seconds to wait for slices to become |
| 717 | + available before each retry attempt. If None, there is no timeout. |
| 718 | +
|
| 719 | + Returns: |
| 720 | + The result of the wrapped function. |
| 721 | +
|
| 722 | + Raises: |
| 723 | + ElasticRuntimeError: If all retry attempts fail. |
| 724 | + Exception: Any other exception raised by the wrapped function that is not |
| 725 | + due to a slice down event. |
| 726 | + """ |
| 727 | + def decorator(func): |
| 728 | + @functools.wraps(func) |
| 729 | + def wrapper(*args, **kwargs): |
| 730 | + for retry_index in range(max_retries): |
| 731 | + try: |
| 732 | + _logger.info( |
| 733 | + "Elastic attempt %d out of %d", retry_index + 1, max_retries |
| 734 | + ) |
| 735 | + |
| 736 | + self.wait_for_slices(wait_period=wait_period, timeout=timeout) |
| 737 | + |
| 738 | + return func(*args, **kwargs) |
| 739 | + except jax.errors.JaxRuntimeError as error: |
| 740 | + if not self.is_error_due_to_slice_down(error): |
| 741 | + raise |
| 742 | + |
| 743 | + try: |
| 744 | + _logger.info("Cleaning up any ongoing traces") |
| 745 | + jax.profiler.stop_trace() |
| 746 | + except (RuntimeError, ValueError) as e: |
| 747 | + _logger.info("No ongoing traces to clean up") |
| 748 | + except Exception: |
| 749 | + _logger.exception("Error cleaning up ongoing traces") |
| 750 | + raise |
| 751 | + |
| 752 | + jax.clear_caches() |
| 753 | + for array in jax.live_arrays(): |
| 754 | + array.delete() |
| 755 | + raise ElasticRuntimeError( |
| 756 | + f"Elastic attempt {max_retries} out of {max_retries} failed." |
| 757 | + ) |
| 758 | + |
| 759 | + return wrapper |
| 760 | + return decorator |
0 commit comments