Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 90 additions & 11 deletions libp2p/stream_muxer/mplex/mplex_stream.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may use trio.current_time() instead of time.time().
time.time() returns wall-clock time, which can jump backward or forward if the system clock changes (e.g., due to NTP).
trio.current_time() is monotonic and stable, making it the safer choice for deadlines and timeouts in async code.

from collections.abc import Awaitable, Callable
from types import (
TracebackType,
Expand Down Expand Up @@ -35,6 +36,12 @@
)


class MplexStreamTimeout(Exception):
"""Raised when a stream operation exceeds its deadline."""

pass


class MplexStream(IMuxedStream):
"""
reference: https://github.com/libp2p/go-mplex/blob/master/stream.go
Expand All @@ -46,8 +53,8 @@ class MplexStream(IMuxedStream):
# class of IMuxedConn. Ignoring this type assignment should not pose
# any risk.
muxed_conn: "Mplex" # type: ignore[assignment]
read_deadline: int | None
write_deadline: int | None
read_deadline: float | None
write_deadline: float | None

rw_lock: ReadWriteLock
close_lock: trio.Lock
Expand Down Expand Up @@ -91,6 +98,30 @@ def __init__(
def is_initiator(self) -> bool:
return self.stream_id.is_initiator

def _check_read_deadline(self) -> None:
"""Check if read deadline has expired and raise timeout if needed."""
if self.read_deadline is not None and time.time() > self.read_deadline:
raise MplexStreamTimeout("Read operation exceeded deadline")

def _check_write_deadline(self) -> None:
"""Check if write deadline has expired and raise timeout if needed."""
if self.write_deadline is not None and time.time() > self.write_deadline:
raise MplexStreamTimeout("Write operation exceeded deadline")

def _get_read_timeout(self) -> float | None:
"""Calculate remaining time until read deadline."""
if self.read_deadline is None:
return None
remaining = self.read_deadline - time.time()
return max(0.0, remaining) if remaining > 0 else 0

def _get_write_timeout(self) -> float | None:
"""Calculate remaining time until write deadline."""
if self.write_deadline is None:
return None
remaining = self.write_deadline - time.time()
return max(0.0, remaining) if remaining > 0 else 0

async def _read_until_eof(self) -> bytes:
async for data in self.incoming_data_channel:
self._buf.extend(data)
Expand Down Expand Up @@ -137,6 +168,8 @@ async def _do_read(self, n: int | None = None) -> bytes:
:param n: number of bytes to read
:return: bytes actually read
"""
# check deadline before starting
self._check_read_deadline()
async with self.rw_lock.read_lock():
if n is not None and n < 0:
raise ValueError(
Expand All @@ -146,9 +179,12 @@ async def _do_read(self, n: int | None = None) -> bytes:
if self.event_reset.is_set():
raise MplexStreamReset
if n is None:
return await self._read_until_eof()
return await self._read_until_eof_with_timeout()
# check deadline again before potentially blocking operation
self._check_read_deadline()
if len(self._buf) == 0:
data: bytes
timeout = self._get_read_timeout()
# Peek whether there is data available. If yes, we just read until
# there is no data, then return.
try:
Expand All @@ -160,8 +196,20 @@ async def _do_read(self, n: int | None = None) -> bytes:
# We know `receive` will be blocked here. Wait for data here with
# `receive` and catch all kinds of errors here.
try:
data = await self.incoming_data_channel.receive()
if timeout is not None and timeout <= 0:
raise MplexStreamTimeout(
"Read deadline exceeded while waiting for data"
)

if timeout is not None:
with trio.fail_after(timeout):
data = await self.incoming_data_channel.receive()
else:
data = await self.incoming_data_channel.receive()

self._buf.extend(data)
except trio.TooSlowError:
raise MplexStreamTimeout("Read operation timed out")
except trio.EndOfChannel:
if self.event_reset.is_set():
raise MplexStreamReset
Expand All @@ -181,6 +229,25 @@ async def _do_read(self, n: int | None = None) -> bytes:
self._buf = self._buf[len(payload) :]
return bytes(payload)

async def _read_until_eof_with_timeout(self) -> bytes:
"""Read until EOF with timeout support."""
timeout = self._get_read_timeout()

try:
if timeout is not None:
with trio.fail_after(timeout):
async for data in self.incoming_data_channel:
self._buf.extend(data)
else:
async for data in self.incoming_data_channel:
self._buf.extend(data)
except trio.TooSlowError:
raise MplexStreamTimeout("Read until EOF operation timed out")

payload = self._buf
self._buf = self._buf[len(payload) :]
return bytes(payload)

async def write(self, data: bytes) -> None:
"""
Write to stream.
Expand All @@ -204,9 +271,20 @@ async def _do_write(self, data: bytes) -> None:

:param data: bytes to write
"""
# Check deadline before starting
self._check_write_deadline()
async with self.rw_lock.write_lock():
if self.event_local_closed.is_set():
raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}")
raise MplexStreamClosed(
f"cannot write to closed stream: data={data!r}"
)

# Check deadline again after acquiring lock
timeout = self._get_write_timeout()
if timeout is not None and timeout <= 0:
raise MplexStreamTimeout("Write deadline exceeded")

# Determine appropriate header flag for a message operation
flag = self._get_header_flag("message")
await self.muxed_conn.send_message(flag, data, self.stream_id)

Expand Down Expand Up @@ -345,10 +423,9 @@ def set_deadline(self, ttl: int) -> bool:
:param ttl: timeout in seconds for read and write operations
:return: True if successful, False if ttl is invalid (negative)
"""
if not self._validate_ttl(ttl):
return False
self.read_deadline = ttl
self.write_deadline = ttl
deadline = time.time() + ttl
self.read_deadline = deadline
self.write_deadline = deadline
return True

def set_read_deadline(self, ttl: int) -> bool:
Expand All @@ -362,7 +439,8 @@ def set_read_deadline(self, ttl: int) -> bool:
:param ttl: timeout in seconds for read operations
:return: True if successful, False if ttl is invalid (negative)
"""
return self._set_deadline_with_validation(ttl, "read_deadline")
self.read_deadline = time.time() + ttl
return True

def set_write_deadline(self, ttl: int) -> bool:
"""
Expand All @@ -375,7 +453,8 @@ def set_write_deadline(self, ttl: int) -> bool:
:param ttl: timeout in seconds for write operations
:return: True if successful, False if ttl is invalid (negative)
"""
return self._set_deadline_with_validation(ttl, "write_deadline")
self.write_deadline = ttl + time.time()
return True

def get_remote_address(self) -> tuple[str, int] | None:
"""Delegate to the parent Mplex connection."""
Expand Down
Loading