diff --git a/.github/workflows/native-build.yml b/.github/workflows/native-build.yml index 4e8a5b03..8840d7d7 100644 --- a/.github/workflows/native-build.yml +++ b/.github/workflows/native-build.yml @@ -113,7 +113,9 @@ jobs: shell: bash run: | pip install uv - uv venv --system-site-packages --relocatable + if [ ! -d .venv ]; then + uv venv --system-site-packages --relocatable + fi uv sync --link-mode=copy --active --extra installer - name: Cache bun installation diff --git a/API.md b/API.md index 14beca58..8b625811 100644 --- a/API.md +++ b/API.md @@ -97,6 +97,10 @@ This document describes the available endpoints and their usage. All endpoints r - [GET /api/yt-dlp/options](#get-apiyt-dlpoptions) - [GET /api/system/configuration](#get-apisystemconfiguration) - [POST /api/system/terminal](#post-apisystemterminal) + - [GET /api/system/terminal/active](#get-apisystemterminalactive) + - [GET /api/system/terminal/{session\_id}](#get-apisystemterminalsession_id) + - [DELETE /api/system/terminal/{session\_id}](#delete-apisystemterminalsession_id) + - [GET /api/system/terminal/{session\_id}/stream](#get-apisystemterminalsession_idstream) - [POST /api/system/pause](#post-apisystempause) - [POST /api/system/resume](#post-apisystemresume) - [POST /api/system/shutdown](#post-apisystemshutdown) @@ -2476,7 +2480,7 @@ or an error: --- ### POST /api/system/terminal -**Purpose**: Stream yt-dlp CLI output via Server-Sent Events (SSE). Requires `YTP_CONSOLE_ENABLED=true`. +**Purpose**: Start a yt-dlp terminal session. Requires `YTP_CONSOLE_ENABLED=true`. **Body**: ```json @@ -2485,9 +2489,118 @@ or an error: } ``` +**Response**: +```json +{ + "session_id": "3a8c5f7e2d3b4a8f9c0d1e2f3a4b5c6d", + "command": "--help", + "status": "starting", + "created_at": 1713000000.0, + "started_at": 1713000000.0, + "finished_at": null, + "expires_at": null, + "exit_code": null, + "last_sequence": 0 +} +``` + +**Notes**: +- Starts the command in an application-owned background task. +- Only one active terminal session is allowed at a time. +- The command continues running if the frontend disconnects or reloads. + +- `403 Forbidden` if console is disabled. +- `400 Bad Request` if the request body is invalid. +- `409 Conflict` if another terminal session is already active. + +--- + +### GET /api/system/terminal/active +**Purpose**: Return the currently active terminal session metadata, or `null` if no session is active. + +**Response**: +```json +null +``` +or +```json +{ + "session_id": "3a8c5f7e2d3b4a8f9c0d1e2f3a4b5c6d", + "command": "--help", + "status": "running", + "created_at": 1713000000.0, + "started_at": 1713000000.0, + "finished_at": null, + "expires_at": null, + "exit_code": null, + "last_sequence": 12 +} +``` + +- `403 Forbidden` if console is disabled. + +--- + +### GET /api/system/terminal/{session_id} +**Purpose**: Return metadata for a specific terminal session while it is active or still within the replay/drain window. + +**Response**: +```json +{ + "session_id": "3a8c5f7e2d3b4a8f9c0d1e2f3a4b5c6d", + "command": "--help", + "status": "completed", + "created_at": 1713000000.0, + "started_at": 1713000000.0, + "finished_at": 1713000004.0, + "expires_at": 1713000034.0, + "exit_code": 0, + "last_sequence": 15 +} +``` + +- `403 Forbidden` if console is disabled. +- `404 Not Found` if the session does not exist or has already expired. + +--- + +### DELETE /api/system/terminal/{session_id} +**Purpose**: Request cancellation for the active terminal session. + +**Response**: +```json +{ + "message": "Terminal session cancellation requested.", + "session_id": "3a8c5f7e2d3b4a8f9c0d1e2f3a4b5c6d" +} +``` + +**Notes**: +- This only applies to the currently active session. +- The client should stay attached to the stream to receive the final `close` event and refreshed terminal status. +- Cancelled sessions finalize as `interrupted` and remain replayable until the drain window expires. + +- `403 Forbidden` if console is disabled. +- `404 Not Found` if the session does not exist or has already expired. +- `409 Conflict` if the session exists but is no longer active. + +--- + +### GET /api/system/terminal/{session_id}/stream +**Purpose**: Replay a terminal session transcript over SSE and tail live output when the session is still running. + +**Query Parameters**: +- `since` (optional): Resume after the provided integer event id. + +**Headers**: +- `Last-Event-ID` (optional): Resume after the provided integer event id. + +If both `since` and `Last-Event-ID` are present, the larger value is used. + **Response**: - `Content-Type: text/event-stream` -- Emits `output` events for stdout/stderr and a final `close` event when the process exits. +- Replays transcript events with monotonic integer SSE `id` values. +- Emits `output` events for stdout/stderr and a final `close` event when available. **Event Payloads**: ```json @@ -2497,8 +2610,13 @@ or an error: { "exitcode": 0 } ``` +**Notes**: +- Replay/restore works while the session is still running or until the finished session expires. +- Finished sessions are removed lazily after the transcript drain window elapses. + - `403 Forbidden` if console is disabled. -- `400 Bad Request` if the request body is invalid. +- `400 Bad Request` if the replay cursor is invalid. +- `404 Not Found` if the session does not exist or has already expired. --- diff --git a/app/features/ytdlp/patches.py b/app/features/ytdlp/patches.py new file mode 100644 index 00000000..d38b2a64 --- /dev/null +++ b/app/features/ytdlp/patches.py @@ -0,0 +1,98 @@ +import logging +import subprocess +import sys +from typing import Any + +LOG: logging.Logger = logging.getLogger("ytdlp.utils") + + +def patch_metadataparser() -> None: + """ + Patches yt_dlp MetadataParserPP action to handle subprocess pickling issues. + """ + try: + from yt_dlp.postprocessor.metadataparser import MetadataParserPP + from yt_dlp.utils import Namespace + except Exception as exc: + LOG.warning(f"Unable to import yt_dlp metadata parser for patching: {exc!s}") + return + + if getattr(MetadataParserPP.Actions, "_ytptube_patched", False): + return + + class _ActionNS(Namespace): + _ACTIONS_STR: list[str] = [] + + @staticmethod + def _get_name(func) -> str | None: + if not callable(func): + return None + + target = getattr(func, "__func__", func) + module_name = getattr(target, "__module__", None) + qual_name = getattr(target, "__qualname__", getattr(target, "__name__", None)) + + return f"{module_name}.{qual_name}" if module_name and qual_name else None + + def __contains__(self, candidate: object) -> bool: + if candidate in self.__dict__.values(): + return True + + if func_name := _ActionNS._get_name(candidate): + if len(_ActionNS._ACTIONS_STR) < 1: + _ActionNS._ACTIONS_STR.extend( + [value for value in (_ActionNS._get_name(value) for value in self.__dict__.values()) if value] + ) + + return func_name in _ActionNS._ACTIONS_STR + + return False + + actions_dict: dict[str, Any] = dict(MetadataParserPP.Actions.items_) + MetadataParserPP.Actions = _ActionNS(**actions_dict) + MetadataParserPP.Actions._ytptube_patched = True + LOG.debug("MetadataParserPP action namespace patch applied successfully.") + + +def patch_windows_popen_wait() -> None: + if sys.platform != "win32": + return + + try: + from yt_dlp.utils import Popen + except Exception as exc: + LOG.warning(f"Unable to import yt_dlp Popen for patching: {exc!s}") + return + + if getattr(Popen, "_ytptube_wait_patched", False): + return + + original_wait = Popen.wait + + # Windows subprocess waits can swallow the synthetic interrupt we use to + # stop live downloads, especially while yt-dlp is blocked on ffmpeg. + def interruptible_wait(self, timeout=None): + if timeout is not None: + return original_wait(self, timeout=timeout) + + while True: + try: + return original_wait(self, timeout=0.1) + except subprocess.TimeoutExpired: + continue + + Popen.wait = interruptible_wait + Popen._ytptube_wait_patched = True + LOG.debug("yt_dlp Popen.wait Windows patch applied successfully.") + + +def apply_ytdlp_patches() -> None: + try: + patch_metadataparser() + except Exception as exc: + LOG.debug("Metadata parser patch failed to apply: %s", exc) + + try: + patch_windows_popen_wait() + except Exception as exc: + LOG.debug("Windows Popen wait patch failed to apply: %s", exc) diff --git a/app/features/ytdlp/tests/test_ytdlp_module.py b/app/features/ytdlp/tests/test_ytdlp_module.py index 1adf9112..66215795 100644 --- a/app/features/ytdlp/tests/test_ytdlp_module.py +++ b/app/features/ytdlp/tests/test_ytdlp_module.py @@ -1,5 +1,6 @@ from unittest.mock import MagicMock, Mock, patch +from app.features.ytdlp.patches import patch_windows_popen_wait from app.features.ytdlp.utils import _DATA from app.features.ytdlp.ytdlp import YTDLP, _ArchiveProxy, ytdlp_options @@ -145,6 +146,61 @@ def test_init_handles_none_params(self, mock_super_init) -> None: assert isinstance(ytdlp.archive, _ArchiveProxy) assert not ytdlp.archive + @patch("app.features.ytdlp.ytdlp.yt_dlp.YoutubeDL.__init__") + def test_init_patches_windows_popen_wait_once(self, mock_super_init) -> None: + mock_super_init.return_value = None + + class FakePopen: + def wait(self, timeout=None): + return timeout + + with patch("app.features.ytdlp.patches.sys.platform", "win32"): + with patch("yt_dlp.utils.Popen", FakePopen): + YTDLP(params={}) + + assert getattr(FakePopen, "_ytptube_wait_patched", False) is True + + def test_windows_wait_patch_uses_polling_for_blocking_wait(self) -> None: + calls: list[float | None] = [] + + class FakePopen: + _ytptube_wait_patched = False + + def wait(self, timeout=None): + calls.append(timeout) + if len(calls) < 3: + raise TimeoutError + return 0 + + with patch("app.features.ytdlp.patches.sys.platform", "win32"): + with ( + patch("yt_dlp.utils.Popen", FakePopen), + patch("app.features.ytdlp.patches.subprocess.TimeoutExpired", TimeoutError), + ): + patch_windows_popen_wait() + result = FakePopen().wait() + + assert result == 0 + assert calls == [0.1, 0.1, 0.1] + + def test_windows_wait_patch_preserves_explicit_timeout(self) -> None: + calls: list[float | None] = [] + + class FakePopen: + _ytptube_wait_patched = False + + def wait(self, timeout=None): + calls.append(timeout) + return 0 + + with patch("app.features.ytdlp.patches.sys.platform", "win32"): + with patch("yt_dlp.utils.Popen", FakePopen): + patch_windows_popen_wait() + result = FakePopen().wait(timeout=5) + + assert result == 0 + assert calls == [5] + @patch("app.features.ytdlp.ytdlp.yt_dlp.YoutubeDL._delete_downloaded_files") def test_delete_downloaded_files_skips_when_interrupted(self, mock_super_delete) -> None: """Test _delete_downloaded_files skips cleanup when _interrupted is True.""" diff --git a/app/features/ytdlp/utils.py b/app/features/ytdlp/utils.py index 50e1ab2f..93034d89 100644 --- a/app/features/ytdlp/utils.py +++ b/app/features/ytdlp/utils.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import Any +from app.features.ytdlp.patches import apply_ytdlp_patches from app.features.ytdlp.ytdlp import YTDLP from app.library.Utils import merge_dict, timed_lru_cache @@ -143,52 +144,6 @@ def critical(self, msg, *args, **kwargs): self._log(logging.CRITICAL, msg, *args, **kwargs) -def patch_metadataparser() -> None: - """ - Patches yt_dlp MetadataParserPP action to handle subprocess pickling issues. - """ - try: - from yt_dlp.postprocessor.metadataparser import MetadataParserPP - from yt_dlp.utils import Namespace - except Exception as exc: - LOG.warning(f"Unable to import yt_dlp metadata parser for patching: {exc!s}") - return - - if getattr(MetadataParserPP.Actions, "_ytptube_patched", False): - return - - class _ActionNS(Namespace): - _ACTIONS_STR: list[str] = [] - - @staticmethod - def _get_name(func) -> str | None: - if not callable(func): - return None - - target = getattr(func, "__func__", func) - module_name = getattr(target, "__module__", None) - qual_name = getattr(target, "__qualname__", getattr(target, "__name__", None)) - - return f"{module_name}.{qual_name}" if module_name and qual_name else None - - def __contains__(self, candidate: object) -> bool: - if candidate in self.__dict__.values(): - return True - - if func_name := _ActionNS._get_name(candidate): - if len(_ActionNS._ACTIONS_STR) < 1: - _ActionNS._ACTIONS_STR.extend([_ActionNS._get_name(value) for value in self.__dict__.values()]) - - return func_name in _ActionNS._ACTIONS_STR - - return False - - actions_dict: dict[str, Any] = dict(MetadataParserPP.Actions.items_) - MetadataParserPP.Actions = _ActionNS(**actions_dict) - MetadataParserPP.Actions._ytptube_patched = True - LOG.debug("MetadataParserPP action namespace patch applied successfully.") - - def arg_converter( args: str, level: int | bool | None = None, @@ -222,10 +177,7 @@ def _default_opts(args: str): finally: yt_dlp.options.create_parser = create_parser - try: - patch_metadataparser() - except Exception as exc: - LOG.debug("Metadata parser patch failed to apply: %s", exc) + apply_ytdlp_patches() default_opts = _default_opts([]).ydl_opts diff --git a/app/features/ytdlp/ytdlp.py b/app/features/ytdlp/ytdlp.py index 1a1f0a84..cccba7e5 100644 --- a/app/features/ytdlp/ytdlp.py +++ b/app/features/ytdlp/ytdlp.py @@ -5,6 +5,7 @@ import yt_dlp from yt_dlp.utils import make_archive_id +from app.features.ytdlp.patches import apply_ytdlp_patches from app.library.cf_solver_handler import set_cf_handler @@ -52,13 +53,15 @@ class YTDLP(yt_dlp.YoutubeDL): _registered = False def __init__(self, params=None, auto_init=True): + apply_ytdlp_patches() + # Avoid yt-dlp preloading the archive file by stripping the param first orig_file = None - patched_params = None + patched_params: dict[str, Any] | None = None if params is not None: try: orig_file: str | None = params.get("download_archive") - patched_params: dict = dict(params) + patched_params = dict(params) if "download_archive" in patched_params: patched_params.pop("download_archive", None) except Exception: diff --git a/app/library/TerminalSessionManager.py b/app/library/TerminalSessionManager.py new file mode 100644 index 00000000..423dd131 --- /dev/null +++ b/app/library/TerminalSessionManager.py @@ -0,0 +1,698 @@ +from __future__ import annotations + +import asyncio +import errno +import json +import logging +import os +import shlex +import shutil +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from aiohttp import web + +from app.library.config import Config +from app.library.Services import Services +from app.library.Singleton import Singleton + +if TYPE_CHECKING: + from asyncio.events import AbstractEventLoop + from asyncio.subprocess import Process + + from aiohttp.web import Request + +LOG: logging.Logger = logging.getLogger("terminal_manager") + +ACTIVE_FILE_NAME = "active.json" +METADATA_FILE_NAME = "metadata.json" +TRANSCRIPT_FILE_NAME = "transcript.jsonl" +DEFAULT_DRAIN_TTL = 30.0 +DEFAULT_KEEPALIVE_INTERVAL = 15.0 +DEFAULT_SHUTDOWN_TIMEOUT = 5.0 + + +class TerminalSessionConflictError(RuntimeError): + pass + + +@dataclass(slots=True) +class ActiveTerminalSession: + session_id: str + task: asyncio.Task[None] + process: Process | None = None + subscribers: set[asyncio.Queue[dict[str, Any] | None]] = field(default_factory=set) + interrupted: bool = False + + +class TerminalSessionManager(metaclass=Singleton): + def __init__(self) -> None: + self.config: Config = Config.get_instance() + self.root_path: Path = Path(self.config.config_path) / "runtime" / "terminal" + self._lock = asyncio.Lock() + self._active: ActiveTerminalSession | None = None + self._drain_ttl: float = DEFAULT_DRAIN_TTL + self._keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL + self._shutdown_timeout: float = DEFAULT_SHUTDOWN_TIMEOUT + + @staticmethod + def get_instance() -> TerminalSessionManager: + return TerminalSessionManager() + + def attach(self, app: web.Application) -> None: + self._ensure_root() + Services.get_instance().add("terminal_manager", self) + app.on_startup.append(self.on_startup) + app.on_shutdown.append(self.on_shutdown) + + async def on_startup(self, _: web.Application) -> None: + await self.initialize() + + async def on_shutdown(self, _: web.Application) -> None: + session_id: str | None = None + task: asyncio.Task[None] | None = None + + async with self._lock: + runtime = self._active + if runtime is None: + return + + runtime.interrupted = True + session_id = runtime.session_id + task = runtime.task + + self._signal_process(runtime.process) + if runtime.process is None and task.done() is False: + task.cancel() + + for subscriber in list(runtime.subscribers): + subscriber.put_nowait(None) + + assert session_id is not None + assert task is not None + + try: + await asyncio.wait_for(asyncio.shield(task), timeout=self._shutdown_timeout) + return + except TimeoutError: + async with self._lock: + runtime = self._active + if runtime is not None and runtime.session_id == session_id: + self._signal_process(runtime.process, force=True) + if task.done() is False: + task.cancel() + + try: + await asyncio.wait_for(asyncio.shield(task), timeout=self._shutdown_timeout) + return + except TimeoutError: + LOG.warning("Terminal session '%s' did not finish during shutdown.", session_id) + + await self._force_interrupt_session(session_id) + + async def initialize(self) -> None: + async with self._lock: + self._ensure_root() + self._cleanup_expired_sessions(time.time()) + self._recover_orphaned_active_session(time.time()) + + async def cleanup(self) -> None: + async with self._lock: + self._cleanup_expired_sessions(time.time()) + + async def create_session(self, command: str) -> dict[str, Any]: + async with self._lock: + now = time.time() + self._cleanup_expired_sessions(now) + + active_session = self._get_active_session_locked() + if active_session is not None: + msg = "A terminal session is already active." + raise TerminalSessionConflictError(msg) + + session_id = uuid.uuid4().hex + session_dir = self._session_dir(session_id) + session_dir.mkdir(parents=True, exist_ok=True) + + metadata = { + "session_id": session_id, + "command": command, + "status": "starting", + "created_at": now, + "started_at": now, + "finished_at": None, + "expires_at": None, + "exit_code": None, + "last_sequence": 0, + } + self._write_json(self._metadata_path(session_id), metadata) + self._transcript_path(session_id).touch(exist_ok=True) + self._set_active_marker(session_id) + + task = asyncio.create_task( + self._run_session(session_id=session_id, command=command), name=f"terminal_{session_id}" + ) + self._active = ActiveTerminalSession(session_id=session_id, task=task) + return dict(metadata) + + async def get_active_session(self) -> dict[str, Any] | None: + async with self._lock: + self._cleanup_expired_sessions(time.time()) + metadata = self._get_active_session_locked() + return None if metadata is None else dict(metadata) + + async def get_session(self, session_id: str) -> dict[str, Any] | None: + async with self._lock: + self._cleanup_expired_sessions(time.time()) + metadata = self._load_metadata(session_id) + return None if metadata is None else dict(metadata) + + async def cancel_session(self, session_id: str) -> dict[str, Any]: + async with self._lock: + self._cleanup_expired_sessions(time.time()) + + metadata = self._load_metadata(session_id) + if metadata is None: + msg = f"Unknown terminal session '{session_id}'." + raise FileNotFoundError(msg) + + runtime = self._active + if runtime is None or runtime.session_id != session_id: + msg = "Terminal session is not active." + raise RuntimeError(msg) + + runtime.interrupted = True + self._signal_process(runtime.process) + if runtime.process is None and runtime.task.done() is False: + runtime.task.cancel() + + return dict(metadata) + + async def stream_session(self, session_id: str, request: Request) -> web.StreamResponse: + since = self._parse_since(request) + + response = web.StreamResponse( + status=web.HTTPOk.status_code, + headers={ + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + await response.prepare(request) + + queue: asyncio.Queue[dict[str, Any] | None] | None = None + last_sent = since + + try: + replay_events = self._read_transcript(session_id=session_id, since=since) + for event in replay_events: + if await self._emit_sse(request=request, response=response, event=event) is False: + return response + last_sent = event["seq"] + + replay_until = last_sent + async with self._lock: + metadata = self._load_metadata(session_id) + if metadata is not None: + replay_until = int(metadata.get("last_sequence", last_sent)) + + runtime = self._active + if runtime is not None and runtime.session_id == session_id: + queue = asyncio.Queue() + runtime.subscribers.add(queue) + + if replay_until > last_sent: + replay_events = self._read_transcript(session_id=session_id, since=last_sent, until=replay_until) + for event in replay_events: + if await self._emit_sse(request=request, response=response, event=event) is False: + return response + last_sent = event["seq"] + + if queue is None: + return response + + while True: + if self._is_request_disconnected(request): + break + + try: + event = await asyncio.wait_for(queue.get(), timeout=self._keepalive_interval) + except TimeoutError: + if await self._emit_keepalive(request=request, response=response) is False: + break + continue + + if event is None: + break + + if event["seq"] <= last_sent: + continue + + if await self._emit_sse(request=request, response=response, event=event) is False: + break + last_sent = event["seq"] + finally: + if queue is not None: + async with self._lock: + if self._active is not None and self._active.session_id == session_id: + self._active.subscribers.discard(queue) + + try: + await response.write_eof() + except ConnectionResetError: + pass + + return response + + async def _run_session(self, session_id: str, command: str) -> None: + return_code = -1 + final_status = "completed" + proc: Process | None = None + read_task: asyncio.Task[None] | None = None + master_fd: int | None = None + + try: + LOG.info("Cli command from client. '%s'", command) + args = ["yt-dlp", *shlex.split(command, posix=os.name != "nt")] + env_vars = self._build_env() + + pty_handles = self._open_pty() + if pty_handles is None: + stdin_arg = asyncio.subprocess.DEVNULL + stdout_arg = asyncio.subprocess.PIPE + stderr_arg = asyncio.subprocess.STDOUT + use_pty = False + slave_fd = None + else: + master_fd, slave_fd = pty_handles + stdin_arg = asyncio.subprocess.DEVNULL + stdout_arg = slave_fd + stderr_arg = slave_fd + use_pty = True + + creationflags = 0 + if os.name == "nt": + import subprocess + + creationflags = subprocess.CREATE_NO_WINDOW + + proc = await asyncio.create_subprocess_exec( + *args, + cwd=self.config.download_path, + stdin=stdin_arg, + stdout=stdout_arg, + stderr=stderr_arg, + env=env_vars, + creationflags=creationflags, + ) + + async with self._lock: + metadata = self._load_metadata(session_id) + if metadata is not None: + metadata["status"] = "running" + self._write_json(self._metadata_path(session_id), metadata) + + if self._active is not None and self._active.session_id == session_id: + self._active.process = proc + + if use_pty is True and slave_fd is not None: + try: + os.close(slave_fd) + except Exception as exc: + LOG.error("Error closing PTY. '%s'.", str(exc)) + + read_task = asyncio.create_task( + self._read_process_output(session_id=session_id, proc=proc, use_pty=use_pty, master_fd=master_fd), + name=f"terminal_reader_{session_id}", + ) + + return_code = await proc.wait() + await read_task + except asyncio.CancelledError: + final_status = "interrupted" + except Exception as exc: + final_status = "failed" + LOG.error("CLI execute exception was thrown.") + LOG.exception(exc) + await self._append_event(session_id, "output", {"type": "stderr", "line": str(exc)}) + finally: + final_status = await self._resolve_final_status(session_id=session_id, status=final_status) + + if final_status == "interrupted" and proc is not None and getattr(proc, "returncode", None) is None: + self._signal_process(proc) + try: + return_code = await asyncio.wait_for(proc.wait(), timeout=self._shutdown_timeout) + except TimeoutError: + self._signal_process(proc, force=True) + try: + return_code = await asyncio.wait_for(proc.wait(), timeout=self._shutdown_timeout) + except TimeoutError: + LOG.warning("Terminal session '%s' process did not exit cleanly.", session_id) + + if proc is not None: + proc_returncode = getattr(proc, "returncode", None) + if proc_returncode is not None: + return_code = int(proc_returncode) + + await self._append_event(session_id, "close", {"exitcode": return_code}) + await self._finalize_session(session_id=session_id, status=final_status, exit_code=return_code) + + if read_task is not None and not read_task.done(): + read_task.cancel() + + if master_fd is not None: + try: + os.close(master_fd) + except OSError: + pass + + async def _read_process_output(self, session_id: str, proc: Process, use_pty: bool, master_fd: int | None) -> None: + if use_pty is False: + assert proc.stdout is not None + async for raw_line in proc.stdout: + line = raw_line.rstrip(b"\n") + await self._append_event( + session_id, + "output", + {"type": "stdout", "line": line.decode("utf-8", errors="replace")}, + ) + return + + assert master_fd is not None + loop: AbstractEventLoop = asyncio.get_running_loop() + buffer = b"" + + while True: + try: + chunk = await loop.run_in_executor(None, lambda: os.read(master_fd, 1024)) + except OSError as exc: + if exc.errno == errno.EIO: + break + raise + + if not chunk: + if buffer: + await self._append_event( + session_id, + "output", + {"type": "stdout", "line": buffer.decode("utf-8", errors="replace")}, + ) + break + + buffer += chunk + *lines, buffer = buffer.split(b"\n") + for line in lines: + await self._append_event( + session_id, + "output", + {"type": "stdout", "line": line.decode("utf-8", errors="replace")}, + ) + + async def _append_event(self, session_id: str, event: str, data: dict[str, Any]) -> dict[str, Any]: + async with self._lock: + metadata = self._load_metadata(session_id) + if metadata is None: + msg = f"Unknown terminal session '{session_id}'." + raise FileNotFoundError(msg) + + next_sequence = int(metadata.get("last_sequence", 0)) + 1 + metadata["last_sequence"] = next_sequence + self._write_json(self._metadata_path(session_id), metadata) + + record = {"seq": next_sequence, "event": event, "data": data} + with self._transcript_path(session_id).open("a", encoding="utf-8") as handle: + handle.write(json.dumps(record) + "\n") + + if self._active is not None and self._active.session_id == session_id: + for subscriber in list(self._active.subscribers): + subscriber.put_nowait(record) + + return record + + async def _finalize_session(self, session_id: str, status: str, exit_code: int) -> None: + async with self._lock: + metadata = self._load_metadata(session_id) + if metadata is None: + return + + now = time.time() + metadata["status"] = status + metadata["exit_code"] = exit_code + metadata["finished_at"] = now + metadata["expires_at"] = now + self._drain_ttl + self._write_json(self._metadata_path(session_id), metadata) + + if self._load_active_marker() == session_id: + self._clear_active_marker() + + runtime = self._active + if runtime is not None and runtime.session_id == session_id: + subscribers = list(runtime.subscribers) + self._active = None + for subscriber in subscribers: + subscriber.put_nowait(None) + + def _get_active_session_locked(self) -> dict[str, Any] | None: + session_id = self._load_active_marker() + if session_id is None: + return None + + metadata = self._load_metadata(session_id) + if metadata is None: + self._clear_active_marker() + return None + + return metadata + + def _recover_orphaned_active_session(self, now: float) -> None: + session_id = self._load_active_marker() + if session_id is None: + return + + metadata = self._load_metadata(session_id) + if metadata is None: + self._clear_active_marker() + return + + metadata["status"] = "interrupted" + metadata["finished_at"] = now + metadata["expires_at"] = now + self._drain_ttl + metadata["exit_code"] = -1 if metadata.get("exit_code") is None else metadata["exit_code"] + self._write_json(self._metadata_path(session_id), metadata) + self._clear_active_marker() + + def _cleanup_expired_sessions(self, now: float) -> None: + self._ensure_root() + active_session_id = self._load_active_marker() + + for path in self.root_path.iterdir(): + if path.name == ACTIVE_FILE_NAME or path.is_dir() is False: + continue + + metadata = self._load_metadata(path.name) + if metadata is None: + shutil.rmtree(path, ignore_errors=True) + continue + + expires_at = metadata.get("expires_at") + if expires_at is None or float(expires_at) > now: + continue + + if active_session_id == path.name: + self._clear_active_marker() + active_session_id = None + + shutil.rmtree(path, ignore_errors=True) + + def _parse_since(self, request: Request) -> int: + values: list[int] = [] + candidates = [request.query.get("since"), request.headers.get("Last-Event-ID")] + for candidate in candidates: + if candidate in (None, ""): + continue + try: + value = int(candidate) + except ValueError as exc: + msg = "Resume cursor must be an integer." + raise ValueError(msg) from exc + + if value < 0: + msg = "Resume cursor must be zero or greater." + raise ValueError(msg) + + values.append(value) + + return max(values, default=0) + + async def _force_interrupt_session(self, session_id: str) -> None: + async with self._lock: + metadata = self._load_metadata(session_id) + if metadata is None: + return + + now = time.time() + metadata["status"] = "interrupted" + metadata["finished_at"] = now + metadata["expires_at"] = now + self._drain_ttl + metadata["exit_code"] = -1 if metadata.get("exit_code") is None else metadata["exit_code"] + self._write_json(self._metadata_path(session_id), metadata) + + if self._load_active_marker() == session_id: + self._clear_active_marker() + + runtime = self._active + if runtime is not None and runtime.session_id == session_id: + subscribers = list(runtime.subscribers) + self._active = None + for subscriber in subscribers: + subscriber.put_nowait(None) + + async def _resolve_final_status(self, session_id: str, status: str) -> str: + async with self._lock: + runtime = self._active + if runtime is not None and runtime.session_id == session_id and runtime.interrupted: + return "interrupted" + + metadata = self._load_metadata(session_id) + if metadata is not None and metadata.get("status") == "interrupted": + return "interrupted" + + return status + + def _signal_process(self, process: Process | None, *, force: bool = False) -> None: + if process is None or getattr(process, "returncode", None) is not None: + return + + action = process.kill if force else process.terminate + try: + action() + except ProcessLookupError: + return + + async def _emit_sse(self, request: Request, response: web.StreamResponse, event: dict[str, Any]) -> bool: + if self._is_request_disconnected(request): + return False + + payload = f"id: {event['seq']}\nevent: {event['event']}\ndata: {json.dumps(event['data'])}\n\n" + try: + await response.write(payload.encode("utf-8")) + except ConnectionResetError: + return False + + return True + + async def _emit_keepalive(self, request: Request, response: web.StreamResponse) -> bool: + if self._is_request_disconnected(request): + return False + + try: + await response.write(b": keepalive\n\n") + except ConnectionResetError: + return False + + return True + + def _read_transcript(self, session_id: str, since: int, until: int | None = None) -> list[dict[str, Any]]: + transcript_path = self._transcript_path(session_id) + if transcript_path.exists() is False: + return [] + + events: list[dict[str, Any]] = [] + with transcript_path.open("r", encoding="utf-8") as handle: + for raw_line in handle: + line = raw_line.strip() + if not line: + continue + + event = json.loads(line) + sequence = int(event["seq"]) + if sequence <= since: + continue + if until is not None and sequence > until: + break + events.append(event) + + return events + + def _build_env(self) -> dict[str, str]: + env_vars = os.environ.copy() + env_vars.update( + { + "PWD": self.config.download_path, + "FORCE_COLOR": "1", + "PYTHONUNBUFFERED": "1", + } + ) + + if os.name != "nt": + env_vars.update( + { + "TERM": "xterm-256color", + "LANG": "en_US.UTF-8", + "LC_ALL": "en_US.UTF-8", + "SHELL": "/bin/bash", + } + ) + + return env_vars + + def _open_pty(self) -> tuple[int, int] | None: + try: + import pty + + return pty.openpty() + except ImportError: + return None + + def _is_request_disconnected(self, request: Request) -> bool: + return request.transport is None or request.transport.is_closing() + + def _ensure_root(self) -> None: + self.root_path.mkdir(parents=True, exist_ok=True) + + def _session_dir(self, session_id: str) -> Path: + return self.root_path / session_id + + def _metadata_path(self, session_id: str) -> Path: + return self._session_dir(session_id) / METADATA_FILE_NAME + + def _transcript_path(self, session_id: str) -> Path: + return self._session_dir(session_id) / TRANSCRIPT_FILE_NAME + + def _active_marker_path(self) -> Path: + return self.root_path / ACTIVE_FILE_NAME + + def _set_active_marker(self, session_id: str) -> None: + self._write_json(self._active_marker_path(), {"session_id": session_id}) + + def _clear_active_marker(self) -> None: + self._active_marker_path().unlink(missing_ok=True) + + def _load_active_marker(self) -> str | None: + data = self._read_json(self._active_marker_path()) + if not isinstance(data, dict): + return None + session_id = data.get("session_id") + return session_id if isinstance(session_id, str) and session_id else None + + def _load_metadata(self, session_id: str) -> dict[str, Any] | None: + data = self._read_json(self._metadata_path(session_id)) + return data if isinstance(data, dict) else None + + def _read_json(self, path: Path) -> dict[str, Any] | None: + if path.exists() is False: + return None + + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + def _write_json(self, path: Path, payload: dict[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + temp_path = path.with_suffix(f"{path.suffix}.tmp") + with temp_path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle) + temp_path.replace(path) diff --git a/app/library/downloads/core.py b/app/library/downloads/core.py index 52911f05..778e1906 100644 --- a/app/library/downloads/core.py +++ b/app/library/downloads/core.py @@ -2,6 +2,7 @@ from __future__ import annotations +import _thread import asyncio import logging import os @@ -217,7 +218,7 @@ def trigger_live_cancel() -> None: if "posix" == os.name: os.kill(os.getpid(), signal.SIGINT) else: - signal.raise_signal(signal.SIGINT) + _thread.interrupt_main() threading.Thread(target=trigger_live_cancel, name=f"cancel-watch-{self.id}", daemon=True).start() diff --git a/app/main.py b/app/main.py index 78444443..f2a6dca0 100644 --- a/app/main.py +++ b/app/main.py @@ -36,6 +36,7 @@ from app.library.Scheduler import Scheduler from app.library.Services import Services from app.library.sqlite_store import SqliteStore +from app.library.TerminalSessionManager import TerminalSessionManager from app.library.UpdateChecker import UpdateChecker LOG = logging.getLogger("app") @@ -118,6 +119,7 @@ def start(self, host: str | None = None, port: int | None = None, cb=None): BackgroundWorker.get_instance().attach(self._app) Scheduler.get_instance().attach(self._app) Cache.get_instance().attach(self._app) + TerminalSessionManager.get_instance().attach(self._app) self._socket.attach(self._app) self._http.attach(self._app) diff --git a/app/routes/api/system.py b/app/routes/api/system.py index e1744a4e..59fb89fb 100644 --- a/app/routes/api/system.py +++ b/app/routes/api/system.py @@ -1,11 +1,7 @@ import asyncio -import errno import logging -import os -import shlex import time from pathlib import Path -from typing import TYPE_CHECKING from aiohttp import web from aiohttp.web import Request, Response @@ -18,14 +14,10 @@ from app.library.encoder import Encoder from app.library.Events import EventBus, Events from app.library.router import route +from app.library.TerminalSessionManager import TerminalSessionConflictError, TerminalSessionManager from app.library.UpdateChecker import UpdateChecker from app.library.Utils import list_folders -if TYPE_CHECKING: - from asyncio import Task - from asyncio.events import AbstractEventLoop - from asyncio.subprocess import Process - LOG: logging.Logger = logging.getLogger(__name__) @@ -231,14 +223,7 @@ async def check_updates(config: Config, encoder: Encoder, update_checker: Update ) -@route("POST", "api/system/terminal", "system.terminal") -async def stream_terminal(request: Request, config: Config, encoder: Encoder) -> Response | web.StreamResponse: - if not config.console_enabled: - return web.json_response( - {"error": "Console feature is disabled."}, - status=web.HTTPForbidden.status_code, - ) - +async def _validate_terminal_command_request(request: Request) -> str | Response: if not request.can_read_body: return web.json_response( {"error": "Request body is required."}, @@ -259,136 +244,122 @@ async def stream_terminal(request: Request, config: Config, encoder: Encoder) -> status=web.HTTPBadRequest.status_code, ) - response = web.StreamResponse( + return raw_command + + +@route("POST", "api/system/terminal", "system.terminal") +async def create_terminal_session( + request: Request, config: Config, encoder: Encoder, terminal_manager: TerminalSessionManager +) -> Response: + if not config.console_enabled: + return web.json_response( + {"error": "Console feature is disabled."}, + status=web.HTTPForbidden.status_code, + ) + + raw_command = await _validate_terminal_command_request(request) + if isinstance(raw_command, Response): + return raw_command + + try: + metadata = await terminal_manager.create_session(raw_command) + except TerminalSessionConflictError as exc: + return web.json_response( + {"error": str(exc)}, + status=web.HTTPConflict.status_code, + ) + + return web.json_response(data=metadata, status=web.HTTPOk.status_code, dumps=encoder.encode) + + +@route("GET", "api/system/terminal/active", "system.terminal.active") +async def get_active_terminal_session( + config: Config, encoder: Encoder, terminal_manager: TerminalSessionManager +) -> Response: + if not config.console_enabled: + return web.json_response( + {"error": "Console feature is disabled."}, + status=web.HTTPForbidden.status_code, + ) + + metadata = await terminal_manager.get_active_session() + return web.json_response(data=metadata, status=web.HTTPOk.status_code, dumps=encoder.encode) + + +@route("GET", "api/system/terminal/{session_id}", "system.terminal.session") +async def get_terminal_session( + request: Request, config: Config, encoder: Encoder, terminal_manager: TerminalSessionManager +) -> Response: + if not config.console_enabled: + return web.json_response( + {"error": "Console feature is disabled."}, + status=web.HTTPForbidden.status_code, + ) + + session_id = request.match_info.get("session_id", "") + metadata = await terminal_manager.get_session(session_id) + if metadata is None: + return web.json_response( + {"error": "Terminal session not found."}, + status=web.HTTPNotFound.status_code, + ) + + return web.json_response(data=metadata, status=web.HTTPOk.status_code, dumps=encoder.encode) + + +@route("DELETE", "api/system/terminal/{session_id}", "system.terminal.cancel") +async def cancel_terminal_session( + request: Request, config: Config, encoder: Encoder, terminal_manager: TerminalSessionManager +) -> Response: + if not config.console_enabled: + return web.json_response( + {"error": "Console feature is disabled."}, + status=web.HTTPForbidden.status_code, + ) + + session_id = request.match_info.get("session_id", "") + try: + await terminal_manager.cancel_session(session_id) + except FileNotFoundError: + return web.json_response( + {"error": "Terminal session not found."}, + status=web.HTTPNotFound.status_code, + ) + except RuntimeError as exc: + return web.json_response( + {"error": str(exc)}, + status=web.HTTPConflict.status_code, + ) + + return web.json_response( + data={"message": "Terminal session cancellation requested.", "session_id": session_id}, status=web.HTTPOk.status_code, - headers={ - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - }, + dumps=encoder.encode, ) - await response.prepare(request) - async def emit_event(event: str, data: dict) -> None: - if request.transport is None or request.transport.is_closing(): - return - payload: str = f"event: {event}\ndata: {encoder.encode(data)}\n\n" - await response.write(payload.encode("utf-8")) - returncode: int = -1 - try: - LOG.info("Cli command from client. '%s'", raw_command) - - args: list[str] = ["yt-dlp", *shlex.split(raw_command, posix=os.name != "nt")] - env_vars: dict[str, str] = os.environ.copy() - env_vars.update( - { - "PWD": config.download_path, - "FORCE_COLOR": "1", - "PYTHONUNBUFFERED": "1", - } +@route("GET", "api/system/terminal/{session_id}/stream", "system.terminal.stream") +async def stream_terminal_session( + request: Request, config: Config, terminal_manager: TerminalSessionManager +) -> Response | web.StreamResponse: + if not config.console_enabled: + return web.json_response( + {"error": "Console feature is disabled."}, + status=web.HTTPForbidden.status_code, ) - if "nt" != os.name: - env_vars.update( - { - "TERM": "xterm-256color", - "LANG": "en_US.UTF-8", - "LC_ALL": "en_US.UTF-8", - "SHELL": "/bin/bash", - } - ) - - try: - import pty - - master_fd, slave_fd = pty.openpty() - stdin_arg = asyncio.subprocess.DEVNULL - stdout_arg = stderr_arg = slave_fd - use_pty = True - except ImportError: - use_pty = False - master_fd = slave_fd = None - stdin_arg = asyncio.subprocess.DEVNULL - stdout_arg = asyncio.subprocess.PIPE - stderr_arg = asyncio.subprocess.STDOUT - - creationflags = 0 - if os.name == "nt": - import subprocess - - creationflags = subprocess.CREATE_NO_WINDOW - - proc: Process = await asyncio.create_subprocess_exec( - *args, - cwd=config.download_path, - stdin=stdin_arg, - stdout=stdout_arg, - stderr=stderr_arg, - env=env_vars, - creationflags=creationflags, + session_id = request.match_info.get("session_id", "") + metadata = await terminal_manager.get_session(session_id) + if metadata is None: + return web.json_response( + {"error": "Terminal session not found."}, + status=web.HTTPNotFound.status_code, ) - if use_pty: - assert slave_fd is not None - try: - os.close(slave_fd) - except Exception as e: - LOG.error("Error closing PTY. '%s'.", str(e)) - - async def reader() -> None: - if use_pty is False: - assert proc.stdout is not None - async for raw_line in proc.stdout: - line = raw_line.rstrip(b"\n") - await emit_event("output", {"type": "stdout", "line": line.decode("utf-8", errors="replace")}) - return - - assert master_fd is not None - read_fd = master_fd - loop: AbstractEventLoop = asyncio.get_running_loop() - buffer: bytes = b"" - while True: - try: - chunk: bytes = await loop.run_in_executor(None, lambda: os.read(read_fd, 1024)) - except OSError as e: - if e.errno == errno.EIO: - break - raise - - if not chunk: - if buffer: - await emit_event( - "output", - {"type": "stdout", "line": buffer.decode("utf-8", errors="replace")}, - ) - break - - buffer += chunk - *lines, buffer = buffer.split(b"\n") - - for line in lines: - await emit_event( - "output", - {"type": "stdout", "line": line.decode("utf-8", errors="replace")}, - ) - if master_fd is None: - return - try: - os.close(master_fd) - except Exception as e: - LOG.error("Error closing PTY. '%s'.", str(e)) - - read_task: Task = asyncio.create_task(reader(), name="cli_reader") - - returncode = await proc.wait() - await read_task - except Exception as e: - LOG.error("CLI execute exception was thrown.") - LOG.exception(e) - await emit_event("output", {"type": "stderr", "line": str(e)}) - finally: - await emit_event("close", {"exitcode": returncode}) - await response.write_eof() - - return response + try: + return await terminal_manager.stream_session(session_id=session_id, request=request) + except ValueError as exc: + return web.json_response( + {"error": str(exc)}, + status=web.HTTPBadRequest.status_code, + ) diff --git a/app/tests/test_download.py b/app/tests/test_download.py index 96c03f73..c1601a7f 100644 --- a/app/tests/test_download.py +++ b/app/tests/test_download.py @@ -371,20 +371,32 @@ def __init__(self, params): def process_ie_result(self, ie_result, download): return ie_result, download + def download(self, url_list): + return 0 + signal_mock = Mock() - thread_instance = Mock(start=Mock()) - thread_mock = Mock(return_value=thread_instance) + thread_instances: list[Mock] = [] + + def build_thread(*_args, **_kwargs): + thread = Mock(start=Mock()) + thread_instances.append(thread) + return thread + + thread_mock = Mock(side_effect=build_thread) monkeypatch.setattr("app.library.downloads.core.YTDLP", FakeYTDLP) monkeypatch.setattr("app.library.downloads.core.signal.signal", signal_mock) monkeypatch.setattr("app.library.downloads.core.threading.Thread", thread_mock) download._download() - thread_mock.assert_called_once() - thread_instance.start.assert_called_once() signal_mock.assert_any_call(signal.SIGINT, signal.default_int_handler) - target = thread_mock.call_args.kwargs["target"] + live_cancel_thread = next( + call for call in thread_mock.call_args_list if call.kwargs.get("name", "").startswith("cancel-watch-") + ) + live_cancel_thread_index = thread_mock.call_args_list.index(live_cancel_thread) + thread_instances[live_cancel_thread_index].start.assert_called_once() + target = live_cancel_thread.kwargs["target"] ydl = created_ydl[0] if "posix" == os.name: @@ -395,9 +407,9 @@ def process_ie_result(self, ie_result, download): target() mock_kill.assert_called_once_with(12345, signal.SIGINT) else: - with patch("app.library.downloads.core.signal.raise_signal") as mock_raise_signal: + with patch("app.library.downloads.core._thread.interrupt_main") as mock_interrupt_main: target() - mock_raise_signal.assert_called_once_with(signal.SIGINT) + mock_interrupt_main.assert_called_once_with() assert ydl._interrupted is True ydl.to_screen.assert_called_once_with("[info] Interrupt received, exiting cleanly...") diff --git a/app/tests/test_terminal_session_manager.py b/app/tests/test_terminal_session_manager.py new file mode 100644 index 00000000..f2962db4 --- /dev/null +++ b/app/tests/test_terminal_session_manager.py @@ -0,0 +1,449 @@ +import asyncio +import json +from pathlib import Path +from typing import Any, cast + +import pytest + +from app.library.Services import Services +from app.library.TerminalSessionManager import TerminalSessionManager +from app.library.config import Config +from app.library.encoder import Encoder +from app.routes.api.system import ( + cancel_terminal_session, + create_terminal_session, + get_active_terminal_session, + get_terminal_session, + stream_terminal_session, +) + + +class _FakeTransport: + def __init__(self) -> None: + self._closing = False + + def is_closing(self) -> bool: + return self._closing + + +class _FakeRequest: + def __init__( + self, + *, + payload: dict | None = None, + match_info: dict[str, str] | None = None, + query: dict[str, str] | None = None, + headers: dict[str, str] | None = None, + can_read_body: bool = False, + ) -> None: + self._payload = payload + self.match_info = match_info or {} + self.query = query or {} + self.headers = headers or {} + self.can_read_body = can_read_body + self.transport = _FakeTransport() + + async def json(self) -> dict | None: + return self._payload + + +class _FakeStreamResponse: + def __init__(self, *, status: int, headers: dict[str, str]) -> None: + self.status = status + self.headers = headers + self.payload = bytearray() + self.prepared = False + self.closed = False + + async def prepare(self, _request: _FakeRequest) -> "_FakeStreamResponse": + self.prepared = True + return self + + async def write(self, data: bytes) -> None: + self.payload.extend(data) + + async def write_eof(self) -> None: + self.closed = True + + +class _FakeStdout: + def __init__(self, lines: list[bytes]) -> None: + self._lines = lines + + def __aiter__(self) -> "_FakeStdout": + return self + + async def __anext__(self) -> bytes: + if not self._lines: + raise StopAsyncIteration + return self._lines.pop(0) + + +class _BlockingProc: + def __init__(self, done_event: asyncio.Event) -> None: + self.stdout = _FakeStdout([]) + self._done_event = done_event + self.returncode: int | None = None + + async def wait(self) -> int: + await self._done_event.wait() + self.returncode = 0 + return 0 + + +class _CompletedProc: + def __init__(self, lines: list[bytes], exit_code: int = 0) -> None: + self.stdout = _FakeStdout(lines) + self._exit_code = exit_code + self.returncode: int | None = None + + async def wait(self) -> int: + await asyncio.sleep(0) + self.returncode = self._exit_code + return self._exit_code + + +class _TerminableProc: + def __init__(self) -> None: + self.stdout = _FakeStdout([]) + self.returncode: int | None = None + self.terminate_calls = 0 + self.kill_calls = 0 + self.wait_started = asyncio.Event() + self._done_event = asyncio.Event() + + def terminate(self) -> None: + self.terminate_calls += 1 + if self.returncode is None: + self.returncode = -15 + self._done_event.set() + + def kill(self) -> None: + self.kill_calls += 1 + if self.returncode is None: + self.returncode = -9 + self._done_event.set() + + async def wait(self) -> int: + self.wait_started.set() + await self._done_event.wait() + assert self.returncode is not None + return self.returncode + + +@pytest.fixture +def terminal_setup(tmp_path: Path) -> tuple[Config, TerminalSessionManager, Encoder]: + Services._reset_singleton() + Config._reset_singleton() + TerminalSessionManager._reset_singleton() + + config = Config.get_instance() + config.console_enabled = True + config.config_path = str(tmp_path / "config") + config.download_path = str(tmp_path / "downloads") + Path(config.config_path).mkdir(parents=True, exist_ok=True) + Path(config.download_path).mkdir(parents=True, exist_ok=True) + + manager = TerminalSessionManager.get_instance() + encoder = Encoder() + return config, manager, encoder + + +class TestTerminalSessionRoutes: + @pytest.mark.asyncio + async def test_start_returns_session_metadata_and_active_conflict( + self, terminal_setup: tuple[Config, TerminalSessionManager, Encoder], monkeypatch: pytest.MonkeyPatch + ) -> None: + config, manager, encoder = terminal_setup + await manager.initialize() + + done_event = asyncio.Event() + + async def fake_create_subprocess_exec(*_args, **_kwargs): + return _BlockingProc(done_event) + + monkeypatch.setattr( + "app.library.TerminalSessionManager.asyncio.create_subprocess_exec", fake_create_subprocess_exec + ) + monkeypatch.setattr(manager, "_open_pty", lambda: None) + + request = _FakeRequest(payload={"command": "--help"}, can_read_body=True) + response = await create_terminal_session(request, config, encoder, manager) + payload = json.loads(response.body.decode("utf-8")) + + assert 200 == response.status + assert payload["session_id"] + assert "starting" == payload["status"] + + await asyncio.sleep(0) + + conflict = await create_terminal_session(request, config, encoder, manager) + assert 409 == conflict.status + assert b"already active" in conflict.body.lower() + + active = await get_active_terminal_session(config, encoder, manager) + active_payload = json.loads(active.body.decode("utf-8")) + assert payload["session_id"] == active_payload["session_id"] + + assert manager._active is not None + task = manager._active.task + done_event.set() + await task + + @pytest.mark.asyncio + async def test_stream_endpoint_replays_persisted_events_and_resume_cursor( + self, terminal_setup: tuple[Config, TerminalSessionManager, Encoder], monkeypatch: pytest.MonkeyPatch + ) -> None: + config, manager, encoder = terminal_setup + await manager.initialize() + + async def fake_create_subprocess_exec(*_args, **_kwargs): + return _CompletedProc([b"first\n", b"second\n"]) + + monkeypatch.setattr( + "app.library.TerminalSessionManager.asyncio.create_subprocess_exec", fake_create_subprocess_exec + ) + monkeypatch.setattr(manager, "_open_pty", lambda: None) + monkeypatch.setattr("app.library.TerminalSessionManager.web.StreamResponse", _FakeStreamResponse) + + start_request = _FakeRequest(payload={"command": "--version"}, can_read_body=True) + start_response = await create_terminal_session(start_request, config, encoder, manager) + session_id = json.loads(start_response.body.decode("utf-8"))["session_id"] + + assert manager._active is not None + task = manager._active.task + await task + + status_response = await get_terminal_session( + _FakeRequest(match_info={"session_id": session_id}), config, encoder, manager + ) + status_payload = json.loads(status_response.body.decode("utf-8")) + assert "completed" == status_payload["status"] + assert 3 == status_payload["last_sequence"] + assert 0 == status_payload["exit_code"] + + stream_request = _FakeRequest(match_info={"session_id": session_id}) + stream_response = await stream_terminal_session(stream_request, config, manager) + stream_payload = stream_response.payload.decode("utf-8") + + assert "id: 1" in stream_payload + assert "id: 2" in stream_payload + assert "id: 3" in stream_payload + assert 'data: {"type": "stdout", "line": "first"}' in stream_payload + assert 'data: {"exitcode": 0}' in stream_payload + + resumed_request = _FakeRequest(match_info={"session_id": session_id}, query={"since": "1"}) + resumed_response = await stream_terminal_session(resumed_request, config, manager) + resumed_payload = resumed_response.payload.decode("utf-8") + + assert "id: 1" not in resumed_payload + assert "id: 2" in resumed_payload + assert "id: 3" in resumed_payload + + @pytest.mark.asyncio + async def test_completed_session_expires_after_drain_window( + self, terminal_setup: tuple[Config, TerminalSessionManager, Encoder], monkeypatch: pytest.MonkeyPatch + ) -> None: + config, manager, encoder = terminal_setup + manager._drain_ttl = 0.05 + await manager.initialize() + + async def fake_create_subprocess_exec(*_args, **_kwargs): + return _CompletedProc([b"done\n"]) + + monkeypatch.setattr( + "app.library.TerminalSessionManager.asyncio.create_subprocess_exec", fake_create_subprocess_exec + ) + monkeypatch.setattr(manager, "_open_pty", lambda: None) + + start_request = _FakeRequest(payload={"command": "--help"}, can_read_body=True) + start_response = await create_terminal_session(start_request, config, encoder, manager) + session_id = json.loads(start_response.body.decode("utf-8"))["session_id"] + + assert manager._active is not None + task = manager._active.task + await task + + before_expiry = await manager.get_session(session_id) + assert before_expiry is not None + + await asyncio.sleep(0.06) + + expired = await manager.get_session(session_id) + assert expired is None + assert not (manager.root_path / session_id).exists() + + @pytest.mark.asyncio + async def test_shutdown_interrupts_active_session_and_clears_active_marker( + self, terminal_setup: tuple[Config, TerminalSessionManager, Encoder], monkeypatch: pytest.MonkeyPatch + ) -> None: + config, manager, encoder = terminal_setup + manager._shutdown_timeout = 0.05 + await manager.initialize() + + proc = _TerminableProc() + + async def fake_create_subprocess_exec(*_args, **_kwargs): + return proc + + monkeypatch.setattr( + "app.library.TerminalSessionManager.asyncio.create_subprocess_exec", fake_create_subprocess_exec + ) + monkeypatch.setattr(manager, "_open_pty", lambda: None) + + start_request = _FakeRequest(payload={"command": "--help"}, can_read_body=True) + start_response = await create_terminal_session(start_request, config, encoder, manager) + session_id = json.loads(start_response.body.decode("utf-8"))["session_id"] + + await proc.wait_started.wait() + await manager.on_shutdown(cast(Any, None)) + + metadata = await manager.get_session(session_id) + transcript = manager._read_transcript(session_id=session_id, since=0) + + assert metadata is not None + assert "interrupted" == metadata["status"] + assert -15 == metadata["exit_code"] + assert metadata["finished_at"] is not None + assert metadata["expires_at"] is not None + assert 1 == proc.terminate_calls + assert 0 == proc.kill_calls + assert manager._active is None + assert manager._load_active_marker() is None + assert "close" == transcript[-1]["event"] + assert -15 == transcript[-1]["data"]["exitcode"] + + @pytest.mark.asyncio + async def test_stream_endpoint_emits_keepalive_for_silent_active_session( + self, terminal_setup: tuple[Config, TerminalSessionManager, Encoder], monkeypatch: pytest.MonkeyPatch + ) -> None: + config, manager, encoder = terminal_setup + manager._keepalive_interval = 0.01 + await manager.initialize() + + done_event = asyncio.Event() + + async def fake_create_subprocess_exec(*_args, **_kwargs): + return _BlockingProc(done_event) + + monkeypatch.setattr( + "app.library.TerminalSessionManager.asyncio.create_subprocess_exec", fake_create_subprocess_exec + ) + monkeypatch.setattr(manager, "_open_pty", lambda: None) + monkeypatch.setattr("app.library.TerminalSessionManager.web.StreamResponse", _FakeStreamResponse) + + start_request = _FakeRequest(payload={"command": "--version"}, can_read_body=True) + start_response = await create_terminal_session(start_request, config, encoder, manager) + session_id = json.loads(start_response.body.decode("utf-8"))["session_id"] + assert manager._active is not None + session_task = manager._active.task + + stream_request = _FakeRequest(match_info={"session_id": session_id}) + stream_task = asyncio.create_task(_stream_session(stream_request, config, manager)) + + await asyncio.sleep(0.03) + done_event.set() + await session_task + stream_response = await stream_task + stream_payload = stream_response.payload.decode("utf-8") + + assert ": keepalive" in stream_payload + assert "id: 1" in stream_payload + assert 'data: {"exitcode": 0}' in stream_payload + + @pytest.mark.asyncio + async def test_cancel_endpoint_interrupts_active_session( + self, terminal_setup: tuple[Config, TerminalSessionManager, Encoder], monkeypatch: pytest.MonkeyPatch + ) -> None: + config, manager, encoder = terminal_setup + await manager.initialize() + + proc = _TerminableProc() + + async def fake_create_subprocess_exec(*_args, **_kwargs): + return proc + + monkeypatch.setattr( + "app.library.TerminalSessionManager.asyncio.create_subprocess_exec", fake_create_subprocess_exec + ) + monkeypatch.setattr(manager, "_open_pty", lambda: None) + + start_request = _FakeRequest(payload={"command": "--help"}, can_read_body=True) + start_response = await create_terminal_session(start_request, config, encoder, manager) + session_id = json.loads(start_response.body.decode("utf-8"))["session_id"] + + await proc.wait_started.wait() + + cancel_response = await cancel_terminal_session( + _FakeRequest(match_info={"session_id": session_id}), config, encoder, manager + ) + cancel_payload = json.loads(cancel_response.body.decode("utf-8")) + + assert 200 == cancel_response.status + assert session_id == cancel_payload["session_id"] + + assert manager._active is not None + active_task = manager._active.task + await active_task + + metadata = await manager.get_session(session_id) + transcript = manager._read_transcript(session_id=session_id, since=0) + + assert metadata is not None + assert "interrupted" == metadata["status"] + assert -15 == metadata["exit_code"] + assert 1 == proc.terminate_calls + assert 0 == proc.kill_calls + assert transcript[-1]["event"] == "close" + assert -15 == transcript[-1]["data"]["exitcode"] + + @pytest.mark.asyncio + async def test_cancel_endpoint_returns_conflict_for_inactive_session( + self, terminal_setup: tuple[Config, TerminalSessionManager, Encoder], monkeypatch: pytest.MonkeyPatch + ) -> None: + config, manager, encoder = terminal_setup + await manager.initialize() + + async def fake_create_subprocess_exec(*_args, **_kwargs): + return _CompletedProc([b"done\n"]) + + monkeypatch.setattr( + "app.library.TerminalSessionManager.asyncio.create_subprocess_exec", fake_create_subprocess_exec + ) + monkeypatch.setattr(manager, "_open_pty", lambda: None) + + start_request = _FakeRequest(payload={"command": "--version"}, can_read_body=True) + start_response = await create_terminal_session(start_request, config, encoder, manager) + session_id = json.loads(start_response.body.decode("utf-8"))["session_id"] + + assert manager._active is not None + await manager._active.task + + cancel_response = await cancel_terminal_session( + _FakeRequest(match_info={"session_id": session_id}), config, encoder, manager + ) + + assert 409 == cancel_response.status + assert b"not active" in cancel_response.body.lower() + + @pytest.mark.asyncio + async def test_cancel_endpoint_returns_not_found_for_unknown_session( + self, terminal_setup: tuple[Config, TerminalSessionManager, Encoder] + ) -> None: + config, manager, encoder = terminal_setup + await manager.initialize() + + cancel_response = await cancel_terminal_session( + _FakeRequest(match_info={"session_id": "missing"}), config, encoder, manager + ) + + assert 404 == cancel_response.status + assert b"not found" in cancel_response.body.lower() + + +async def _stream_session( + request: _FakeRequest, config: Config, manager: TerminalSessionManager +) -> _FakeStreamResponse: + response = await stream_terminal_session(request, config, manager) + assert isinstance(response, _FakeStreamResponse) + return response diff --git a/pyproject.toml b/pyproject.toml index 456d7f3d..2b96d829 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -194,6 +194,7 @@ testpaths = ["app/tests", "app/features"] addopts = "-v --tb=short" filterwarnings = [ "ignore:Parsing dates involving a day of month without a year:DeprecationWarning", + "ignore:This process \\(pid\\=.*\\) is multi-threaded, use of fork\\(\\) may lead to deadlocks in the child\\.:DeprecationWarning", ] [dependency-groups] diff --git a/ui/app/components/InputAutocomplete.vue b/ui/app/components/InputAutocomplete.vue index d94477a8..40d43aa8 100644 --- a/ui/app/components/InputAutocomplete.vue +++ b/ui/app/components/InputAutocomplete.vue @@ -7,11 +7,12 @@ :placeholder="placeholder" autocomplete="new-password" :disabled="disabled" + :icon="icon" size="lg" variant="outline" color="neutral" class="w-full" - :ui="{ root: 'w-full', base: 'w-full bg-default/90' }" + :ui="{ root: 'w-full', base: 'w-full bg-default/90', leadingIcon: iconClass }" @focus="onFocus" @blur="hideList" @input="onInput" @@ -63,6 +64,8 @@ const props = withDefaults( placeholder?: string; disabled?: boolean; id?: string; + icon?: string; + iconClass?: string; multiple?: boolean; openOnFocus?: boolean; allowShortFlags?: boolean; @@ -71,13 +74,16 @@ const props = withDefaults( placeholder: '', disabled: false, id: '', + icon: undefined, + iconClass: '', multiple: true, openOnFocus: false, allowShortFlags: false, }, ); -const { placeholder, disabled, id, multiple, openOnFocus, allowShortFlags } = toRefs(props); +const { placeholder, disabled, id, icon, iconClass, multiple, openOnFocus, allowShortFlags } = + toRefs(props); const model = defineModel(); diff --git a/ui/app/components/TextareaAutocomplete.vue b/ui/app/components/TextareaAutocomplete.vue index 211e8567..e7030e1e 100644 --- a/ui/app/components/TextareaAutocomplete.vue +++ b/ui/app/components/TextareaAutocomplete.vue @@ -25,6 +25,13 @@ @mouseup="updateCaret" /> + +
(); +const props = withDefaults( + defineProps<{ + options: AutoCompleteOptions; + placeholder?: string; + disabled?: boolean; + id?: string; + rows?: number; + icon?: string; + iconClass?: string; + }>(), + { + placeholder: '', + disabled: false, + id: '', + rows: 4, + icon: undefined, + iconClass: '', + }, +); const model = defineModel(); const localValue = ref(model.value || ''); diff --git a/ui/app/composables/useConsoleSession.ts b/ui/app/composables/useConsoleSession.ts new file mode 100644 index 00000000..2b4ed6b0 --- /dev/null +++ b/ui/app/composables/useConsoleSession.ts @@ -0,0 +1,661 @@ +import { fetchEventSource } from '@microsoft/fetch-event-source'; +import type { EventSourceMessage } from '@microsoft/fetch-event-source'; +import { useStorage } from '@vueuse/core'; +import { computed, ref } from 'vue'; + +import { parse_api_error, parse_api_response, request, uri } from '~/utils'; + +type ConsoleSessionStatus = + | 'idle' + | 'starting' + | 'running' + | 'reconnecting' + | 'finished' + | 'interrupted' + | 'expired' + | 'error'; + +type ConsoleSessionState = { + sessionId: string | null; + command: string; + status: ConsoleSessionStatus; + lastEventId: string | null; + exitCode: number | null; + transcript: Array; + error: string; +}; + +type ConsoleSessionResponse = { + session_id?: string | null; + sessionId?: string | null; + command?: string | null; + status?: string | null; + last_event_id?: string | number | null; + lastEventId?: string | number | null; + exit_code?: number | null; + exitCode?: number | null; + expired?: boolean | null; + not_found?: boolean | null; +}; + +type StartConsoleSessionInput = { + command: string; + displayCommand?: string; +}; + +type CancelConsoleSessionResult = { + status: 'cancelled' | 'missing' | 'error'; + message?: string; +}; + +const STORAGE_KEY = 'console_session_state'; +const MAX_TRANSCRIPT_CHUNKS = 1500; +const MAX_TRANSCRIPT_CHARS = 120000; +const RECONNECT_DELAY_MS = 1500; + +const DEFAULT_STATE: ConsoleSessionState = { + sessionId: null, + command: '', + status: 'idle', + lastEventId: null, + exitCode: null, + transcript: [], + error: '', +}; + +class ConsoleSessionExpiredError extends Error {} + +const trimTranscript = (transcript: Array): Array => { + const next = [...transcript]; + + while (next.length > MAX_TRANSCRIPT_CHUNKS) { + next.shift(); + } + + let totalChars = next.reduce((sum, chunk) => sum + chunk.length, 0); + while (next.length > 1 && totalChars > MAX_TRANSCRIPT_CHARS) { + totalChars -= next[0]?.length || 0; + next.shift(); + } + + return next; +}; + +const isActiveSessionStatus = (status: ConsoleSessionStatus | null | undefined): boolean => { + return ['starting', 'running', 'reconnecting'].includes(status || ''); +}; + +const normalizeStatus = ( + status: string | null | undefined, + fallback: ConsoleSessionStatus = 'idle', +): ConsoleSessionStatus => { + switch ((status || '').trim().toLowerCase()) { + case 'starting': + case 'pending': + case 'queued': + return 'starting'; + + case 'active': + case 'open': + case 'running': + case 'streaming': + return 'running'; + + case 'reconnecting': + return 'reconnecting'; + + case 'complete': + case 'completed': + case 'closed': + case 'done': + case 'exited': + case 'finished': + return 'finished'; + + case 'interrupted': + return 'interrupted'; + + case 'failed': + case 'error': + return 'error'; + + case 'expired': + case 'missing': + case 'not_found': + return 'expired'; + + case 'idle': + return 'idle'; + + default: + return fallback; + } +}; + +const normalizePersistedState = ( + input: Partial | null | undefined, + { allowActiveWithoutSessionId = false }: { allowActiveWithoutSessionId?: boolean } = {}, +): ConsoleSessionState => { + const sessionId = + typeof input?.sessionId === 'string' && input.sessionId.trim() ? input.sessionId : null; + const status = normalizeStatus(input?.status, DEFAULT_STATE.status); + const hasDetachedActiveStatus = isActiveSessionStatus(status) && sessionId === null; + + return { + sessionId, + command: typeof input?.command === 'string' ? input.command : DEFAULT_STATE.command, + status: hasDetachedActiveStatus && !allowActiveWithoutSessionId ? 'idle' : status, + lastEventId: + input?.lastEventId === null || input?.lastEventId === undefined + ? null + : String(input.lastEventId), + exitCode: typeof input?.exitCode === 'number' ? input.exitCode : null, + transcript: trimTranscript( + Array.isArray(input?.transcript) + ? input.transcript.filter((chunk): chunk is string => typeof chunk === 'string') + : DEFAULT_STATE.transcript, + ), + error: typeof input?.error === 'string' ? input.error : DEFAULT_STATE.error, + }; +}; + +const sessionState = useStorage(STORAGE_KEY, { ...DEFAULT_STATE }); +const streamController = ref(null); +const reconnectTimer = ref | null>(null); +const connectionNonce = ref(0); + +sessionState.value = normalizePersistedState(sessionState.value); + +const updateState = (patch: Partial): void => { + sessionState.value = normalizePersistedState( + { + ...sessionState.value, + ...patch, + }, + { allowActiveWithoutSessionId: true }, + ); +}; + +const appendTranscript = (chunk: string): void => { + if (!chunk) { + return; + } + + updateState({ transcript: trimTranscript([...sessionState.value.transcript, chunk]) }); +}; + +const stopReconnectTimer = (): void => { + if (!reconnectTimer.value) { + return; + } + + clearTimeout(reconnectTimer.value); + reconnectTimer.value = null; +}; + +const stopStream = (): void => { + stopReconnectTimer(); + streamController.value?.abort(); + streamController.value = null; +}; + +const clearTranscript = (): void => { + const isActive = isActiveSessionStatus(sessionState.value.status); + + updateState({ + transcript: [], + sessionId: isActive ? sessionState.value.sessionId : null, + status: isActive ? sessionState.value.status : 'idle', + lastEventId: isActive ? sessionState.value.lastEventId : null, + exitCode: isActive ? sessionState.value.exitCode : null, + error: '', + }); +}; + +const markExpired = (): void => { + stopStream(); + updateState({ status: 'expired', error: '' }); +}; + +const finishSession = ( + status: ConsoleSessionStatus = 'finished', + exitCode: number | null = sessionState.value.exitCode, +): void => { + stopStream(); + updateState({ + status, + exitCode, + error: '', + }); +}; + +const scheduleReconnect = (): void => { + if (!sessionState.value.sessionId) { + return; + } + + stopReconnectTimer(); + updateState({ status: 'reconnecting', error: '' }); + reconnectTimer.value = setTimeout(() => { + reconnectTimer.value = null; + if (!sessionState.value.sessionId || !isActiveSessionStatus(sessionState.value.status)) { + return; + } + + void connectStream(); + }, RECONNECT_DELAY_MS); +}; + +const shouldSkipEvent = (eventId: string): boolean => { + const lastEventId = sessionState.value.lastEventId; + if (!lastEventId) { + return false; + } + + if (eventId === lastEventId) { + return true; + } + + if (/^\d+$/.test(eventId) && /^\d+$/.test(lastEventId)) { + return Number(eventId) <= Number(lastEventId); + } + + return false; +}; + +const readJson = async (response: Response): Promise => { + try { + return await response.clone().json(); + } catch { + return null; + } +}; + +const parseResponseError = async (response: Response): Promise => { + const payload = await readJson(response); + if (payload) { + return await parse_api_error(payload); + } + + try { + const text = await response.text(); + if (text) { + return text; + } + } catch { + return response.statusText || 'Request failed.'; + } + + return response.statusText || 'Request failed.'; +}; + +const refreshSessionMetadata = async (): Promise => { + if (!sessionState.value.sessionId) { + return; + } + + try { + const payload = await readSessionResponse( + `/api/system/terminal/${encodeURIComponent(sessionState.value.sessionId)}`, + ); + if (!payload) { + return; + } + + updateState(normalizeResponse(payload)); + } catch (error) { + if (error instanceof ConsoleSessionExpiredError) { + markExpired(); + } + } +}; + +const normalizeResponse = (payload: ConsoleSessionResponse): Partial => { + const status = + payload.expired || payload.not_found + ? 'expired' + : normalizeStatus(payload.status, sessionState.value.status); + + const sessionId = payload.session_id ?? payload.sessionId; + const lastEventId = payload.last_event_id ?? payload.lastEventId; + const exitCode = payload.exit_code ?? payload.exitCode; + + return { + sessionId: sessionId === undefined ? sessionState.value.sessionId : sessionId, + command: payload.command ?? sessionState.value.command, + status, + lastEventId: + lastEventId === undefined || lastEventId === null + ? sessionState.value.lastEventId + : String(lastEventId), + exitCode: exitCode === undefined ? sessionState.value.exitCode : exitCode, + error: + status === 'error' + ? sessionState.value.error || + (typeof exitCode === 'number' && exitCode !== 0 + ? `Command exited with code ${exitCode}.` + : '') + : '', + }; +}; + +const readSessionResponse = async ( + path: string, + { allowMissing = false }: { allowMissing?: boolean } = {}, +): Promise => { + const response = await request(path); + + if ([404, 410].includes(response.status)) { + if (allowMissing) { + return null; + } + + throw new ConsoleSessionExpiredError(await parseResponseError(response)); + } + + if (!response.ok) { + throw new Error(await parseResponseError(response)); + } + + return await parse_api_response(response.json()); +}; + +const connectStream = async (): Promise => { + if (!sessionState.value.sessionId) { + return; + } + + stopStream(); + + const controller = new AbortController(); + let finalMetadataRefresh: Promise | null = null; + const nonce = connectionNonce.value + 1; + connectionNonce.value = nonce; + streamController.value = controller; + + const search = new URLSearchParams(); + if (sessionState.value.lastEventId) { + search.set('since', sessionState.value.lastEventId); + } + + const url = uri( + `/api/system/terminal/${encodeURIComponent(sessionState.value.sessionId)}/stream${search.size > 0 ? `?${search.toString()}` : ''}`, + ); + + const headers: Record = { + Accept: 'text/event-stream', + }; + + if (sessionState.value.lastEventId) { + headers['Last-Event-ID'] = sessionState.value.lastEventId; + } + + try { + await fetchEventSource(url, { + method: 'GET', + headers, + credentials: 'same-origin', + signal: controller.signal, + openWhenHidden: true, + onopen: async (response) => { + if (response.ok) { + updateState({ status: 'running', error: '' }); + return; + } + + if ([404, 410].includes(response.status)) { + throw new ConsoleSessionExpiredError(await parseResponseError(response)); + } + + throw new Error(await parseResponseError(response)); + }, + onmessage: (event: EventSourceMessage) => { + if (event.id && shouldSkipEvent(event.id)) { + return; + } + + let payload: Record | null = null; + if (event.data) { + try { + payload = JSON.parse(event.data) as Record; + } catch { + payload = null; + } + } + + if (event.id) { + updateState({ lastEventId: event.id }); + } + + if (event.event === 'expired' || payload?.type === 'expired' || payload?.expired === true) { + markExpired(); + return; + } + + if (event.event === 'status') { + updateState(normalizeResponse(payload as ConsoleSessionResponse)); + return; + } + + if (event.event === 'output') { + const line = typeof payload?.line === 'string' ? payload.line : event.data; + appendTranscript(`${line || ''}\n`); + return; + } + + if (event.event === 'close') { + const nextExitCode = + typeof payload?.exitcode === 'number' + ? payload.exitcode + : typeof payload?.exit_code === 'number' + ? payload.exit_code + : sessionState.value.exitCode; + const nextStatus = normalizeStatus( + typeof payload?.status === 'string' ? payload.status : null, + 'finished', + ); + + if (payload?.expired === true || payload?.status === 'expired') { + markExpired(); + return; + } + + finishSession(nextStatus === 'error' ? 'error' : nextStatus, nextExitCode); + finalMetadataRefresh = refreshSessionMetadata(); + } + }, + onclose: () => { + if (controller.signal.aborted || nonce !== connectionNonce.value) { + return; + } + + if (!sessionState.value.sessionId || !isActiveSessionStatus(sessionState.value.status)) { + return; + } + + scheduleReconnect(); + }, + onerror: (error) => { + throw error instanceof Error ? error : new Error(String(error)); + }, + }); + + if (finalMetadataRefresh) { + await finalMetadataRefresh; + } + } catch (error) { + if (controller.signal.aborted || nonce !== connectionNonce.value) { + return; + } + + if (error instanceof ConsoleSessionExpiredError) { + markExpired(); + return; + } + + if (sessionState.value.sessionId && isActiveSessionStatus(sessionState.value.status)) { + scheduleReconnect(); + return; + } + + const message = error instanceof Error ? error.message : String(error); + appendTranscript(`Error: ${message}\n`); + updateState({ status: 'error', error: message }); + } finally { + if (streamController.value === controller) { + streamController.value = null; + } + } +}; + +const restoreSession = async (): Promise => { + try { + if (sessionState.value.sessionId) { + const payload = await readSessionResponse( + `/api/system/terminal/${encodeURIComponent(sessionState.value.sessionId)}`, + ); + + if (payload) { + updateState(normalizeResponse(payload)); + } + + if (sessionState.value.sessionId && isActiveSessionStatus(sessionState.value.status)) { + await connectStream(); + } + + return; + } + + if ( + sessionState.value.sessionId || + sessionState.value.command || + sessionState.value.transcript.length > 0 + ) { + return; + } + + const active = await readSessionResponse('/api/system/terminal/active', { allowMissing: true }); + if (!active) { + return; + } + + updateState(normalizeResponse(active)); + + if (sessionState.value.sessionId && isActiveSessionStatus(sessionState.value.status)) { + await connectStream(); + } + } catch (error) { + if (error instanceof ConsoleSessionExpiredError) { + markExpired(); + return; + } + + if (sessionState.value.sessionId && isActiveSessionStatus(sessionState.value.status)) { + await connectStream(); + } + } +}; + +const startSession = async ({ + command, + displayCommand = command, +}: StartConsoleSessionInput): Promise => { + stopStream(); + updateState({ + sessionId: null, + command: displayCommand, + status: 'starting', + lastEventId: null, + exitCode: null, + error: '', + }); + + try { + const response = await request('/api/system/terminal', { + method: 'POST', + body: JSON.stringify({ command }), + }); + + if (!response.ok) { + throw new Error(await parseResponseError(response)); + } + + const payload = await parse_api_response(response.json()); + updateState({ + ...normalizeResponse(payload), + command: displayCommand, + status: normalizeStatus(payload.status, 'running'), + lastEventId: null, + exitCode: null, + transcript: trimTranscript([ + ...sessionState.value.transcript, + 'user@YTPTube ~\n', + `$ yt-dlp ${displayCommand}\n`, + ]), + error: '', + }); + + if (sessionState.value.sessionId && isActiveSessionStatus(sessionState.value.status)) { + await connectStream(); + } + + return Boolean(sessionState.value.sessionId); + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + appendTranscript(`Error: ${message}\n`); + updateState({ status: 'error', error: message }); + return false; + } +}; + +const cancelSession = async (): Promise => { + if (!sessionState.value.sessionId) { + return { status: 'missing', message: 'No active terminal session found.' }; + } + + try { + const response = await request( + `/api/system/terminal/${encodeURIComponent(sessionState.value.sessionId)}`, + { + method: 'DELETE', + }, + ); + + if (response.status === 404) { + await refreshSessionMetadata(); + return { status: 'missing', message: 'Terminal session not found.' }; + } + + if (!response.ok) { + throw new Error(await parseResponseError(response)); + } + + updateState({ error: '' }); + return { status: 'cancelled' }; + } catch (error) { + const message = error instanceof Error ? error.message : String(error); + updateState({ error: message }); + return { status: 'error', message }; + } +}; + +const resetState = (): void => { + stopStream(); + sessionState.value = { ...DEFAULT_STATE }; +}; + +const useConsoleSession = () => { + return { + state: sessionState, + bufferedTranscript: computed(() => sessionState.value.transcript.slice()), + isLoading: computed(() => isActiveSessionStatus(sessionState.value.status)), + clearTranscript, + restoreSession, + startSession, + cancelSession, + disconnect: stopStream, + __resetForTesting: resetState, + }; +}; + +export { useConsoleSession }; diff --git a/ui/app/pages/console.vue b/ui/app/pages/console.vue index 6bb1e84c..aee75e77 100644 --- a/ui/app/pages/console.vue +++ b/ui/app/pages/console.vue @@ -6,8 +6,28 @@ Console - - {{ isLoading ? 'Streaming output' : 'Idle' }} + + + + {{ sessionStatusLabel }} + + + + + Session {{ shortSessionId }} + + + + Exit {{ sessionExitCode }} @@ -15,10 +35,21 @@
-

Run yt-dlp commands directly in a non-interactive session.

+

{{ sessionStatusDescription }}

+ + Help + + + + diff --git a/ui/tests/composables/useConsoleSession.test.ts b/ui/tests/composables/useConsoleSession.test.ts new file mode 100644 index 00000000..7a08456e --- /dev/null +++ b/ui/tests/composables/useConsoleSession.test.ts @@ -0,0 +1,505 @@ +import { describe, it, expect, beforeAll, beforeEach, afterEach, mock, spyOn } from 'bun:test' +import { ref } from 'vue' + +type StorageEntry = ReturnType> + +type MockEventSourceMessage = { + event: string + id?: string + data: string +} + +type MockFetchEventSourceOptions = { + method: string + headers: Record + credentials: RequestCredentials + signal: AbortSignal + openWhenHidden: boolean + onopen: (response: Response) => Promise | void + onmessage: (event: MockEventSourceMessage) => Promise | void + onclose: () => Promise | void + onerror: (error: unknown) => unknown +} + +const runtimeConfig = { + app: { + baseURL: '/base-path', + }, +} + +;(globalThis as typeof globalThis & { useRuntimeConfig?: () => typeof runtimeConfig }).useRuntimeConfig = () => runtimeConfig + +mock.module('#imports', () => ({ + useRuntimeConfig: () => runtimeConfig, +})) + +const storageMap = new Map>() + +const cloneValue = (value: T): T => { + return JSON.parse(JSON.stringify(value)) as T +} + +const useStorageFn = mock((key: string, defaultValue: T) => { + if (!storageMap.has(key)) { + storageMap.set(key, ref(cloneValue(defaultValue))) + } + + return storageMap.get(key) as StorageEntry +}) + +mock.module('@vueuse/core', () => ({ + useStorage: useStorageFn, +})) + +const fetchEventSourceMock = mock( + async (_url: string, _options: MockFetchEventSourceOptions): Promise => {}, +) + +mock.module('@microsoft/fetch-event-source', () => ({ + fetchEventSource: fetchEventSourceMock, +})) + +type MockResponseInput = { + ok: boolean + status: number + jsonData: unknown +} + +const createMockResponse = ({ ok, status, jsonData }: MockResponseInput): Response => { + return { + ok, + status, + headers: new Headers({ 'Content-Type': 'application/json' }), + redirected: false, + statusText: ok ? 'OK' : 'Error', + type: 'basic', + url: '', + body: null, + bodyUsed: false, + clone() { + return this + }, + async json() { + return jsonData + }, + text: async () => JSON.stringify(jsonData), + arrayBuffer: async () => new ArrayBuffer(0), + blob: async () => new Blob(), + formData: async () => new FormData(), + } as Response +} + +let utils: Awaited +let useConsoleSession: typeof import('~/composables/useConsoleSession').useConsoleSession + +beforeAll(async () => { + utils = await import('~/utils/index') + ;({ useConsoleSession } = await import('~/composables/useConsoleSession')) +}) + +beforeEach(() => { + storageMap.clear() + useStorageFn.mockClear() + fetchEventSourceMock.mockClear() +}) + +afterEach(() => { + const session = useConsoleSession() + session.__resetForTesting() +}) + +describe('useConsoleSession', () => { + it('starts a session, persists prompt transcript, and stores streamed output', async () => { + const requestSpy = spyOn(utils, 'request') + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { session_id: 'sess-1', command: '--help', status: 'running' }, + }), + ) + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { + session_id: 'sess-1', + command: '--help', + status: 'completed', + last_event_id: '2', + exit_code: 0, + }, + }), + ) + + fetchEventSourceMock.mockImplementationOnce(async (_url, options) => { + await options.onopen(new Response('', { status: 200 })) + options.onmessage({ + event: 'output', + id: '1', + data: JSON.stringify({ line: 'line one' }), + }) + options.onmessage({ + event: 'close', + id: '2', + data: JSON.stringify({ exitcode: 0 }), + }) + }) + + const session = useConsoleSession() + const started = await session.startSession({ command: '--help', displayCommand: '--help' }) + await Promise.resolve() + await Promise.resolve() + + expect(started).toBe(true) + expect(session.state.value.sessionId).toBe('sess-1') + expect(session.state.value.command).toBe('--help') + expect(session.state.value.status).toBe('finished') + expect(session.state.value.lastEventId).toBe('2') + expect(session.state.value.exitCode).toBe(0) + expect(session.state.value.transcript).toEqual([ + 'user@YTPTube ~\n', + '$ yt-dlp --help\n', + 'line one\n', + ]) + + requestSpy.mockRestore() + }) + + it('restores a running session using both since and Last-Event-ID', async () => { + const session = useConsoleSession() + session.state.value = { + sessionId: 'sess-restore', + command: '--version', + status: 'running', + lastEventId: '42', + exitCode: null, + transcript: ['user@YTPTube ~\n', '$ yt-dlp --version\n'], + error: '', + } + + const requestSpy = spyOn(utils, 'request') + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { session_id: 'sess-restore', command: '--version', status: 'running', last_event_id: '42' }, + }), + ) + + fetchEventSourceMock.mockImplementationOnce(async (_url, options) => { + await options.onopen(new Response('', { status: 200 })) + }) + + await session.restoreSession() + + expect(fetchEventSourceMock).toHaveBeenCalledTimes(1) + + const streamCall = fetchEventSourceMock.mock.calls[0] + expect(streamCall).toBeDefined() + if (!streamCall) { + throw new Error('Expected fetchEventSource to be called once.') + } + + const [streamUrl, streamOptions] = streamCall + expect(streamUrl).toContain('/base-path/api/system/terminal/sess-restore/stream?since=42') + expect(streamOptions.headers['Last-Event-ID']).toBe('42') + expect(streamOptions.method).toBe('GET') + + requestSpy.mockRestore() + }) + + it('marks the session expired and stops retry setup on stream not found', async () => { + const requestSpy = spyOn(utils, 'request') + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { session_id: 'sess-expired', command: '--help', status: 'running' }, + }), + ) + + fetchEventSourceMock.mockImplementationOnce(async (_url, options) => { + await options.onopen( + new Response(JSON.stringify({ error: 'Session expired' }), { + status: 404, + headers: { 'Content-Type': 'application/json' }, + }), + ) + }) + + const session = useConsoleSession() + const started = await session.startSession({ command: '--help', displayCommand: '--help' }) + + expect(started).toBe(true) + expect(session.state.value.status).toBe('expired') + expect(session.state.value.command).toBe('--help') + expect(session.state.value.transcript).toEqual([ + 'user@YTPTube ~\n', + '$ yt-dlp --help\n', + ]) + expect(fetchEventSourceMock).toHaveBeenCalledTimes(1) + + requestSpy.mockRestore() + }) + + it('restores a persisted session even after a local stream error state', async () => { + const session = useConsoleSession() + session.state.value = { + sessionId: 'sess-reconnect', + command: '--version', + status: 'error', + lastEventId: '7', + exitCode: null, + transcript: ['user@YTPTube ~\n', '$ yt-dlp --version\n'], + error: 'Stream failed', + } + + const requestSpy = spyOn(utils, 'request') + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { session_id: 'sess-reconnect', command: '--version', status: 'running', last_event_id: '7' }, + }), + ) + + fetchEventSourceMock.mockImplementationOnce(async (_url, options) => { + await options.onopen(new Response('', { status: 200 })) + }) + + await session.restoreSession() + + expect(fetchEventSourceMock).toHaveBeenCalledTimes(1) + expect(session.state.value.status).toBe('running') + expect(session.state.value.error).toBe('') + + requestSpy.mockRestore() + }) + + it('deduplicates replayed events when reconnect resumes from the last event id', async () => { + const requestSpy = spyOn(utils, 'request') + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { session_id: 'sess-dupe', command: '--help', status: 'running' }, + }), + ) + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { + session_id: 'sess-dupe', + command: '--help', + status: 'completed', + last_event_id: '9', + exit_code: 0, + }, + }), + ) + + fetchEventSourceMock.mockImplementationOnce(async (_url, options) => { + await options.onopen(new Response('', { status: 200 })) + options.onmessage({ + event: 'output', + id: '8', + data: JSON.stringify({ line: 'first line' }), + }) + options.onmessage({ + event: 'output', + id: '8', + data: JSON.stringify({ line: 'first line' }), + }) + options.onmessage({ + event: 'close', + id: '9', + data: JSON.stringify({ exitcode: 0 }), + }) + }) + + const session = useConsoleSession() + await session.startSession({ command: '--help', displayCommand: '--help' }) + await Promise.resolve() + await Promise.resolve() + + expect(session.state.value.transcript).toEqual([ + 'user@YTPTube ~\n', + '$ yt-dlp --help\n', + 'first line\n', + ]) + + requestSpy.mockRestore() + }) + + it('refreshes final metadata after close so interrupted sessions keep their backend status', async () => { + const requestSpy = spyOn(utils, 'request') + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { session_id: 'sess-int', command: '--help', status: 'running' }, + }), + ) + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { + session_id: 'sess-int', + command: '--help', + status: 'interrupted', + last_event_id: '2', + exit_code: -15, + }, + }), + ) + + fetchEventSourceMock.mockImplementationOnce(async (_url, options) => { + await options.onopen(new Response('', { status: 200 })) + options.onmessage({ + event: 'output', + id: '1', + data: JSON.stringify({ line: 'partial output' }), + }) + options.onmessage({ + event: 'close', + id: '2', + data: JSON.stringify({ exitcode: -15 }), + }) + }) + + const session = useConsoleSession() + await session.startSession({ command: '--help', displayCommand: '--help' }) + await Promise.resolve() + await Promise.resolve() + + expect(session.state.value.status).toBe('interrupted') + expect(session.state.value.exitCode).toBe(-15) + expect(session.state.value.error).toBe('') + + requestSpy.mockRestore() + }) + + it('requests cancellation for the active session without dropping local state early', async () => { + const session = useConsoleSession() + session.state.value = { + sessionId: 'sess-stop', + command: '--help', + status: 'running', + lastEventId: '3', + exitCode: null, + transcript: ['user@YTPTube ~\n', '$ yt-dlp --help\n', 'chunk\n'], + error: '', + } + + const requestSpy = spyOn(utils, 'request') + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: true, + status: 200, + jsonData: { message: 'Terminal session cancellation requested.', session_id: 'sess-stop' }, + }), + ) + + const result = await session.cancelSession() + + expect(result).toEqual({ status: 'cancelled' }) + expect(session.state.value.sessionId).toBe('sess-stop') + expect(session.state.value.status).toBe('running') + expect(session.state.value.error).toBe('') + + const requestCall = requestSpy.mock.calls.at(-1) + expect(requestCall).toBeDefined() + if (!requestCall) { + throw new Error('Expected request to be called for session cancellation.') + } + + const [path, options] = requestCall as [string, { method: string }] + expect(path).toBe('/api/system/terminal/sess-stop') + expect(options.method).toBe('DELETE') + + requestSpy.mockRestore() + }) + + it('disconnects the local stream without issuing a cancel request', async () => { + const session = useConsoleSession() + session.state.value = { + sessionId: 'sess-detach', + command: '--help', + status: 'running', + lastEventId: '3', + exitCode: null, + transcript: ['user@YTPTube ~\n', '$ yt-dlp --help\n'], + error: '', + } + + fetchEventSourceMock.mockImplementationOnce(async (_url, options) => { + await options.onopen(new Response('', { status: 200 })) + return new Promise(() => {}) + }) + + void session.restoreSession() + await Promise.resolve() + await Promise.resolve() + + const requestSpy = spyOn(utils, 'request') + requestSpy.mockClear() + session.disconnect() + + expect(requestSpy).toHaveBeenCalledTimes(0) + expect(session.state.value.sessionId).toBe('sess-detach') + + requestSpy.mockRestore() + }) + + it('refreshes metadata when cancel targets a missing session', async () => { + const session = useConsoleSession() + session.state.value = { + sessionId: 'sess-missing', + command: '--help', + status: 'running', + lastEventId: '3', + exitCode: null, + transcript: ['user@YTPTube ~\n', '$ yt-dlp --help\n'], + error: '', + } + + const requestSpy = spyOn(utils, 'request') + requestSpy.mockResolvedValueOnce( + createMockResponse({ + ok: false, + status: 404, + jsonData: { error: 'Terminal session not found.' }, + }), + ) + + const result = await session.cancelSession() + + expect(result).toEqual({ status: 'missing', message: 'Terminal session not found.' }) + expect(session.state.value.status).toBe('running') + + requestSpy.mockRestore() + }) + + it('clears visible transcript without dropping the active session cursor', () => { + const session = useConsoleSession() + session.state.value = { + sessionId: 'sess-active', + command: '--help', + status: 'running', + lastEventId: '15', + exitCode: null, + transcript: ['user@YTPTube ~\n', '$ yt-dlp --help\n', 'chunk\n'], + error: '', + } + + session.clearTranscript() + + expect(session.state.value.transcript).toEqual([]) + expect(session.state.value.sessionId).toBe('sess-active') + expect(session.state.value.status).toBe('running') + expect(session.state.value.lastEventId).toBe('15') + }) +})