Skip to content

Improved H3 for hypercorn. #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/hypercorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
BYTES = 1
OCTETS = 1
SECONDS = 1.0
DEFAULT_QUIC_MAX_SAVED_SESSIONS = 100

FilePath = Union[AnyStr, os.PathLike]
SocketKind = Union[int, socket.SocketKind]
Expand Down Expand Up @@ -95,6 +96,8 @@ class Config:
max_requests: Optional[int] = None
max_requests_jitter: int = 0
pid_path: Optional[str] = None
quic_retry: bool = True
quic_max_saved_sessions: int = DEFAULT_QUIC_MAX_SAVED_SESSIONS
server_names: List[str] = []
shutdown_timeout = 60 * SECONDS
ssl_handshake_timeout = 60 * SECONDS
Expand Down
67 changes: 65 additions & 2 deletions src/hypercorn/protocol/quic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from functools import partial
from secrets import token_bytes
from typing import Awaitable, Callable, Dict, Optional, Set, Tuple

from aioquic.buffer import Buffer
Expand All @@ -15,9 +16,12 @@
)
from aioquic.quic.packet import (
encode_quic_version_negotiation,
encode_quic_retry,
PACKET_TYPE_INITIAL,
pull_quic_header,
)
from aioquic.quic.retry import QuicRetryTokenHandler
from aioquic.tls import SessionTicket

from .h3 import H3Protocol
from ..config import Config
Expand Down Expand Up @@ -59,6 +63,12 @@ def __init__(

self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
self.retry: Optional[QuicRetryTokenHandler]
if config.quic_retry:
self.retry = QuicRetryTokenHandler()
else:
self.retry = None
self.session_tickets: Dict[bytes, bytes] = {}

@property
def idle(self) -> bool:
Expand Down Expand Up @@ -93,11 +103,49 @@ async def handle(self, event: Event) -> None:
and header.packet_type == PACKET_TYPE_INITIAL
and not self.context.terminated.is_set()
):
cid = header.destination_cid
retry_cid = None
if self.retry is not None:
if not header.token:
if header.version is None:
return
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why return here, is a missing version an indication of an error?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retry is an option because it trades off increased latency for ensuring that you can roundtrip to the client and aren't just being used to amplify an attack (with the start of the TLS handshake) by someone spoofing UDP. I usually turn it on.

Re the missing version... if there is no header token, then this isn't the client retrying, and we have to ask them to retry, but to do that we need the version. If there isn't a version, they've sent us a versionless "short header" packet which is not a sensible thing to do, so we just drop the packet.

source_cid = token_bytes(8)
wire = encode_quic_retry(
version=header.version,
source_cid=source_cid,
destination_cid=header.source_cid,
original_destination_cid=header.destination_cid,
retry_token=self.retry.create_token(
event.address, header.destination_cid, source_cid
),
)
await self.send(RawData(data=wire, address=event.address))
return
else:
try:
(cid, retry_cid) = self.retry.validate_token(
event.address, header.token
)
if self.connections.get(cid) is not None:
# duplicate!
return
except ValueError:
return
fetcher: Optional[Callable]
handler: Optional[Callable]
if self.config.quic_max_saved_sessions > 0:
fetcher = self._get_session_ticket
handler = self._store_session_ticket
else:
fetcher = None
handler = None
connection = QuicConnection(
configuration=self.quic_config,
original_destination_connection_id=header.destination_cid,
original_destination_connection_id=cid,
retry_source_connection_id=retry_cid,
session_ticket_fetcher=fetcher,
session_ticket_handler=handler,
)
# This partial() needs python >= 3.8
state = ConnectionState(connection)
timer = self.task_group.create_timer(partial(self._timeout, state))
state.timer = timer
Expand Down Expand Up @@ -164,3 +212,18 @@ async def _timeout(self, state: ConnectionState) -> None:
await self._handle_events(state, None)
await self.send_all(connection)
await state.timer.schedule(connection.get_timer())

def _get_session_ticket(self, ticket: bytes) -> None:
try:
self.session_tickets.pop(ticket)
except KeyError:
return None

def _store_session_ticket(self, session_ticket: SessionTicket) -> None:
self.session_tickets[session_ticket.ticket] = session_ticket
# Implement a simple FIFO remembering the self.config.quic_max_saved_sessions
# most recent sessions.
while len(self.session_tickets) > self.config.quic_max_saved_sessions:
# Grab the first key
key = next(iter(self.session_tickets.keys()))
del self.session_tickets[key]