diff --git a/examples/doc-examples/example_net_stream.py b/examples/doc-examples/example_net_stream.py deleted file mode 100644 index d8842beab..000000000 --- a/examples/doc-examples/example_net_stream.py +++ /dev/null @@ -1,263 +0,0 @@ -""" -Enhanced NetStream Example for py-libp2p with State Management - -This example demonstrates the new NetStream features including: -- State tracking and transitions -- Proper error handling and validation -- Resource cleanup and event notifications -- Thread-safe operations with Trio locks - -Based on the standard echo demo but enhanced to show NetStream state management. -""" - -import argparse -import random -import secrets - -import multiaddr -import trio - -from libp2p import ( - new_host, -) -from libp2p.crypto.secp256k1 import ( - create_new_key_pair, -) -from libp2p.custom_types import ( - TProtocol, -) -from libp2p.network.stream.exceptions import ( - StreamClosed, - StreamEOF, - StreamReset, -) -from libp2p.network.stream.net_stream import ( - NetStream, - StreamState, -) -from libp2p.peer.peerinfo import ( - info_from_p2p_addr, -) - -PROTOCOL_ID = TProtocol("/echo/1.0.0") - - -async def enhanced_echo_handler(stream: NetStream) -> None: - """ - Enhanced echo handler that demonstrates NetStream state management. - """ - print(f"New connection established: {stream}") - print(f"Initial stream state: {await stream.state}") - - try: - # Verify stream is in expected initial state - assert await stream.state == StreamState.OPEN - assert await stream.is_readable() - assert await stream.is_writable() - print("āœ“ Stream initialized in OPEN state") - - # Read incoming data with proper state checking - print("Waiting for client data...") - - while await stream.is_readable(): - try: - # Read data from client - data = await stream.read(1024) - if not data: - print("Received empty data, client may have closed") - break - - print(f"Received: {data.decode('utf-8').strip()}") - - # Check if we can still write before echoing - if await stream.is_writable(): - await stream.write(data) - print(f"Echoed: {data.decode('utf-8').strip()}") - else: - print("Cannot echo - stream not writable") - break - - except StreamEOF: - print("Client closed their write side (EOF)") - break - except StreamReset: - print("Stream was reset by client") - return - except StreamClosed as e: - print(f"Stream operation failed: {e}") - break - - # Demonstrate graceful closure - current_state = await stream.state - print(f"Current state before close: {current_state}") - - if current_state not in [StreamState.CLOSE_BOTH, StreamState.RESET]: - await stream.close() - print("Server closed write side") - - final_state = await stream.state - print(f"Final stream state: {final_state}") - - except Exception as e: - print(f"Handler error: {e}") - # Reset stream on unexpected errors - if await stream.state not in [StreamState.RESET, StreamState.CLOSE_BOTH]: - await stream.reset() - print("Stream reset due to error") - - -async def enhanced_client_demo(stream: NetStream) -> None: - """ - Enhanced client that demonstrates various NetStream state scenarios. - """ - print(f"Client stream established: {stream}") - print(f"Initial state: {await stream.state}") - - try: - # Verify initial state - assert await stream.state == StreamState.OPEN - print("āœ“ Client stream in OPEN state") - - # Scenario 1: Normal communication - message = b"Hello from enhanced NetStream client!\n" - - if await stream.is_writable(): - await stream.write(message) - print(f"Sent: {message.decode('utf-8').strip()}") - else: - print("Cannot write - stream not writable") - return - - # Close write side to signal EOF to server - await stream.close() - print("Client closed write side") - - # Verify state transition - state_after_close = await stream.state - print(f"State after close: {state_after_close}") - assert state_after_close == StreamState.CLOSE_WRITE - assert await stream.is_readable() # Should still be readable - assert not await stream.is_writable() # Should not be writable - - # Try to write (should fail) - try: - await stream.write(b"This should fail") - print("ERROR: Write succeeded when it should have failed!") - except StreamClosed as e: - print(f"āœ“ Expected error when writing to closed stream: {e}") - - # Read the echo response - if await stream.is_readable(): - try: - response = await stream.read() - print(f"Received echo: {response.decode('utf-8').strip()}") - except StreamEOF: - print("Server closed their write side") - except StreamReset: - print("Stream was reset") - - # Check final state - final_state = await stream.state - print(f"Final client state: {final_state}") - - except Exception as e: - print(f"Client error: {e}") - # Reset on error - await stream.reset() - print("Client reset stream due to error") - - -async def run_enhanced_demo( - port: int, destination: str, seed: int | None = None -) -> None: - """ - Run enhanced echo demo with NetStream state management. - """ - listen_addr = multiaddr.Multiaddr(f"/ip4/0.0.0.0/tcp/{port}") - - # Generate or use provided key - if seed: - random.seed(seed) - secret_number = random.getrandbits(32 * 8) - secret = secret_number.to_bytes(length=32, byteorder="big") - else: - secret = secrets.token_bytes(32) - - host = new_host(key_pair=create_new_key_pair(secret)) - - async with host.run(listen_addrs=[listen_addr]): - print(f"Host ID: {host.get_id().to_string()}") - print("=" * 60) - - if not destination: # Server mode - print("šŸ–„ļø ENHANCED ECHO SERVER MODE") - print("=" * 60) - - # type: ignore: Stream is type of NetStream - host.set_stream_handler(PROTOCOL_ID, enhanced_echo_handler) - - print( - "Run client from another console:\n" - f"python3 example_net_stream.py " - f"-d {host.get_addrs()[0]}\n" - ) - print("Waiting for connections...") - print("Press Ctrl+C to stop server") - await trio.sleep_forever() - - else: # Client mode - print("šŸ“± ENHANCED ECHO CLIENT MODE") - print("=" * 60) - - # Connect to server - maddr = multiaddr.Multiaddr(destination) - info = info_from_p2p_addr(maddr) - await host.connect(info) - print(f"Connected to server: {info.peer_id.pretty()}") - - # Create stream and run enhanced demo - stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) - if isinstance(stream, NetStream): - await enhanced_client_demo(stream) - - print("\n" + "=" * 60) - print("CLIENT DEMO COMPLETE") - - -def main() -> None: - example_maddr = ( - "/ip4/127.0.0.1/tcp/8000/p2p/QmQn4SwGkDZKkUEpBRBvTmheQycxAHJUNmVEnjA2v1qe8Q" - ) - - parser = argparse.ArgumentParser( - formatter_class=argparse.RawDescriptionHelpFormatter - ) - parser.add_argument("-p", "--port", default=0, type=int, help="source port number") - parser.add_argument( - "-d", - "--destination", - type=str, - help=f"destination multiaddr string, e.g. {example_maddr}", - ) - parser.add_argument( - "-s", - "--seed", - type=int, - help="seed for deterministic peer ID generation", - ) - parser.add_argument( - "--demo-states", action="store_true", help="run state transition demo only" - ) - - args = parser.parse_args() - - try: - trio.run(run_enhanced_demo, args.port, args.destination, args.seed) - except KeyboardInterrupt: - print("\nšŸ‘‹ Demo interrupted by user") - except Exception as e: - print(f"āŒ Demo failed: {e}") - - -if __name__ == "__main__": - main() diff --git a/libp2p/network/connection/swarm_connection.py b/libp2p/network/connection/swarm_connection.py index c1b42c582..7f6b43554 100644 --- a/libp2p/network/connection/swarm_connection.py +++ b/libp2p/network/connection/swarm_connection.py @@ -13,6 +13,7 @@ ) from libp2p.network.stream.net_stream import ( NetStream, + StreamState, ) from libp2p.stream_muxer.exceptions import ( MuxedConnUnavailable, @@ -146,7 +147,13 @@ async def _handle_muxed_stream(self, muxed_stream: IMuxedStream) -> None: self.remove_stream(net_stream) async def _add_stream(self, muxed_stream: IMuxedStream) -> NetStream: - net_stream = NetStream(muxed_stream) + # + net_stream = NetStream(muxed_stream, self) + # Set Stream state to OPEN if the event has already started. + # This is to ensure that the new streams created after connection has started + # are immediately set to OPEN state. + if self.event_started.is_set(): + net_stream.set_state(StreamState.OPEN) self.streams.add(net_stream) await self.swarm.notify_opened_stream(net_stream) return net_stream @@ -155,6 +162,10 @@ async def _notify_disconnected(self) -> None: await self.swarm.notify_disconnected(self) async def start(self) -> None: + streams_open = self.get_streams() + for stream in streams_open: + """Set the state of the stream to OPEN.""" + stream.set_state(StreamState.OPEN) await self._handle_new_streams() async def new_stream(self) -> NetStream: diff --git a/libp2p/network/stream/net_stream.py b/libp2p/network/stream/net_stream.py index b54fdda4f..c57f71345 100644 --- a/libp2p/network/stream/net_stream.py +++ b/libp2p/network/stream/net_stream.py @@ -1,8 +1,10 @@ from enum import ( Enum, + auto, +) +from typing import ( + TYPE_CHECKING, ) - -import trio from libp2p.abc import ( IMuxedStream, @@ -14,10 +16,13 @@ from libp2p.stream_muxer.exceptions import ( MuxedStreamClosed, MuxedStreamEOF, - MuxedStreamError, MuxedStreamReset, ) +if TYPE_CHECKING: + from libp2p.network.connection.swarm_connection import SwarmConn + + from .exceptions import ( StreamClosed, StreamEOF, @@ -26,100 +31,25 @@ class StreamState(Enum): - """NetStream States""" - - OPEN = "open" - CLOSE_READ = "close_read" - CLOSE_WRITE = "close_write" - CLOSE_BOTH = "close_both" - RESET = "reset" + INIT = auto() + OPEN = auto() + CLOSED = auto() + RESET = auto() + ERROR = auto() class NetStream(INetStream): - """ - Summary - _______ - A Network stream implementation. - - NetStream wraps a muxed stream and provides proper state tracking, resource cleanup, - and event notification capabilities. - - State Machine - _____________ - - .. code:: markdown - - [CREATED] → OPEN → CLOSE_READ → CLOSE_BOTH → [CLEANUP] - ↓ ↗ ↗ - CLOSE_WRITE → ← ↗ - ↓ ↗ - RESET → → → → → → → → - - State Transitions - _________________ - - OPEN → CLOSE_READ: EOF encountered during read() - - OPEN → CLOSE_WRITE: Explicit close() call - - OPEN → RESET: reset() call or critical stream error - - CLOSE_READ → CLOSE_BOTH: Explicit close() call - - CLOSE_WRITE → CLOSE_BOTH: EOF encountered during read() - - Any state → RESET: reset() call - - Terminal States (trigger cleanup) - _________________________________ - - CLOSE_BOTH: Stream fully closed, triggers resource cleanup - - RESET: Stream reset/terminated, triggers resource cleanup - - Operation Validity by State - ___________________________ - OPEN: read() āœ“ write() āœ“ close() āœ“ reset() āœ“ - CLOSE_READ: read() āœ— write() āœ“ close() āœ“ reset() āœ“ - CLOSE_WRITE: read() āœ“ write() āœ— close() āœ“ reset() āœ“ - CLOSE_BOTH: read() āœ— write() āœ— close() āœ“ reset() āœ“ - RESET: read() āœ— write() āœ— close() āœ“ reset() āœ“ - - Cleanup Process (triggered by CLOSE_BOTH or RESET) - __________________________________________________ - 1. Remove stream from SwarmConn - 2. Notify all listeners with ClosedStream event - 3. Decrement reference counter - 4. Background cleanup via nursery (if provided) - - Thread Safety - _____________ - All state operations are protected by trio.Lock() for safe concurrent access. - State checks and modifications are atomic operations. - - Example: See :file:`examples/doc-examples/example_net_stream.py` - - :param muxed_stream (IMuxedStream): The underlying muxed stream - :param nursery (Optional[trio.Nursery]): Nursery for background cleanup tasks - :raises StreamClosed: When attempting invalid operations on closed streams - :raises StreamEOF: When EOF is encountered during read operations - :raises StreamReset: When the underlying stream has been reset - """ - muxed_stream: IMuxedStream protocol_id: TProtocol | None - __stream_state: StreamState def __init__( - self, muxed_stream: IMuxedStream, nursery: trio.Nursery | None = None + self, muxed_stream: IMuxedStream, swarm_conn: "SwarmConn | None" ) -> None: - super().__init__() - self.muxed_stream = muxed_stream self.muxed_conn = muxed_stream.muxed_conn self.protocol_id = None - - # For background tasks - self._nursery = nursery - - # State management - self.__stream_state = StreamState.OPEN - self._state_lock = trio.Lock() - - # For notification handling - self._notify_lock = trio.Lock() + self._state = StreamState.INIT + self.swarm_conn = swarm_conn def get_protocol(self) -> TProtocol | None: """ @@ -134,51 +64,37 @@ def set_protocol(self, protocol_id: TProtocol) -> None: self.protocol_id = protocol_id @property - async def state(self) -> StreamState: - """Get current stream state.""" - async with self._state_lock: - return self.__stream_state + def state(self) -> StreamState: + """ + :return: current state of the stream + """ + return self._state + + def set_state(self, state: StreamState) -> None: + """ + Set the current state of the stream. + + :param state: new state of the stream + """ + self._state = state async def read(self, n: int | None = None) -> bytes: """ Read from stream. :param n: number of bytes to read - :raises StreamClosed: If `NetStream` is closed for reading - :raises StreamReset: If `NetStream` is reset - :raises StreamEOF: If trying to read after reaching end of file :return: Bytes read from the stream """ - async with self._state_lock: - if self.__stream_state in [ - StreamState.CLOSE_READ, - StreamState.CLOSE_BOTH, - ]: - raise StreamClosed("Stream is closed for reading") - - if self.__stream_state == StreamState.RESET: - raise StreamReset("Stream is reset, cannot be used to read") - try: - data = await self.muxed_stream.read(n) - return data + if self.state == StreamState.RESET: + raise StreamReset("Cannot read from stream; stream is reset") + elif self.state != StreamState.OPEN: + raise StreamClosed("Cannot read from stream; not open") + else: + return await self.muxed_stream.read(n) except MuxedStreamEOF as error: - async with self._state_lock: - if self.__stream_state == StreamState.CLOSE_WRITE: - self.__stream_state = StreamState.CLOSE_BOTH - await self._remove() - elif self.__stream_state == StreamState.OPEN: - self.__stream_state = StreamState.CLOSE_READ raise StreamEOF() from error except MuxedStreamReset as error: - async with self._state_lock: - if self.__stream_state in [ - StreamState.OPEN, - StreamState.CLOSE_READ, - StreamState.CLOSE_WRITE, - ]: - self.__stream_state = StreamState.RESET - await self._remove() raise StreamReset() from error async def write(self, data: bytes) -> None: @@ -186,123 +102,41 @@ async def write(self, data: bytes) -> None: Write to stream. :param data: bytes to write - :raises StreamClosed: If `NetStream` is closed for writing or reset - :raises StreamClosed: If `StreamError` occurred while writing """ - async with self._state_lock: - if self.__stream_state in [ - StreamState.CLOSE_WRITE, - StreamState.CLOSE_BOTH, - StreamState.RESET, - ]: - raise StreamClosed("Stream is closed for writing") - try: - await self.muxed_stream.write(data) - except (MuxedStreamClosed, MuxedStreamError) as error: - async with self._state_lock: - if self.__stream_state == StreamState.OPEN: - self.__stream_state = StreamState.CLOSE_WRITE - elif self.__stream_state == StreamState.CLOSE_READ: - self.__stream_state = StreamState.CLOSE_BOTH - await self._remove() + if self.state == StreamState.RESET: + raise StreamReset("Cannot write to stream; stream is reset") + elif self.state != StreamState.OPEN: + raise StreamClosed("Cannot write to stream; not open") + else: + await self.muxed_stream.write(data) + except MuxedStreamClosed as error: + self.set_state(StreamState.CLOSED) raise StreamClosed() from error + except MuxedStreamReset as error: + self.set_state(StreamState.RESET) + raise StreamReset() from error async def close(self) -> None: - """Close stream for writing.""" - async with self._state_lock: - if self.__stream_state in [ - StreamState.CLOSE_BOTH, - StreamState.RESET, - StreamState.CLOSE_WRITE, - ]: - return - + """Close stream.""" await self.muxed_stream.close() - - async with self._state_lock: - if self.__stream_state == StreamState.CLOSE_READ: - self.__stream_state = StreamState.CLOSE_BOTH - await self._remove() - elif self.__stream_state == StreamState.OPEN: - self.__stream_state = StreamState.CLOSE_WRITE + self.set_state(StreamState.CLOSED) + await self.remove() async def reset(self) -> None: - """Reset stream, closing both ends.""" - async with self._state_lock: - if self.__stream_state == StreamState.RESET: - return - + """Reset stream.""" await self.muxed_stream.reset() + self.set_state(StreamState.RESET) + await self.remove() - async with self._state_lock: - if self.__stream_state in [ - StreamState.OPEN, - StreamState.CLOSE_READ, - StreamState.CLOSE_WRITE, - ]: - self.__stream_state = StreamState.RESET - await self._remove() - - async def _remove(self) -> None: - """ - Remove stream from connection and notify listeners. - This is called when the stream is fully closed or reset. - """ - if hasattr(self.muxed_conn, "remove_stream"): - remove_stream = getattr(self.muxed_conn, "remove_stream") - await remove_stream(self) - - # Notify in background using Trio nursery if available - if self._nursery: - self._nursery.start_soon(self._notify_closed) - else: - await self._notify_closed() - - async def _notify_closed(self) -> None: + async def remove(self) -> None: """ - Notify all listeners that the stream has been closed. - This runs in a separate task to avoid blocking the main flow. + Remove the stream from the connection and notify swarm that stream was closed. """ - async with self._notify_lock: - if hasattr(self.muxed_conn, "swarm"): - swarm = getattr(self.muxed_conn, "swarm") - - if hasattr(swarm, "notify_all"): - await swarm.notify_all( - lambda notifiee: notifiee.closed_stream(swarm, self) - ) - - if hasattr(swarm, "refs") and hasattr(swarm.refs, "done"): - swarm.refs.done() + if self.swarm_conn is not None: + self.swarm_conn.remove_stream(self) + await self.swarm_conn.swarm.notify_closed_stream(self) def get_remote_address(self) -> tuple[str, int] | None: """Delegate to the underlying muxed stream.""" return self.muxed_stream.get_remote_address() - - async def is_closed(self) -> bool: - """Check if stream is closed.""" - current_state = await self.state - return current_state in [StreamState.CLOSE_BOTH, StreamState.RESET] - - async def is_readable(self) -> bool: - """Check if stream is readable.""" - current_state = await self.state - return current_state not in [ - StreamState.CLOSE_READ, - StreamState.CLOSE_BOTH, - StreamState.RESET, - ] - - async def is_writable(self) -> bool: - """Check if stream is writable.""" - current_state = await self.state - return current_state not in [ - StreamState.CLOSE_WRITE, - StreamState.CLOSE_BOTH, - StreamState.RESET, - ] - - def __str__(self) -> str: - """String representation of the stream.""" - return f"" diff --git a/libp2p/stream_muxer/yamux/yamux.py b/libp2p/stream_muxer/yamux/yamux.py index b2711e1a8..5b2496893 100644 --- a/libp2p/stream_muxer/yamux/yamux.py +++ b/libp2p/stream_muxer/yamux/yamux.py @@ -127,6 +127,9 @@ async def write(self, data: bytes) -> None: "Timed out waiting for window update after 5 seconds." ) + if self.reset_received: + raise MuxedStreamReset("Stream was reset") + if self.closed: raise MuxedStreamError("Stream is closed") diff --git a/newsfragments/632.feature.rst b/newsfragments/632.feature.rst new file mode 100644 index 000000000..f18fe7ca4 --- /dev/null +++ b/newsfragments/632.feature.rst @@ -0,0 +1,2 @@ +Adds the `StreamState` to the NetStream class to manage the state of network streams more effectively. +Adds the `remove` method to notify the Swarm that a stream has been removed. diff --git a/tests/core/network/test_net_stream.py b/tests/core/network/test_net_stream.py index efd64c25b..269a53235 100644 --- a/tests/core/network/test_net_stream.py +++ b/tests/core/network/test_net_stream.py @@ -58,7 +58,7 @@ async def test_net_stream_read_after_remote_closed(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.write(DATA) await stream_0.close() - await trio.sleep(0.5) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA with pytest.raises(StreamEOF): await stream_1.read(MAX_READ_LEN) @@ -90,7 +90,7 @@ async def test_net_stream_read_after_remote_closed_and_reset(net_stream_pair): await stream_0.close() await stream_0.reset() # Sleep to let `stream_1` receive the message. - await trio.sleep(1) + await trio.sleep(0.01) assert (await stream_1.read(MAX_READ_LEN)) == DATA @@ -107,7 +107,7 @@ async def test_net_stream_write_after_local_closed(net_stream_pair): async def test_net_stream_write_after_local_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_0.reset() - with pytest.raises(StreamClosed): + with pytest.raises(StreamReset): await stream_0.write(DATA) @@ -116,5 +116,5 @@ async def test_net_stream_write_after_remote_reset(net_stream_pair): stream_0, stream_1 = net_stream_pair await stream_1.reset() await trio.sleep(0.01) - with pytest.raises(StreamClosed): + with pytest.raises(StreamReset): await stream_0.write(DATA)