|
1 | 1 | """Unit tests for sky.server.requests.executor module.""" |
2 | 2 | import asyncio |
| 3 | +import concurrent.futures |
3 | 4 | import functools |
4 | 5 | import os |
| 6 | +import queue as queue_lib |
5 | 7 | import time |
6 | 8 | from typing import List |
7 | 9 | from unittest import mock |
|
10 | 12 |
|
11 | 13 | from sky import exceptions |
12 | 14 | from sky import skypilot_config |
| 15 | +from sky.server import config as server_config |
13 | 16 | from sky.server import constants as server_constants |
14 | 17 | from sky.server.requests import executor |
15 | 18 | from sky.server.requests import payloads |
@@ -452,6 +455,11 @@ def _keyboard_interrupt_entrypoint(): |
452 | 455 | raise KeyboardInterrupt() |
453 | 456 |
|
454 | 457 |
|
| 458 | +def _dummy_entrypoint_for_retry_test(): |
| 459 | + """Dummy entrypoint for retry test that can be pickled.""" |
| 460 | + return None |
| 461 | + |
| 462 | + |
455 | 463 | @pytest.mark.asyncio |
456 | 464 | @pytest.mark.parametrize('test_case', [ |
457 | 465 | pytest.param( |
@@ -518,3 +526,119 @@ async def test_stdout_stderr_restoration(mock_fd_operations, test_case): |
518 | 526 | # Verify no double-close |
519 | 527 | _assert_no_double_close(mock_fd_operations['close_calls'], |
520 | 528 | mock_fd_operations['created_fds']) |
| 529 | + |
| 530 | + |
| 531 | +@pytest.mark.asyncio |
| 532 | +async def test_request_worker_retry_execution_retryable_error( |
| 533 | + isolated_database, monkeypatch): |
| 534 | + """Test that RequestWorker retries requests when ExecutionRetryableError is raised.""" |
| 535 | + # Create a request in the database |
| 536 | + request_id = 'test-retry-request' |
| 537 | + request = requests_lib.Request( |
| 538 | + request_id=request_id, |
| 539 | + name='test-request', |
| 540 | + entrypoint= |
| 541 | + _dummy_entrypoint_for_retry_test, # Won't be called in this test |
| 542 | + request_body=payloads.RequestBody(), |
| 543 | + status=requests_lib.RequestStatus.RUNNING, |
| 544 | + created_at=time.time(), |
| 545 | + user_id='test-user', |
| 546 | + ) |
| 547 | + await requests_lib.create_if_not_exists_async(request) |
| 548 | + |
| 549 | + # Create a mock queue that tracks puts |
| 550 | + queue_items = [] |
| 551 | + mock_queue = queue_lib.Queue() |
| 552 | + |
| 553 | + class MockRequestQueue: |
| 554 | + |
| 555 | + def __init__(self, queue): |
| 556 | + self.queue = queue |
| 557 | + |
| 558 | + def get(self): |
| 559 | + try: |
| 560 | + return self.queue.get(block=False) |
| 561 | + except queue_lib.Empty: |
| 562 | + return None |
| 563 | + |
| 564 | + def put(self, item): |
| 565 | + queue_items.append(item) |
| 566 | + self.queue.put(item) |
| 567 | + |
| 568 | + request_queue = MockRequestQueue(mock_queue) |
| 569 | + |
| 570 | + # Mock _get_queue to return our mock queue |
| 571 | + def mock_get_queue(schedule_type): |
| 572 | + return request_queue |
| 573 | + |
| 574 | + monkeypatch.setattr(executor, '_get_queue', mock_get_queue) |
| 575 | + |
| 576 | + # Mock time.sleep to track calls (but still sleep for very short waits) |
| 577 | + sleep_calls = [] |
| 578 | + |
| 579 | + def mock_sleep(seconds): |
| 580 | + sleep_calls.append(seconds) |
| 581 | + |
| 582 | + monkeypatch.setattr('time.sleep', mock_sleep) |
| 583 | + |
| 584 | + # Create a mock executor that tracks submit_until_success calls |
| 585 | + submit_calls = [] |
| 586 | + |
| 587 | + class MockExecutor: |
| 588 | + |
| 589 | + def submit_until_success(self, fn, *args, **kwargs): |
| 590 | + submit_calls.append((fn, args, kwargs)) |
| 591 | + # Return a future that immediately completes (does nothing) |
| 592 | + fut = concurrent.futures.Future() |
| 593 | + fut.set_result(None) |
| 594 | + return fut |
| 595 | + |
| 596 | + mock_executor = MockExecutor() |
| 597 | + |
| 598 | + # Create a RequestWorker |
| 599 | + worker = executor.RequestWorker( |
| 600 | + schedule_type=requests_lib.ScheduleType.LONG, |
| 601 | + config=server_config.WorkerConfig(garanteed_parallelism=1, |
| 602 | + burstable_parallelism=0, |
| 603 | + num_db_connections_per_worker=0)) |
| 604 | + |
| 605 | + # Create a future that raises ExecutionRetryableError |
| 606 | + retryable_error = exceptions.ExecutionRetryableError( |
| 607 | + 'Failed to provision all possible launchable resources.', |
| 608 | + hint='Retry after 30s', |
| 609 | + retry_wait_seconds=30) |
| 610 | + fut = concurrent.futures.Future() |
| 611 | + fut.set_exception(retryable_error) |
| 612 | + |
| 613 | + # Create request_element tuple |
| 614 | + request_element = (request_id, False, True |
| 615 | + ) # (request_id, ignore_return_value, retryable) |
| 616 | + |
| 617 | + # Call handle_task_result - this should catch the exception and reschedule |
| 618 | + worker.handle_task_result(fut, request_element) |
| 619 | + |
| 620 | + # Verify the request was put back on the queue |
| 621 | + assert queue_items == [ |
| 622 | + request_element |
| 623 | + ], (f'Expected {request_element} to be put on queue, got {queue_items[0]}') |
| 624 | + |
| 625 | + # Verify time.sleep was called with the retry wait time (first call should be 30) |
| 626 | + assert sleep_calls == [ |
| 627 | + 30 |
| 628 | + ], (f'Expected first time.sleep call to be 30 seconds, got {sleep_calls[0]}' |
| 629 | + ) |
| 630 | + |
| 631 | + # Verify the request status was reset to PENDING |
| 632 | + updated_request = requests_lib.get_request(request_id, fields=['status']) |
| 633 | + assert updated_request is not None |
| 634 | + assert updated_request.status == requests_lib.RequestStatus.PENDING, ( |
| 635 | + f'Expected request status to be PENDING, got {updated_request.status}') |
| 636 | + |
| 637 | + # Call process_request - it should pick up the request from the queue |
| 638 | + # and call submit_until_success |
| 639 | + worker.process_request(mock_executor, request_queue) |
| 640 | + |
| 641 | + # Verify submit_until_success was called |
| 642 | + assert len(submit_calls) == 1, ( |
| 643 | + f'Expected submit_until_success to be called once, got {len(submit_calls)} calls' |
| 644 | + ) |
0 commit comments