Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion app/features/conditions/router.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from collections import OrderedDict
from typing import Any
Expand Down Expand Up @@ -85,7 +86,7 @@ async def conditions_test(request: Request, encoder: Encoder, cache: Cache, conf
return web.json_response({"error": "condition is required."}, status=web.HTTPBadRequest.status_code)

try:
validate_url(url, allow_internal=config.allow_internal_urls)
await asyncio.to_thread(validate_url, url, config.allow_internal_urls)
except ValueError as e:
return web.json_response({"error": str(e)}, status=web.HTTPBadRequest.status_code)

Expand Down
1 change: 1 addition & 0 deletions app/features/tasks/definitions/handlers/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def _generic_id(url):
url=url,
no_archive=True,
no_log=True,
budget_sleep=True,
)

if not info:
Expand Down
1 change: 1 addition & 0 deletions app/features/tasks/definitions/handlers/rss.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ async def extract(task: HandleTask) -> TaskResult | TaskFailure:
url=url,
no_archive=True,
no_log=True,
budget_sleep=True,
)

if not info:
Expand Down
3 changes: 2 additions & 1 deletion app/features/tasks/definitions/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async def fetch_metadata(self, full: bool = False) -> tuple[dict[str, Any] | Non
no_archive=True,
follow_redirect=False,
sanitize_info=True,
budget_sleep=True,
)

if not ie_info or not isinstance(ie_info, dict):
Expand All @@ -133,7 +134,7 @@ async def _mark_logic(self) -> tuple[bool, str] | dict[str, Any]:

archive_file: Path = Path(archive_file)

(ie_info, _) = await fetch_info(params, self.url, no_archive=True, follow_redirect=True)
(ie_info, _) = await fetch_info(params, self.url, no_archive=True, follow_redirect=True, budget_sleep=True)
if not ie_info or not isinstance(ie_info, dict):
return (False, "Failed to extract information from URL.")

Expand Down
5 changes: 3 additions & 2 deletions app/features/tasks/router.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -469,7 +470,7 @@ async def task_handler_inspect(request: Request, handler: TaskHandle, encoder: E
static_only: bool = data.get("static_only", False) if isinstance(data, dict) else False
if not static_only:
try:
validate_url(url, allow_internal=config.allow_internal_urls)
await asyncio.to_thread(validate_url, url, config.allow_internal_urls)
except ValueError as e:
return web.json_response({"error": str(e)}, status=web.HTTPBadRequest.status_code)

Expand Down Expand Up @@ -710,7 +711,7 @@ async def task_metadata(request: Request, repo: TasksRepository, config: Config,
continue

try:
validate_url(url, allow_internal=config.allow_internal_urls)
await asyncio.to_thread(validate_url, url, config.allow_internal_urls)
except ValueError:
LOG.warning(f"Invalid thumbnail url '{url}'")
continue
Expand Down
21 changes: 19 additions & 2 deletions app/features/ytdlp/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ def __init__(
self.wait_threshold = wait_threshold


def _sleep_timeout(config: dict[str, Any], timeout: float, budget_sleep: bool) -> float:
if not budget_sleep:
return timeout

sleep_requests = config.get("sleep_interval_requests")
if not isinstance(sleep_requests, int | float) or sleep_requests <= 0:
return timeout

return timeout + min(float(sleep_requests) * 20, 300.0)


class ExtractorPool(metaclass=Singleton):
"""
Manages process pool and semaphore for video information extraction.
Expand Down Expand Up @@ -312,6 +323,7 @@ async def fetch_info(
sanitize_info: bool = False,
capture_logs: int | None = None,
extractor_config: ExtractorConfig | None = None,
budget_sleep: bool = False,
**kwargs,
) -> tuple[dict[str, Any] | None, list[dict[str, Any]]]:
"""
Expand All @@ -329,6 +341,7 @@ async def fetch_info(
sanitize_info: Sanitize the extracted information
capture_logs: If provided (e.g., logging.WARNING), capture logs
extractor_config: Configuration for the extractor
budget_sleep: Whether to add extra timeout budget for request-sleep-heavy extraction
**kwargs: Additional arguments

Returns:
Expand All @@ -352,6 +365,7 @@ async def fetch_info(
loop = asyncio.get_running_loop()

safe_config = _sanitize_config(config)
timeout = _sleep_timeout(safe_config, extractor_config.timeout, budget_sleep)

try:
try:
Expand All @@ -372,9 +386,12 @@ async def fetch_info(
**kwargs,
),
),
timeout=extractor_config.timeout,
timeout=timeout,
)

except TimeoutError:
raise

except Exception as exc:
LOG.exception(exc)
LOG.warning("extract_info process pool failed, falling back to thread pool url=%s error=%s", url, exc)
Expand All @@ -393,7 +410,7 @@ async def fetch_info(
**kwargs,
),
),
timeout=extractor_config.timeout,
timeout=timeout,
)
finally:
semaphore.release()
5 changes: 3 additions & 2 deletions app/features/ytdlp/router.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import logging
import time
Expand Down Expand Up @@ -301,7 +302,7 @@ async def get_info(request: Request, cache: Cache, config: Config) -> Response:
)

try:
validate_url(url, allow_internal=config.allow_internal_urls)
await asyncio.to_thread(validate_url, url, config.allow_internal_urls)
except ValueError as e:
return web.json_response(
data={"status": False, "message": str(e), "error": str(e)},
Expand Down Expand Up @@ -453,7 +454,7 @@ async def get_archive_ids(request: Request, config: Config) -> Response:
for i, url in enumerate(data):
dct = {"index": i, "url": url}
try:
validate_url(url, allow_internal=config.allow_internal_urls)
await asyncio.to_thread(validate_url, url, config.allow_internal_urls)
dct.update(get_archive_id(url))
except ValueError as e:
dct.update({"id": None, "ie_key": None, "archive_id": None, "error": str(e)})
Expand Down
1 change: 1 addition & 0 deletions app/library/downloads/item_adder.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ async def add(
no_archive=False,
follow_redirect=True,
capture_logs=logging.WARNING,
budget_sleep=True,
)

if not entry:
Expand Down
1 change: 1 addition & 0 deletions app/routes/api/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ async def item_nfo_generate(request: Request, queue: DownloadQueue) -> Response:
url=item.info.url,
no_archive=True,
follow_redirect=True,
budget_sleep=True,
)

if not info_dict:
Expand Down
4 changes: 3 additions & 1 deletion app/routes/api/images.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
import random
import time
Expand Down Expand Up @@ -39,7 +40,7 @@ async def get_thumbnail(request: Request, config: Config) -> Response:
return web.json_response(data={"error": "URL is required."}, status=web.HTTPForbidden.status_code)

try:
validate_url(url, allow_internal=config.allow_internal_urls)
await asyncio.to_thread(validate_url, url, config.allow_internal_urls)
except ValueError as e:
return web.json_response(data={"error": str(e)}, status=web.HTTPForbidden.status_code)

Expand All @@ -59,6 +60,7 @@ async def get_thumbnail(request: Request, config: Config) -> Response:
url=url,
follow_redirects=True,
headers=request_headers,
timeout=10.0,
)

if response.status_code != web.HTTPOk.status_code:
Expand Down
120 changes: 120 additions & 0 deletions app/tests/test_async_url_validation_routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

import json
from typing import Any, Generator

import pytest

from app.features.conditions import router as conditions_router
from app.features.tasks import router as tasks_router
from app.features.ytdlp import router as ytdlp_router
from app.library.config import Config


@pytest.fixture(autouse=True)
def reset_config() -> Generator[None, None, None]:
Config._reset_singleton()
yield
Config._reset_singleton()


class _Req:
def __init__(self, payload: Any) -> None:
self._payload = payload
self.body_exists = payload is not None

async def json(self) -> Any:
return self._payload


class _InspectReq(_Req):
query: dict[str, str] = {}
match_info: dict[str, str] = {}


class _QueryReq:
def __init__(self, query: dict[str, str]) -> None:
self.query = query


def _patch_thread(monkeypatch: pytest.MonkeyPatch, module: Any, config: Config, url: str) -> dict[str, bool]:
seen = {"to_thread": False, "validate": False}

def fake_validate_url(next_url: str, allow_internal: bool = False) -> bool:
seen["validate"] = True
assert next_url == url
assert allow_internal is config.allow_internal_urls
raise ValueError("Invalid hostname.")

async def fake_to_thread(func, *args, **kwargs):
seen["to_thread"] = True
return func(*args, **kwargs)

monkeypatch.setattr(module, "validate_url", fake_validate_url)
monkeypatch.setattr(module.asyncio, "to_thread", fake_to_thread)
return seen


@pytest.mark.asyncio
async def test_inspect_thread(monkeypatch: pytest.MonkeyPatch) -> None:
config = Config.get_instance()
request = _InspectReq({"url": "https://bad.example/task"})
seen = _patch_thread(monkeypatch, tasks_router, config, "https://bad.example/task")

response = await tasks_router.task_handler_inspect(request, handler=None, encoder=None, config=config)

assert response.status == 400
assert json.loads(response.body.decode("utf-8")) == {"error": "Invalid hostname."}
assert seen == {"to_thread": True, "validate": True}


@pytest.mark.asyncio
async def test_conditions_thread(monkeypatch: pytest.MonkeyPatch) -> None:
config = Config.get_instance()
request = _Req({"url": "https://bad.example/cond", "condition": "title ~= 'x'"})
seen = _patch_thread(monkeypatch, conditions_router, config, "https://bad.example/cond")

response = await conditions_router.conditions_test(request, encoder=None, cache=None, config=config)

assert response.status == 400
assert json.loads(response.body.decode("utf-8")) == {"error": "Invalid hostname."}
assert seen == {"to_thread": True, "validate": True}


@pytest.mark.asyncio
async def test_info_thread(monkeypatch: pytest.MonkeyPatch) -> None:
config = Config.get_instance()
request = _QueryReq({"url": "https://bad.example/info"})
seen = _patch_thread(monkeypatch, ytdlp_router, config, "https://bad.example/info")

response = await ytdlp_router.get_info(request, cache=None, config=config)

assert response.status == 400
assert json.loads(response.body.decode("utf-8")) == {
"status": False,
"message": "Invalid hostname.",
"error": "Invalid hostname.",
}
assert seen == {"to_thread": True, "validate": True}


@pytest.mark.asyncio
async def test_archive_ids_thread(monkeypatch: pytest.MonkeyPatch) -> None:
config = Config.get_instance()
request = _Req(["https://bad.example/archive"])
seen = _patch_thread(monkeypatch, ytdlp_router, config, "https://bad.example/archive")

response = await ytdlp_router.get_archive_ids(request, config)

assert response.status == 200
assert json.loads(response.body.decode("utf-8")) == [
{
"index": 0,
"url": "https://bad.example/archive",
"id": None,
"ie_key": None,
"archive_id": None,
"error": "Invalid hostname.",
}
]
assert seen == {"to_thread": True, "validate": True}
Loading
Loading