|
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 |
|
10 | | -from crawlee import ConcurrencySettings, Request |
| 10 | +from crawlee import ConcurrencySettings, Request, RequestState |
11 | 11 | from crawlee.crawlers import HttpCrawler |
12 | 12 | from crawlee.sessions import SessionPool |
13 | 13 | from crawlee.statistics import Statistics |
| 14 | +from crawlee.storages import RequestQueue |
14 | 15 | from tests.unit.server_endpoints import HELLO_WORLD |
15 | 16 |
|
16 | 17 | if TYPE_CHECKING: |
@@ -577,3 +578,57 @@ async def request_handler(context: HttpCrawlingContext) -> None: |
577 | 578 | assert len(kvs_content) == 1 |
578 | 579 | assert content_key.endswith('.html') |
579 | 580 | assert kvs_content[content_key] == HELLO_WORLD.decode('utf8') |
| 581 | + |
| 582 | + |
| 583 | +async def test_request_state(server_url: URL) -> None: |
| 584 | + queue = await RequestQueue.open(alias='http_request_state') |
| 585 | + crawler = HttpCrawler(request_manager=queue) |
| 586 | + |
| 587 | + success_request = Request.from_url(str(server_url)) |
| 588 | + assert success_request.state == RequestState.UNPROCESSED |
| 589 | + |
| 590 | + error_request = Request.from_url(str(server_url / 'error'), user_data={'cause_error': True}) |
| 591 | + |
| 592 | + requests_states: dict[str, dict[str, RequestState]] = {success_request.unique_key: {}, error_request.unique_key: {}} |
| 593 | + |
| 594 | + @crawler.pre_navigation_hook |
| 595 | + async def pre_navigation_hook(context: BasicCrawlingContext) -> None: |
| 596 | + requests_states[context.request.unique_key]['pre_navigation'] = context.request.state |
| 597 | + |
| 598 | + @crawler.router.default_handler |
| 599 | + async def request_handler(context: HttpCrawlingContext) -> None: |
| 600 | + if context.request.user_data.get('cause_error'): |
| 601 | + raise ValueError('Caused error as requested') |
| 602 | + requests_states[context.request.unique_key]['request_handler'] = context.request.state |
| 603 | + |
| 604 | + @crawler.error_handler |
| 605 | + async def error_handler(context: BasicCrawlingContext, _error: Exception) -> None: |
| 606 | + requests_states[context.request.unique_key]['error_handler'] = context.request.state |
| 607 | + |
| 608 | + @crawler.failed_request_handler |
| 609 | + async def failed_request_handler(context: BasicCrawlingContext, _error: Exception) -> None: |
| 610 | + requests_states[context.request.unique_key]['failed_request_handler'] = context.request.state |
| 611 | + |
| 612 | + await crawler.run([success_request, error_request]) |
| 613 | + |
| 614 | + handled_success_request = await queue.get_request(success_request.unique_key) |
| 615 | + |
| 616 | + assert handled_success_request is not None |
| 617 | + assert handled_success_request.state == RequestState.DONE |
| 618 | + |
| 619 | + assert requests_states[success_request.unique_key] == { |
| 620 | + 'pre_navigation': RequestState.BEFORE_NAV, |
| 621 | + 'request_handler': RequestState.REQUEST_HANDLER, |
| 622 | + } |
| 623 | + |
| 624 | + handled_error_request = await queue.get_request(error_request.unique_key) |
| 625 | + assert handled_error_request is not None |
| 626 | + assert handled_error_request.state == RequestState.ERROR |
| 627 | + |
| 628 | + assert requests_states[error_request.unique_key] == { |
| 629 | + 'pre_navigation': RequestState.BEFORE_NAV, |
| 630 | + 'error_handler': RequestState.ERROR_HANDLER, |
| 631 | + 'failed_request_handler': RequestState.ERROR, |
| 632 | + } |
| 633 | + |
| 634 | + await queue.drop() |
0 commit comments