Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sky/server/requests/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,16 @@ def handle_task_result(self, fut: concurrent.futures.Future,
queue.put(request_element)
except exceptions.ExecutionRetryableError as e:
time.sleep(e.retry_wait_seconds)
# Reset the request status to PENDING so it can be picked up again.
# Assume retryable since the error is ExecutionRetryableError.
request_id, _, _ = request_element
with api_requests.update_request(request_id) as request_task:
assert request_task is not None, request_id
request_task.status = api_requests.RequestStatus.PENDING
# Reschedule the request.
queue = _get_queue(self.schedule_type)
queue.put(request_element)
logger.info(f'Rescheduled request {request_id} for retry')
finally:
# Increment the free executor count when a request finishes
if metrics_utils.METRICS_ENABLED:
Expand Down
24 changes: 24 additions & 0 deletions tests/smoke_tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1775,3 +1775,27 @@ def test_cluster_setup_num_gpus():
teardown=f'sky down -y {name}',
)
smoke_tests_utils.run_one_test(test)


@pytest.mark.aws
def test_launch_retry_until_up():
"""Test that retry until up considers more resources after trying all zones."""
cluster_name = smoke_tests_utils.get_cluster_name()
timeout = 180
test = smoke_tests_utils.Test(
'launch-retry-until-up',
[
# Launch something we'll never get.
f's=$(timeout {timeout} sky launch -c {cluster_name} --gpus B200:8 --infra aws echo hi -y -d --retry-until-up --use-spot 2>&1 || true) && '
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why this works, if we will never get B200:8 , wouldn't we block on this command?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The command will timeout and we will get the logs back and we just parse them to figure out the outcome!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, love this idea!

# Check that "Retry after" appears in the output
'echo "$s" | grep -q "Retry after" && '
# Find the first occurrence of "Retry after" and get its line number
'RETRY_LINE=$(echo "$s" | grep -n "Retry after" | head -1 | cut -d: -f1) && '
# Check that "Considered resources" appears after the first "Retry after"
# We do this by extracting all lines after RETRY_LINE and checking if "Considered resources" appears
'echo "$s" | tail -n +$((RETRY_LINE + 1)) | grep -q "Considered resources"'
],
timeout=200, # Slightly more than 180 to account for test overhead
teardown=f'sky down -y {cluster_name}',
)
smoke_tests_utils.run_one_test(test)
124 changes: 124 additions & 0 deletions tests/unit_tests/test_sky/server/requests/test_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Unit tests for sky.server.requests.executor module."""
import asyncio
import concurrent.futures
import functools
import os
import queue as queue_lib
import time
from typing import List
from unittest import mock
Expand All @@ -10,6 +12,7 @@

from sky import exceptions
from sky import skypilot_config
from sky.server import config as server_config
from sky.server import constants as server_constants
from sky.server.requests import executor
from sky.server.requests import payloads
Expand Down Expand Up @@ -452,6 +455,11 @@ def _keyboard_interrupt_entrypoint():
raise KeyboardInterrupt()


def _dummy_entrypoint_for_retry_test():
"""Dummy entrypoint for retry test that can be pickled."""
return None


@pytest.mark.asyncio
@pytest.mark.parametrize('test_case', [
pytest.param(
Expand Down Expand Up @@ -518,3 +526,119 @@ async def test_stdout_stderr_restoration(mock_fd_operations, test_case):
# Verify no double-close
_assert_no_double_close(mock_fd_operations['close_calls'],
mock_fd_operations['created_fds'])


@pytest.mark.asyncio
async def test_request_worker_retry_execution_retryable_error(
isolated_database, monkeypatch):
"""Test that RequestWorker retries requests when ExecutionRetryableError is raised."""
# Create a request in the database
request_id = 'test-retry-request'
request = requests_lib.Request(
request_id=request_id,
name='test-request',
entrypoint=
_dummy_entrypoint_for_retry_test, # Won't be called in this test
request_body=payloads.RequestBody(),
status=requests_lib.RequestStatus.RUNNING,
created_at=time.time(),
user_id='test-user',
)
await requests_lib.create_if_not_exists_async(request)

# Create a mock queue that tracks puts
queue_items = []
mock_queue = queue_lib.Queue()

class MockRequestQueue:

def __init__(self, queue):
self.queue = queue

def get(self):
try:
return self.queue.get(block=False)
except queue_lib.Empty:
return None

def put(self, item):
queue_items.append(item)
self.queue.put(item)

request_queue = MockRequestQueue(mock_queue)

# Mock _get_queue to return our mock queue
def mock_get_queue(schedule_type):
return request_queue

monkeypatch.setattr(executor, '_get_queue', mock_get_queue)

# Mock time.sleep to track calls (but still sleep for very short waits)
sleep_calls = []

def mock_sleep(seconds):
sleep_calls.append(seconds)

monkeypatch.setattr('time.sleep', mock_sleep)

# Create a mock executor that tracks submit_until_success calls
submit_calls = []

class MockExecutor:

def submit_until_success(self, fn, *args, **kwargs):
submit_calls.append((fn, args, kwargs))
# Return a future that immediately completes (does nothing)
fut = concurrent.futures.Future()
fut.set_result(None)
return fut

mock_executor = MockExecutor()

# Create a RequestWorker
worker = executor.RequestWorker(
schedule_type=requests_lib.ScheduleType.LONG,
config=server_config.WorkerConfig(garanteed_parallelism=1,
burstable_parallelism=0,
num_db_connections_per_worker=0))

# Create a future that raises ExecutionRetryableError
retryable_error = exceptions.ExecutionRetryableError(
'Failed to provision all possible launchable resources.',
hint='Retry after 30s',
retry_wait_seconds=30)
fut = concurrent.futures.Future()
fut.set_exception(retryable_error)

# Create request_element tuple
request_element = (request_id, False, True
) # (request_id, ignore_return_value, retryable)

# Call handle_task_result - this should catch the exception and reschedule
worker.handle_task_result(fut, request_element)

# Verify the request was put back on the queue
assert queue_items == [
request_element
], (f'Expected {request_element} to be put on queue, got {queue_items[0]}')

# Verify time.sleep was called with the retry wait time (first call should be 30)
assert sleep_calls == [
30
], (f'Expected first time.sleep call to be 30 seconds, got {sleep_calls[0]}'
)

# Verify the request status was reset to PENDING
updated_request = requests_lib.get_request(request_id, fields=['status'])
assert updated_request is not None
assert updated_request.status == requests_lib.RequestStatus.PENDING, (
f'Expected request status to be PENDING, got {updated_request.status}')

# Call process_request - it should pick up the request from the queue
# and call submit_until_success
worker.process_request(mock_executor, request_queue)

# Verify submit_until_success was called
assert len(submit_calls) == 1, (
f'Expected submit_until_success to be called once, got {len(submit_calls)} calls'
)
Loading