Skip to content

Commit 9b0f750

Browse files
authored
Merge pull request #896 from unniznd/fix_expose_timeout_muxer_multistream
fix: Added timeout paramter into muxer multistream
2 parents 81cc2f0 + ecdb770 commit 9b0f750

File tree

6 files changed

+157
-7
lines changed

6 files changed

+157
-7
lines changed

libp2p/host/basic_host.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ async def new_stream(
213213
self,
214214
peer_id: ID,
215215
protocol_ids: Sequence[TProtocol],
216-
negotitate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
217216
) -> INetStream:
218217
"""
219218
:param peer_id: peer_id that host is connecting
@@ -227,7 +226,7 @@ async def new_stream(
227226
selected_protocol = await self.multiselect_client.select_one_of(
228227
list(protocol_ids),
229228
MultiselectCommunicator(net_stream),
230-
negotitate_timeout,
229+
self.negotiate_timeout,
231230
)
232231
except MultiselectClientError as error:
233232
logger.debug("fail to open a stream to peer %s, error=%s", peer_id, error)

libp2p/stream_muxer/muxer_multistream.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
MultiselectError,
2222
)
2323
from libp2p.protocol_muxer.multiselect import (
24+
DEFAULT_NEGOTIATE_TIMEOUT,
2425
Multiselect,
2526
)
2627
from libp2p.protocol_muxer.multiselect_client import (
@@ -46,11 +47,17 @@ class MuxerMultistream:
4647
transports: "OrderedDict[TProtocol, TMuxerClass]"
4748
multiselect: Multiselect
4849
multiselect_client: MultiselectClient
50+
negotiate_timeout: int
4951

50-
def __init__(self, muxer_transports_by_protocol: TMuxerOptions) -> None:
52+
def __init__(
53+
self,
54+
muxer_transports_by_protocol: TMuxerOptions,
55+
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
56+
) -> None:
5157
self.transports = OrderedDict()
5258
self.multiselect = Multiselect()
5359
self.multistream_client = MultiselectClient()
60+
self.negotiate_timeout = negotiate_timeout
5461
for protocol, transport in muxer_transports_by_protocol.items():
5562
self.add_transport(protocol, transport)
5663

@@ -80,10 +87,12 @@ async def select_transport(self, conn: IRawConnection) -> TMuxerClass:
8087
communicator = MultiselectCommunicator(conn)
8188
if conn.is_initiator:
8289
protocol = await self.multiselect_client.select_one_of(
83-
tuple(self.transports.keys()), communicator
90+
tuple(self.transports.keys()), communicator, self.negotiate_timeout
8491
)
8592
else:
86-
protocol, _ = await self.multiselect.negotiate(communicator)
93+
protocol, _ = await self.multiselect.negotiate(
94+
communicator, self.negotiate_timeout
95+
)
8796
if protocol is None:
8897
raise MultiselectError(
8998
"Fail to negotiate a stream muxer protocol: no protocol selected"
@@ -93,7 +102,7 @@ async def select_transport(self, conn: IRawConnection) -> TMuxerClass:
93102
async def new_conn(self, conn: ISecureConn, peer_id: ID) -> IMuxedConn:
94103
communicator = MultiselectCommunicator(conn)
95104
protocol = await self.multistream_client.select_one_of(
96-
tuple(self.transports.keys()), communicator
105+
tuple(self.transports.keys()), communicator, self.negotiate_timeout
97106
)
98107
transport_class = self.transports[protocol]
99108
if protocol == PROTOCOL_ID:

libp2p/transport/upgrader.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
MultiselectClientError,
1515
MultiselectError,
1616
)
17+
from libp2p.protocol_muxer.multiselect import (
18+
DEFAULT_NEGOTIATE_TIMEOUT,
19+
)
1720
from libp2p.security.exceptions import (
1821
HandshakeFailure,
1922
)
@@ -37,9 +40,12 @@ def __init__(
3740
self,
3841
secure_transports_by_protocol: TSecurityOptions,
3942
muxer_transports_by_protocol: TMuxerOptions,
43+
negotiate_timeout: int = DEFAULT_NEGOTIATE_TIMEOUT,
4044
):
4145
self.security_multistream = SecurityMultistream(secure_transports_by_protocol)
42-
self.muxer_multistream = MuxerMultistream(muxer_transports_by_protocol)
46+
self.muxer_multistream = MuxerMultistream(
47+
muxer_transports_by_protocol, negotiate_timeout
48+
)
4349

4450
async def upgrade_security(
4551
self,

newsfragments/896.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Exposed timeout method in muxer multistream and updated all the usage. Added testcases to verify that timeout value is passed correctly
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from unittest.mock import (
2+
AsyncMock,
3+
MagicMock,
4+
)
5+
6+
import pytest
7+
8+
from libp2p.custom_types import (
9+
TMuxerClass,
10+
TProtocol,
11+
)
12+
from libp2p.peer.id import (
13+
ID,
14+
)
15+
from libp2p.protocol_muxer.exceptions import (
16+
MultiselectError,
17+
)
18+
from libp2p.stream_muxer.muxer_multistream import (
19+
MuxerMultistream,
20+
)
21+
22+
23+
@pytest.mark.trio
24+
async def test_muxer_timeout_configuration():
25+
"""Test that muxer respects timeout configuration."""
26+
muxer = MuxerMultistream({}, negotiate_timeout=1)
27+
assert muxer.negotiate_timeout == 1
28+
29+
30+
@pytest.mark.trio
31+
async def test_select_transport_passes_timeout_to_multiselect():
32+
"""Test that timeout is passed to multiselect client in select_transport."""
33+
# Mock dependencies
34+
mock_conn = MagicMock()
35+
mock_conn.is_initiator = False
36+
37+
# Mock MultiselectClient
38+
muxer = MuxerMultistream({}, negotiate_timeout=10)
39+
muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None))
40+
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
41+
42+
# Call select_transport
43+
await muxer.select_transport(mock_conn)
44+
45+
# Verify that select_one_of was called with the correct timeout
46+
args, _ = muxer.multiselect.negotiate.call_args
47+
assert args[1] == 10
48+
49+
50+
@pytest.mark.trio
51+
async def test_new_conn_passes_timeout_to_multistream_client():
52+
"""Test that timeout is passed to multistream client in new_conn."""
53+
# Mock dependencies
54+
mock_conn = MagicMock()
55+
mock_conn.is_initiator = True
56+
mock_peer_id = ID(b"test_peer")
57+
mock_communicator = MagicMock()
58+
59+
# Mock MultistreamClient and transports
60+
muxer = MuxerMultistream({}, negotiate_timeout=30)
61+
muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol")
62+
muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock())
63+
64+
# Call new_conn
65+
await muxer.new_conn(mock_conn, mock_peer_id)
66+
67+
# Verify that select_one_of was called with the correct timeout
68+
muxer.multistream_client.select_one_of(
69+
tuple(muxer.transports.keys()), mock_communicator, 30
70+
)
71+
72+
73+
@pytest.mark.trio
74+
async def test_select_transport_no_protocol_selected():
75+
"""
76+
Test that select_transport raises MultiselectError when no protocol is selected.
77+
"""
78+
# Mock dependencies
79+
mock_conn = MagicMock()
80+
mock_conn.is_initiator = False
81+
82+
# Mock Multiselect to return None
83+
muxer = MuxerMultistream({}, negotiate_timeout=30)
84+
muxer.multiselect.negotiate = AsyncMock(return_value=(None, None))
85+
86+
# Expect MultiselectError to be raised
87+
with pytest.raises(MultiselectError, match="no protocol selected"):
88+
await muxer.select_transport(mock_conn)
89+
90+
91+
@pytest.mark.trio
92+
async def test_add_transport_updates_precedence():
93+
"""Test that adding a transport updates protocol precedence."""
94+
# Mock transport classes
95+
mock_transport1 = MagicMock(spec=TMuxerClass)
96+
mock_transport2 = MagicMock(spec=TMuxerClass)
97+
98+
# Initialize muxer and add transports
99+
muxer = MuxerMultistream({}, negotiate_timeout=30)
100+
muxer.add_transport(TProtocol("proto1"), mock_transport1)
101+
muxer.add_transport(TProtocol("proto2"), mock_transport2)
102+
103+
# Verify transport order
104+
assert list(muxer.transports.keys()) == ["proto1", "proto2"]
105+
106+
# Re-add proto1 to check if it moves to the end
107+
muxer.add_transport(TProtocol("proto1"), mock_transport1)
108+
assert list(muxer.transports.keys()) == ["proto2", "proto1"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from libp2p.custom_types import (
4+
TMuxerOptions,
5+
TSecurityOptions,
6+
)
7+
from libp2p.transport.upgrader import (
8+
TransportUpgrader,
9+
)
10+
11+
12+
@pytest.mark.trio
13+
async def test_transport_upgrader_security_and_muxer_initialization():
14+
"""Test TransportUpgrader initializes security and muxer multistreams correctly."""
15+
secure_transports: TSecurityOptions = {}
16+
muxer_transports: TMuxerOptions = {}
17+
negotiate_timeout = 15
18+
19+
upgrader = TransportUpgrader(
20+
secure_transports, muxer_transports, negotiate_timeout=negotiate_timeout
21+
)
22+
23+
# Verify security multistream initialization
24+
assert upgrader.security_multistream.transports == secure_transports
25+
# Verify muxer multistream initialization and timeout
26+
assert upgrader.muxer_multistream.transports == muxer_transports
27+
assert upgrader.muxer_multistream.negotiate_timeout == negotiate_timeout

0 commit comments

Comments
 (0)