Skip to content

Commit 383225f

Browse files
authored
fix: Align Request.state transitions with Request lifecycle (#1601)
### Description - Set the appropriate statuses for `Request.state` at the appropriate stages. - Set the correct final status for `Request`. ### Testing - Add new tests
1 parent 5b97a79 commit 383225f

File tree

9 files changed

+127
-14
lines changed

9 files changed

+127
-14
lines changed

src/crawlee/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from importlib import metadata
22

3-
from ._request import Request, RequestOptions
3+
from ._request import Request, RequestOptions, RequestState
44
from ._service_locator import service_locator
55
from ._types import ConcurrencySettings, EnqueueStrategy, HttpHeaders, RequestTransformAction, SkippedReason
66
from ._utils.globs import Glob
@@ -14,6 +14,7 @@
1414
'HttpHeaders',
1515
'Request',
1616
'RequestOptions',
17+
'RequestState',
1718
'RequestTransformAction',
1819
'SkippedReason',
1920
'service_locator',

src/crawlee/_request.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class CrawleeRequestData(BaseModel):
4141
enqueue_strategy: Annotated[EnqueueStrategy | None, Field(alias='enqueueStrategy')] = None
4242
"""The strategy that was used for enqueuing the request."""
4343

44-
state: RequestState | None = None
44+
state: RequestState = RequestState.UNPROCESSED
4545
"""Describes the request's current lifecycle state."""
4646

4747
session_rotation_count: Annotated[int | None, Field(alias='sessionRotationCount')] = None
@@ -352,7 +352,7 @@ def crawl_depth(self, new_value: int) -> None:
352352
self.crawlee_data.crawl_depth = new_value
353353

354354
@property
355-
def state(self) -> RequestState | None:
355+
def state(self) -> RequestState:
356356
"""Crawlee-specific request handling state."""
357357
return self.crawlee_data.state
358358

src/crawlee/crawlers/_abstract_http/_abstract_http_crawler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pydantic import ValidationError
1111
from typing_extensions import NotRequired, TypeVar
1212

13-
from crawlee._request import Request, RequestOptions
13+
from crawlee._request import Request, RequestOptions, RequestState
1414
from crawlee._utils.docs import docs_group
1515
from crawlee._utils.time import SharedTimeout
1616
from crawlee._utils.urls import to_absolute_url_iterator
@@ -257,6 +257,7 @@ async def _make_http_request(self, context: BasicCrawlingContext) -> AsyncGenera
257257
timeout=remaining_timeout,
258258
)
259259

260+
context.request.state = RequestState.AFTER_NAV
260261
yield HttpCrawlingContext.from_basic_crawling_context(context=context, http_response=result.http_response)
261262

262263
async def _handle_status_code_response(

src/crawlee/crawlers/_basic/_basic_crawler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,7 @@ async def _handle_request_retries(
11521152

11531153
await request_manager.reclaim_request(request)
11541154
else:
1155+
request.state = RequestState.ERROR
11551156
await self._mark_request_as_handled(request)
11561157
await self._handle_failed_request(context, error)
11571158
self._statistics.record_request_processing_failure(request.unique_key)
@@ -1167,8 +1168,6 @@ async def _handle_request_error(self, context: TCrawlingContext | BasicCrawlingC
11671168
f'{self._internal_timeout.total_seconds()} seconds',
11681169
logger=self._logger,
11691170
)
1170-
1171-
context.request.state = RequestState.DONE
11721171
except UserDefinedErrorHandlerError:
11731172
context.request.state = RequestState.ERROR
11741173
raise
@@ -1201,8 +1200,8 @@ async def _handle_skipped_request(
12011200
self, request: Request | str, reason: SkippedReason, *, need_mark: bool = False
12021201
) -> None:
12031202
if need_mark and isinstance(request, Request):
1204-
await self._mark_request_as_handled(request)
12051203
request.state = RequestState.SKIPPED
1204+
await self._mark_request_as_handled(request)
12061205

12071206
url = request.url if isinstance(request, Request) else request
12081207

@@ -1403,8 +1402,6 @@ async def __run_task_function(self) -> None:
14031402
self._statistics.record_request_processing_start(request.unique_key)
14041403

14051404
try:
1406-
request.state = RequestState.REQUEST_HANDLER
1407-
14081405
self._check_request_collision(context.request, context.session)
14091406

14101407
try:
@@ -1414,10 +1411,10 @@ async def __run_task_function(self) -> None:
14141411

14151412
await self._commit_request_handler_result(context)
14161413

1417-
await self._mark_request_as_handled(request)
1418-
14191414
request.state = RequestState.DONE
14201415

1416+
await self._mark_request_as_handled(request)
1417+
14211418
if context.session and context.session.is_usable:
14221419
context.session.mark_good()
14231420

@@ -1483,6 +1480,7 @@ async def __run_task_function(self) -> None:
14831480
raise
14841481

14851482
async def _run_request_handler(self, context: BasicCrawlingContext) -> None:
1483+
context.request.state = RequestState.BEFORE_NAV
14861484
await self._context_pipeline(
14871485
context,
14881486
lambda final_context: wait_for(

src/crawlee/crawlers/_playwright/_playwright_crawler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing_extensions import NotRequired, TypedDict, TypeVar
1414

1515
from crawlee import service_locator
16-
from crawlee._request import Request, RequestOptions
16+
from crawlee._request import Request, RequestOptions, RequestState
1717
from crawlee._types import (
1818
BasicCrawlingContext,
1919
ConcurrencySettings,
@@ -329,6 +329,7 @@ async def _navigate(
329329
response = await context.page.goto(
330330
context.request.url, timeout=remaining_timeout.total_seconds() * 1000, **context.goto_options
331331
)
332+
context.request.state = RequestState.AFTER_NAV
332333
except playwright.async_api.TimeoutError as exc:
333334
raise asyncio.TimeoutError from exc
334335

src/crawlee/router.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Awaitable, Callable
44
from typing import Generic, TypeVar
55

6+
from crawlee._request import RequestState
67
from crawlee._types import BasicCrawlingContext
78
from crawlee._utils.docs import docs_group
89

@@ -89,6 +90,7 @@ def wrapper(handler: Callable[[TCrawlingContext], Awaitable]) -> Callable[[TCraw
8990

9091
async def __call__(self, context: TCrawlingContext) -> None:
9192
"""Invoke a request handler that matches the request label (or the default)."""
93+
context.request.state = RequestState.REQUEST_HANDLER
9294
if context.request.label is None or context.request.label not in self._handlers_by_label:
9395
if self._default_handler is None:
9496
raise RuntimeError(

tests/unit/crawlers/_basic/test_basic_crawler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1829,5 +1829,5 @@ async def error_handler(context: BasicCrawlingContext, error: Exception) -> Requ
18291829
assert original_request.was_already_handled
18301830

18311831
assert error_request is not None
1832-
assert error_request.state == RequestState.REQUEST_HANDLER
1832+
assert error_request.state == RequestState.DONE
18331833
assert error_request.was_already_handled

tests/unit/crawlers/_http/test_http_crawler.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
import pytest
99

10-
from crawlee import ConcurrencySettings, Request
10+
from crawlee import ConcurrencySettings, Request, RequestState
1111
from crawlee.crawlers import HttpCrawler
1212
from crawlee.sessions import SessionPool
1313
from crawlee.statistics import Statistics
14+
from crawlee.storages import RequestQueue
1415
from tests.unit.server_endpoints import HELLO_WORLD
1516

1617
if TYPE_CHECKING:
@@ -577,3 +578,57 @@ async def request_handler(context: HttpCrawlingContext) -> None:
577578
assert len(kvs_content) == 1
578579
assert content_key.endswith('.html')
579580
assert kvs_content[content_key] == HELLO_WORLD.decode('utf8')
581+
582+
583+
async def test_request_state(server_url: URL) -> None:
584+
queue = await RequestQueue.open(alias='http_request_state')
585+
crawler = HttpCrawler(request_manager=queue)
586+
587+
success_request = Request.from_url(str(server_url))
588+
assert success_request.state == RequestState.UNPROCESSED
589+
590+
error_request = Request.from_url(str(server_url / 'error'), user_data={'cause_error': True})
591+
592+
requests_states: dict[str, dict[str, RequestState]] = {success_request.unique_key: {}, error_request.unique_key: {}}
593+
594+
@crawler.pre_navigation_hook
595+
async def pre_navigation_hook(context: BasicCrawlingContext) -> None:
596+
requests_states[context.request.unique_key]['pre_navigation'] = context.request.state
597+
598+
@crawler.router.default_handler
599+
async def request_handler(context: HttpCrawlingContext) -> None:
600+
if context.request.user_data.get('cause_error'):
601+
raise ValueError('Caused error as requested')
602+
requests_states[context.request.unique_key]['request_handler'] = context.request.state
603+
604+
@crawler.error_handler
605+
async def error_handler(context: BasicCrawlingContext, _error: Exception) -> None:
606+
requests_states[context.request.unique_key]['error_handler'] = context.request.state
607+
608+
@crawler.failed_request_handler
609+
async def failed_request_handler(context: BasicCrawlingContext, _error: Exception) -> None:
610+
requests_states[context.request.unique_key]['failed_request_handler'] = context.request.state
611+
612+
await crawler.run([success_request, error_request])
613+
614+
handled_success_request = await queue.get_request(success_request.unique_key)
615+
616+
assert handled_success_request is not None
617+
assert handled_success_request.state == RequestState.DONE
618+
619+
assert requests_states[success_request.unique_key] == {
620+
'pre_navigation': RequestState.BEFORE_NAV,
621+
'request_handler': RequestState.REQUEST_HANDLER,
622+
}
623+
624+
handled_error_request = await queue.get_request(error_request.unique_key)
625+
assert handled_error_request is not None
626+
assert handled_error_request.state == RequestState.ERROR
627+
628+
assert requests_states[error_request.unique_key] == {
629+
'pre_navigation': RequestState.BEFORE_NAV,
630+
'error_handler': RequestState.ERROR_HANDLER,
631+
'failed_request_handler': RequestState.ERROR,
632+
}
633+
634+
await queue.drop()

tests/unit/crawlers/_playwright/test_playwright_crawler.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Glob,
2020
HttpHeaders,
2121
Request,
22+
RequestState,
2223
RequestTransformAction,
2324
SkippedReason,
2425
service_locator,
@@ -991,3 +992,57 @@ async def test_slow_navigation_does_not_count_toward_handler_timeout(server_url:
991992
assert result.requests_failed == 0
992993
assert result.requests_finished == 1
993994
assert request_handler.call_count == 1
995+
996+
997+
async def test_request_state(server_url: URL) -> None:
998+
queue = await RequestQueue.open(alias='playwright_request_state')
999+
crawler = PlaywrightCrawler(request_manager=queue)
1000+
1001+
success_request = Request.from_url(str(server_url))
1002+
assert success_request.state == RequestState.UNPROCESSED
1003+
1004+
error_request = Request.from_url(str(server_url / 'error'), user_data={'cause_error': True})
1005+
1006+
requests_states: dict[str, dict[str, RequestState]] = {success_request.unique_key: {}, error_request.unique_key: {}}
1007+
1008+
@crawler.pre_navigation_hook
1009+
async def pre_navigation_hook(context: PlaywrightPreNavCrawlingContext) -> None:
1010+
requests_states[context.request.unique_key]['pre_navigation'] = context.request.state
1011+
1012+
@crawler.router.default_handler
1013+
async def request_handler(context: PlaywrightCrawlingContext) -> None:
1014+
if context.request.user_data.get('cause_error'):
1015+
raise ValueError('Caused error as requested')
1016+
requests_states[context.request.unique_key]['request_handler'] = context.request.state
1017+
1018+
@crawler.error_handler
1019+
async def error_handler(context: BasicCrawlingContext, _error: Exception) -> None:
1020+
requests_states[context.request.unique_key]['error_handler'] = context.request.state
1021+
1022+
@crawler.failed_request_handler
1023+
async def failed_request_handler(context: BasicCrawlingContext, _error: Exception) -> None:
1024+
requests_states[context.request.unique_key]['failed_request_handler'] = context.request.state
1025+
1026+
await crawler.run([success_request, error_request])
1027+
1028+
handled_success_request = await queue.get_request(success_request.unique_key)
1029+
1030+
assert handled_success_request is not None
1031+
assert handled_success_request.state == RequestState.DONE
1032+
1033+
assert requests_states[success_request.unique_key] == {
1034+
'pre_navigation': RequestState.BEFORE_NAV,
1035+
'request_handler': RequestState.REQUEST_HANDLER,
1036+
}
1037+
1038+
handled_error_request = await queue.get_request(error_request.unique_key)
1039+
assert handled_error_request is not None
1040+
assert handled_error_request.state == RequestState.ERROR
1041+
1042+
assert requests_states[error_request.unique_key] == {
1043+
'pre_navigation': RequestState.BEFORE_NAV,
1044+
'error_handler': RequestState.ERROR_HANDLER,
1045+
'failed_request_handler': RequestState.ERROR,
1046+
}
1047+
1048+
await queue.drop()

0 commit comments

Comments
 (0)