Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions scrapy_spider_metadata/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydantic import BaseModel, ValidationError

from ._utils import get_generic_param, normalize_param_schema
from .defaults import FromSetting

ParamSpecT = TypeVar("ParamSpecT", bound=BaseModel)
logger = getLogger(__name__)
Expand All @@ -30,6 +31,43 @@ def __init__(self, *args: Any, **kwargs: Any):
raise
super().__init__(*args, **kwargs)

def _set_crawler(self, crawler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, where is this called (normally)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh.

I think this should override from_crawler() instead, though I'm not 100% sure if it's possible.

super()._set_crawler(crawler)

if not hasattr(self, "args") or self.args is None:
return

param_model = get_generic_param(self.__class__, Args)
assert param_model is not None
assert issubclass(param_model, BaseModel)

# compat Pydantic v1/v2
if hasattr(self.args, "model_dump"):
data = self.args.model_dump(exclude_unset=True)
else:
data = self.args.dict(exclude_unset=True)

fields = getattr(param_model, "model_fields", None) or getattr(
param_model, "__fields__", {}
)

for field_name, field in fields.items():
default_val = getattr(field, "default", None)
if field_name in data and data[field_name] is not None:
continue
if isinstance(default_val, FromSetting):
getter_name = default_val.getter or "get"
getter = getattr(crawler.settings, getter_name, crawler.settings.get)
value = getter(default_val.name, default_val.default)
if value is not None:
data[field_name] = value

try:
self.args = param_model(**data)
except ValidationError as e:
logger.error(f"Spider parameter validation failed: {e}")
raise

@classmethod
def get_param_schema(cls, normalize: bool = False) -> dict[Any, Any]:
"""Return a :class:`dict` with the :ref:`parameter definition
Expand Down
9 changes: 9 additions & 0 deletions scrapy_spider_metadata/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass
from typing import Any, Literal


@dataclass(frozen=True)
class FromSetting:
name: str
default: Any = None
getter: Literal["get", "getint", "getbool", "getfloat"] = "get"
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os

os.environ.setdefault(
"TWISTED_REACTOR",
"twisted.internet.asyncioreactor.AsyncioSelectorReactor",
)

try:
from twisted.internet import asyncioreactor
asyncioreactor.install()
except Exception:
pass
39 changes: 39 additions & 0 deletions tests/test_fromsetting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from scrapy.utils.test import get_crawler
from pydantic import BaseModel
from scrapy import Spider

from scrapy_spider_metadata.defaults import FromSetting
from scrapy_spider_metadata._params import Args


class Params(BaseModel):
pages: int = FromSetting("MAX_PAGES_SETTING", default=5, getter="getint")
lang: str = "en"


class S(Args[Params], Spider):
name = "s"


def test_fromsetting_reads_setting():
crawler = get_crawler(S, settings_dict={"MAX_PAGES_SETTING": 10})
s = S()
s._set_crawler(crawler)
assert s.args.pages == 10
assert s.args.lang == "en"


def test_fromsetting_uses_default_when_missing():
crawler = get_crawler(S, settings_dict={})
s = S()
s._set_crawler(crawler)
assert s.args.pages == 5
assert s.args.lang == "en"


def test_cli_overrides_everything():
crawler = get_crawler(S, settings_dict={"MAX_PAGES_SETTING": 10})
s = S(pages=99)
s._set_crawler(crawler)
assert s.args.pages == 99
Loading