diff --git a/libp2p/stream_muxer/mplex/mplex_stream.py b/libp2p/stream_muxer/mplex/mplex_stream.py index 4ae7af699..c3976a00e 100644 --- a/libp2p/stream_muxer/mplex/mplex_stream.py +++ b/libp2p/stream_muxer/mplex/mplex_stream.py @@ -1,3 +1,4 @@ +import time from collections.abc import Awaitable, Callable from types import ( TracebackType, @@ -35,6 +36,12 @@ ) +class MplexStreamTimeout(Exception): + """Raised when a stream operation exceeds its deadline.""" + + pass + + class MplexStream(IMuxedStream): """ reference: https://github.com/libp2p/go-mplex/blob/master/stream.go @@ -46,8 +53,8 @@ class MplexStream(IMuxedStream): # class of IMuxedConn. Ignoring this type assignment should not pose # any risk. muxed_conn: "Mplex" # type: ignore[assignment] - read_deadline: int | None - write_deadline: int | None + read_deadline: float | None + write_deadline: float | None rw_lock: ReadWriteLock close_lock: trio.Lock @@ -91,6 +98,30 @@ def __init__( def is_initiator(self) -> bool: return self.stream_id.is_initiator + def _check_read_deadline(self) -> None: + """Check if read deadline has expired and raise timeout if needed.""" + if self.read_deadline is not None and time.time() > self.read_deadline: + raise MplexStreamTimeout("Read operation exceeded deadline") + + def _check_write_deadline(self) -> None: + """Check if write deadline has expired and raise timeout if needed.""" + if self.write_deadline is not None and time.time() > self.write_deadline: + raise MplexStreamTimeout("Write operation exceeded deadline") + + def _get_read_timeout(self) -> float | None: + """Calculate remaining time until read deadline.""" + if self.read_deadline is None: + return None + remaining = self.read_deadline - time.time() + return max(0.0, remaining) if remaining > 0 else 0 + + def _get_write_timeout(self) -> float | None: + """Calculate remaining time until write deadline.""" + if self.write_deadline is None: + return None + remaining = self.write_deadline - time.time() + return max(0.0, remaining) if remaining > 0 else 0 + async def _read_until_eof(self) -> bytes: async for data in self.incoming_data_channel: self._buf.extend(data) @@ -137,6 +168,8 @@ async def _do_read(self, n: int | None = None) -> bytes: :param n: number of bytes to read :return: bytes actually read """ + # check deadline before starting + self._check_read_deadline() async with self.rw_lock.read_lock(): if n is not None and n < 0: raise ValueError( @@ -146,9 +179,12 @@ async def _do_read(self, n: int | None = None) -> bytes: if self.event_reset.is_set(): raise MplexStreamReset if n is None: - return await self._read_until_eof() + return await self._read_until_eof_with_timeout() + # check deadline again before potentially blocking operation + self._check_read_deadline() if len(self._buf) == 0: data: bytes + timeout = self._get_read_timeout() # Peek whether there is data available. If yes, we just read until # there is no data, then return. try: @@ -160,8 +196,20 @@ async def _do_read(self, n: int | None = None) -> bytes: # We know `receive` will be blocked here. Wait for data here with # `receive` and catch all kinds of errors here. try: - data = await self.incoming_data_channel.receive() + if timeout is not None and timeout <= 0: + raise MplexStreamTimeout( + "Read deadline exceeded while waiting for data" + ) + + if timeout is not None: + with trio.fail_after(timeout): + data = await self.incoming_data_channel.receive() + else: + data = await self.incoming_data_channel.receive() + self._buf.extend(data) + except trio.TooSlowError: + raise MplexStreamTimeout("Read operation timed out") except trio.EndOfChannel: if self.event_reset.is_set(): raise MplexStreamReset @@ -181,6 +229,25 @@ async def _do_read(self, n: int | None = None) -> bytes: self._buf = self._buf[len(payload) :] return bytes(payload) + async def _read_until_eof_with_timeout(self) -> bytes: + """Read until EOF with timeout support.""" + timeout = self._get_read_timeout() + + try: + if timeout is not None: + with trio.fail_after(timeout): + async for data in self.incoming_data_channel: + self._buf.extend(data) + else: + async for data in self.incoming_data_channel: + self._buf.extend(data) + except trio.TooSlowError: + raise MplexStreamTimeout("Read until EOF operation timed out") + + payload = self._buf + self._buf = self._buf[len(payload) :] + return bytes(payload) + async def write(self, data: bytes) -> None: """ Write to stream. @@ -204,9 +271,20 @@ async def _do_write(self, data: bytes) -> None: :param data: bytes to write """ + # Check deadline before starting + self._check_write_deadline() async with self.rw_lock.write_lock(): if self.event_local_closed.is_set(): - raise MplexStreamClosed(f"cannot write to closed stream: data={data!r}") + raise MplexStreamClosed( + f"cannot write to closed stream: data={data!r}" + ) + + # Check deadline again after acquiring lock + timeout = self._get_write_timeout() + if timeout is not None and timeout <= 0: + raise MplexStreamTimeout("Write deadline exceeded") + + # Determine appropriate header flag for a message operation flag = self._get_header_flag("message") await self.muxed_conn.send_message(flag, data, self.stream_id) @@ -345,10 +423,9 @@ def set_deadline(self, ttl: int) -> bool: :param ttl: timeout in seconds for read and write operations :return: True if successful, False if ttl is invalid (negative) """ - if not self._validate_ttl(ttl): - return False - self.read_deadline = ttl - self.write_deadline = ttl + deadline = time.time() + ttl + self.read_deadline = deadline + self.write_deadline = deadline return True def set_read_deadline(self, ttl: int) -> bool: @@ -362,7 +439,8 @@ def set_read_deadline(self, ttl: int) -> bool: :param ttl: timeout in seconds for read operations :return: True if successful, False if ttl is invalid (negative) """ - return self._set_deadline_with_validation(ttl, "read_deadline") + self.read_deadline = time.time() + ttl + return True def set_write_deadline(self, ttl: int) -> bool: """ @@ -375,7 +453,8 @@ def set_write_deadline(self, ttl: int) -> bool: :param ttl: timeout in seconds for write operations :return: True if successful, False if ttl is invalid (negative) """ - return self._set_deadline_with_validation(ttl, "write_deadline") + self.write_deadline = ttl + time.time() + return True def get_remote_address(self) -> tuple[str, int] | None: """Delegate to the parent Mplex connection."""