From 98c5e17b3c40e137718ed868aa97c5197d7b58cd Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 2 Mar 2024 15:53:13 -0800 Subject: [PATCH 1/3] Improved H3 for hypercorn. --- src/hypercorn/asyncio/task_group.py | 45 +++++++++++++- src/hypercorn/protocol/quic.py | 87 +++++++++++++++++++--------- src/hypercorn/trio/task_group.py | 41 ++++++++++++- src/hypercorn/trio/worker_context.py | 2 +- src/hypercorn/typing.py | 17 ++++++ 5 files changed, 161 insertions(+), 31 deletions(-) diff --git a/src/hypercorn/asyncio/task_group.py b/src/hypercorn/asyncio/task_group.py index 2e589035..bc1b7e06 100644 --- a/src/hypercorn/asyncio/task_group.py +++ b/src/hypercorn/asyncio/task_group.py @@ -6,7 +6,7 @@ from typing import Any, Awaitable, Callable, Optional from ..config import Config -from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope +from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope, Timer try: from asyncio import TaskGroup as AsyncioTaskGroup @@ -33,6 +33,44 @@ async def _handle( await send(None) +LONG_SLEEP = 86400.0 + +class AsyncioTimer(Timer): + def __init__(self, action: Callable) -> None: + self._action = action + self._done = False + self._wake_up = asyncio.Condition() + self._when: Optional[float] = None + + async def schedule(self, when: Optional[float]) -> None: + self._when = when + async with self._wake_up: + self._wake_up.notify() + + async def stop(self) -> None: + self._done = True + async with self._wake_up: + self._wake_up.notify() + + async def _wait_for_wake_up(self) -> None: + async with self._wake_up: + await self._wake_up.wait() + + async def run(self) -> None: + while not self._done: + if self._when is not None and asyncio.get_event_loop().time() >= self._when: + self._when = None + await self._action() + if self._when is not None: + timeout = max(self._when - asyncio.get_event_loop().time(), 0.0) + else: + timeout = LONG_SLEEP + if not self._done: + try: + await asyncio.wait_for(self._wait_for_wake_up(), timeout) + except TimeoutError: + pass + class TaskGroup: def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop = loop @@ -66,6 +104,11 @@ def _call_soon(func: Callable, *args: Any) -> Any: def spawn(self, func: Callable, *args: Any) -> None: self._task_group.create_task(func(*args)) + def create_timer(self, action: Callable) -> Timer: + timer = AsyncioTimer(action) + self._task_group.create_task(timer.run()) + return timer + async def __aenter__(self) -> "TaskGroup": await self._task_group.__aenter__() return self diff --git a/src/hypercorn/protocol/quic.py b/src/hypercorn/protocol/quic.py index 3d16e54d..a4908848 100644 --- a/src/hypercorn/protocol/quic.py +++ b/src/hypercorn/protocol/quic.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import Awaitable, Callable, Dict, Optional, Tuple +from typing import Awaitable, Callable, Dict, Optional, Set, Tuple from aioquic.buffer import Buffer from aioquic.h3.connection import H3_ALPN @@ -22,7 +22,21 @@ from .h3 import H3Protocol from ..config import Config from ..events import Closed, Event, RawData -from ..typing import AppWrapper, TaskGroup, WorkerContext +from ..typing import AppWrapper, TaskGroup, WorkerContext, Timer + + +class ConnectionState: + def __init__(self, connection: QuicConnection): + self.connection = connection + self.timer: Optional[Timer] = None + self.cids: Set[bytes] = set() + self.h3_protocol: Optional[H3Protocol] = None + + def add_cid(self, cid: bytes) -> None: + self.cids.add(cid) + + def remove_cid(self, cid: bytes) -> None: + self.cids.remove(cid) class QuicProtocol: @@ -38,8 +52,7 @@ def __init__( self.app = app self.config = config self.context = context - self.connections: Dict[bytes, QuicConnection] = {} - self.http_connections: Dict[QuicConnection, H3Protocol] = {} + self.connections: Dict[bytes, ConnectionState] = {} self.send = send self.server = server self.task_group = task_group @@ -49,7 +62,7 @@ def __init__( @property def idle(self) -> bool: - return len(self.connections) == 0 and len(self.http_connections) == 0 + return len(self.connections) == 0 async def handle(self, event: Event) -> None: if isinstance(event, RawData): @@ -69,9 +82,13 @@ async def handle(self, event: Event) -> None: await self.send(RawData(data=data, address=event.address)) return - connection = self.connections.get(header.destination_cid) + state = self.connections.get(header.destination_cid) + if state is not None: + connection = state.connection + else: + connection = None if ( - connection is None + state is None and len(event.data) >= 1200 and header.packet_type == PACKET_TYPE_INITIAL and not self.context.terminated.is_set() @@ -80,12 +97,18 @@ async def handle(self, event: Event) -> None: configuration=self.quic_config, original_destination_connection_id=header.destination_cid, ) - self.connections[header.destination_cid] = connection - self.connections[connection.host_cid] = connection + # This partial() needs python >= 3.8 + state = ConnectionState(connection) + timer = self.task_group.create_timer(partial(self._timeout, state)) + state.timer = timer + state.add_cid(header.destination_cid) + self.connections[header.destination_cid] = state + state.add_cid(connection.host_cid) + self.connections[connection.host_cid] = state if connection is not None: connection.receive_datagram(event.data, event.address, now=self.context.time()) - await self._handle_events(connection, event.address) + await self._wake_up_timer(state) elif isinstance(event, Closed): pass @@ -94,14 +117,18 @@ async def send_all(self, connection: QuicConnection) -> None: await self.send(RawData(data=data, address=address)) async def _handle_events( - self, connection: QuicConnection, client: Optional[Tuple[str, int]] = None + self, state: ConnectionState, client: Optional[Tuple[str, int]] = None ) -> None: + connection = state.connection event = connection.next_event() while event is not None: if isinstance(event, ConnectionTerminated): - pass + await state.timer.stop() + for cid in state.cids: + del self.connections[cid] + state.cids = set() elif isinstance(event, ProtocolNegotiated): - self.http_connections[connection] = H3Protocol( + state.h3_protocol = H3Protocol( self.app, self.config, self.context, @@ -109,27 +136,31 @@ async def _handle_events( client, self.server, connection, - partial(self.send_all, connection), + partial(self._wake_up_timer, state), ) elif isinstance(event, ConnectionIdIssued): - self.connections[event.connection_id] = connection + state.add_cid(event.connection_id) + self.connections[event.connection_id] = state elif isinstance(event, ConnectionIdRetired): + state.remove_cid(event.connection_id) del self.connections[event.connection_id] - if connection in self.http_connections: - await self.http_connections[connection].handle(event) + elif state.h3_protocol is not None: + await state.h3_protocol.handle(event) event = connection.next_event() + async def _wake_up_timer(self, state: ConnectionState) -> None: + # When new output is send, or new input is received, we + # fire the timer right away so we update our state. + await state.timer.schedule(0.0) + + async def _timeout(self, state: ConnectionState) -> None: + connection = state.connection + now = self.context.time() + when = connection.get_timer() + if when is not None and now > when: + connection.handle_timer(now) + await self._handle_events(state, None) await self.send_all(connection) - - timer = connection.get_timer() - if timer is not None: - self.task_group.spawn(self._handle_timer, timer, connection) - - async def _handle_timer(self, timer: float, connection: QuicConnection) -> None: - wait = max(0, timer - self.context.time()) - await self.context.sleep(wait) - if connection._close_at is not None: - connection.handle_timer(now=self.context.time()) - await self._handle_events(connection, None) + await state.timer.schedule(connection.get_timer()) diff --git a/src/hypercorn/trio/task_group.py b/src/hypercorn/trio/task_group.py index 044ff852..5611da2d 100644 --- a/src/hypercorn/trio/task_group.py +++ b/src/hypercorn/trio/task_group.py @@ -7,7 +7,7 @@ import trio from ..config import Config -from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope +from ..typing import AppWrapper, ASGIReceiveCallable, ASGIReceiveEvent, ASGISendEvent, Scope, Timer if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup @@ -39,6 +39,40 @@ async def _handle( await send(None) +LONG_SLEEP = 86400.0 + +class TrioTimer(Timer): + def __init__(self, action: Callable) -> None: + self._action = action + self._done = False + self._wake_up = trio.Condition() + self._when: Optional[float] = None + + async def schedule(self, when: Optional[float]) -> None: + self._when = when + async with self._wake_up: + self._wake_up.notify() + + async def stop(self) -> None: + self._done = True + async with self._wake_up: + self._wake_up.notify() + + async def run(self) -> None: + while not self._done: + if self._when is not None and trio.current_time() >= self._when: + self._when = None + await self._action() + if self._when is not None: + timeout = max(self._when - trio.current_time(), 0.0) + else: + timeout = LONG_SLEEP + if not self._done: + with trio.move_on_after(timeout): + async with self._wake_up: + await self._wake_up.wait() + + class TaskGroup: def __init__(self) -> None: self._nursery: Optional[trio._core._run.Nursery] = None @@ -67,6 +101,11 @@ async def spawn_app( def spawn(self, func: Callable, *args: Any) -> None: self._nursery.start_soon(func, *args) + def create_timer(self, action: Callable) -> Timer: + timer = TrioTimer(action) + self._nursery.start_soon(timer.run) + return timer + async def __aenter__(self) -> TaskGroup: self._nursery_manager = trio.open_nursery() self._nursery = await self._nursery_manager.__aenter__() diff --git a/src/hypercorn/trio/worker_context.py b/src/hypercorn/trio/worker_context.py index c09c4fb6..41ab3f2e 100644 --- a/src/hypercorn/trio/worker_context.py +++ b/src/hypercorn/trio/worker_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional, Type, Union +from typing import Awaitable, Optional, Type, Union import trio diff --git a/src/hypercorn/typing.py b/src/hypercorn/typing.py index 2ebb711d..527e1194 100644 --- a/src/hypercorn/typing.py +++ b/src/hypercorn/typing.py @@ -288,6 +288,20 @@ def is_set(self) -> bool: ... +class Timer: + def __init__(self, action: Callable) -> None: + ... + + async def schedule(self, when: float) -> None: + ... + + async def stop(self) -> None: + ... + + async def run(self) -> None: + ... + + class WorkerContext(Protocol): event_class: Type[Event] terminate: Event @@ -318,6 +332,9 @@ async def spawn_app( def spawn(self, func: Callable, *args: Any) -> None: ... + def create_timer(self, action: Callable) -> Timer: + ... + async def __aenter__(self) -> TaskGroup: ... From fe49b91585a38261d4dfbeba38bd66c33bdaafa5 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 3 Mar 2024 10:03:20 -0800 Subject: [PATCH 2/3] remove unneeded Awaitable import --- src/hypercorn/trio/worker_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hypercorn/trio/worker_context.py b/src/hypercorn/trio/worker_context.py index 41ab3f2e..c09c4fb6 100644 --- a/src/hypercorn/trio/worker_context.py +++ b/src/hypercorn/trio/worker_context.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Awaitable, Optional, Type, Union +from typing import Optional, Type, Union import trio From 1d9ea917ca005e4dad734ad0e5ab2113f2026427 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 3 Mar 2024 15:08:14 -0800 Subject: [PATCH 3/3] Add retry and session resumption. --- src/hypercorn/config.py | 3 ++ src/hypercorn/protocol/quic.py | 67 +++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/src/hypercorn/config.py b/src/hypercorn/config.py index f00c7d5e..12c19e92 100644 --- a/src/hypercorn/config.py +++ b/src/hypercorn/config.py @@ -33,6 +33,7 @@ BYTES = 1 OCTETS = 1 SECONDS = 1.0 +DEFAULT_QUIC_MAX_SAVED_SESSIONS = 100 FilePath = Union[AnyStr, os.PathLike] SocketKind = Union[int, socket.SocketKind] @@ -95,6 +96,8 @@ class Config: max_requests: Optional[int] = None max_requests_jitter: int = 0 pid_path: Optional[str] = None + quic_retry: bool = True + quic_max_saved_sessions: int = DEFAULT_QUIC_MAX_SAVED_SESSIONS server_names: List[str] = [] shutdown_timeout = 60 * SECONDS ssl_handshake_timeout = 60 * SECONDS diff --git a/src/hypercorn/protocol/quic.py b/src/hypercorn/protocol/quic.py index a4908848..de25ff6a 100644 --- a/src/hypercorn/protocol/quic.py +++ b/src/hypercorn/protocol/quic.py @@ -1,6 +1,7 @@ from __future__ import annotations from functools import partial +from secrets import token_bytes from typing import Awaitable, Callable, Dict, Optional, Set, Tuple from aioquic.buffer import Buffer @@ -15,9 +16,12 @@ ) from aioquic.quic.packet import ( encode_quic_version_negotiation, + encode_quic_retry, PACKET_TYPE_INITIAL, pull_quic_header, ) +from aioquic.quic.retry import QuicRetryTokenHandler +from aioquic.tls import SessionTicket from .h3 import H3Protocol from ..config import Config @@ -59,6 +63,12 @@ def __init__( self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False) self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile) + self.retry: Optional[QuicRetryTokenHandler] + if config.quic_retry: + self.retry = QuicRetryTokenHandler() + else: + self.retry = None + self.session_tickets: Dict[bytes, bytes] = {} @property def idle(self) -> bool: @@ -93,11 +103,49 @@ async def handle(self, event: Event) -> None: and header.packet_type == PACKET_TYPE_INITIAL and not self.context.terminated.is_set() ): + cid = header.destination_cid + retry_cid = None + if self.retry is not None: + if not header.token: + if header.version is None: + return + source_cid = token_bytes(8) + wire = encode_quic_retry( + version=header.version, + source_cid=source_cid, + destination_cid=header.source_cid, + original_destination_cid=header.destination_cid, + retry_token=self.retry.create_token( + event.address, header.destination_cid, source_cid + ), + ) + await self.send(RawData(data=wire, address=event.address)) + return + else: + try: + (cid, retry_cid) = self.retry.validate_token( + event.address, header.token + ) + if self.connections.get(cid) is not None: + # duplicate! + return + except ValueError: + return + fetcher: Optional[Callable] + handler: Optional[Callable] + if self.config.quic_max_saved_sessions > 0: + fetcher = self._get_session_ticket + handler = self._store_session_ticket + else: + fetcher = None + handler = None connection = QuicConnection( configuration=self.quic_config, - original_destination_connection_id=header.destination_cid, + original_destination_connection_id=cid, + retry_source_connection_id=retry_cid, + session_ticket_fetcher=fetcher, + session_ticket_handler=handler, ) - # This partial() needs python >= 3.8 state = ConnectionState(connection) timer = self.task_group.create_timer(partial(self._timeout, state)) state.timer = timer @@ -164,3 +212,18 @@ async def _timeout(self, state: ConnectionState) -> None: await self._handle_events(state, None) await self.send_all(connection) await state.timer.schedule(connection.get_timer()) + + def _get_session_ticket(self, ticket: bytes) -> None: + try: + self.session_tickets.pop(ticket) + except KeyError: + return None + + def _store_session_ticket(self, session_ticket: SessionTicket) -> None: + self.session_tickets[session_ticket.ticket] = session_ticket + # Implement a simple FIFO remembering the self.config.quic_max_saved_sessions + # most recent sessions. + while len(self.session_tickets) > self.config.quic_max_saved_sessions: + # Grab the first key + key = next(iter(self.session_tickets.keys())) + del self.session_tickets[key]