|
| 1 | +from unittest.mock import ( |
| 2 | + AsyncMock, |
| 3 | + MagicMock, |
| 4 | +) |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +from libp2p.custom_types import ( |
| 9 | + TMuxerClass, |
| 10 | + TProtocol, |
| 11 | +) |
| 12 | +from libp2p.peer.id import ( |
| 13 | + ID, |
| 14 | +) |
| 15 | +from libp2p.protocol_muxer.exceptions import ( |
| 16 | + MultiselectError, |
| 17 | +) |
| 18 | +from libp2p.stream_muxer.muxer_multistream import ( |
| 19 | + MuxerMultistream, |
| 20 | +) |
| 21 | + |
| 22 | + |
| 23 | +@pytest.mark.trio |
| 24 | +async def test_muxer_timeout_configuration(): |
| 25 | + """Test that muxer respects timeout configuration.""" |
| 26 | + muxer = MuxerMultistream({}, negotiate_timeout=1) |
| 27 | + assert muxer.negotiate_timeout == 1 |
| 28 | + |
| 29 | + |
| 30 | +@pytest.mark.trio |
| 31 | +async def test_select_transport_passes_timeout_to_multiselect(): |
| 32 | + """Test that timeout is passed to multiselect client in select_transport.""" |
| 33 | + # Mock dependencies |
| 34 | + mock_conn = MagicMock() |
| 35 | + mock_conn.is_initiator = False |
| 36 | + |
| 37 | + # Mock MultiselectClient |
| 38 | + muxer = MuxerMultistream({}, negotiate_timeout=10) |
| 39 | + muxer.multiselect.negotiate = AsyncMock(return_value=("mock_protocol", None)) |
| 40 | + muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock()) |
| 41 | + |
| 42 | + # Call select_transport |
| 43 | + await muxer.select_transport(mock_conn) |
| 44 | + |
| 45 | + # Verify that select_one_of was called with the correct timeout |
| 46 | + args, _ = muxer.multiselect.negotiate.call_args |
| 47 | + assert args[1] == 10 |
| 48 | + |
| 49 | + |
| 50 | +@pytest.mark.trio |
| 51 | +async def test_new_conn_passes_timeout_to_multistream_client(): |
| 52 | + """Test that timeout is passed to multistream client in new_conn.""" |
| 53 | + # Mock dependencies |
| 54 | + mock_conn = MagicMock() |
| 55 | + mock_conn.is_initiator = True |
| 56 | + mock_peer_id = ID(b"test_peer") |
| 57 | + mock_communicator = MagicMock() |
| 58 | + |
| 59 | + # Mock MultistreamClient and transports |
| 60 | + muxer = MuxerMultistream({}, negotiate_timeout=30) |
| 61 | + muxer.multistream_client.select_one_of = AsyncMock(return_value="mock_protocol") |
| 62 | + muxer.transports[TProtocol("mock_protocol")] = MagicMock(return_value=MagicMock()) |
| 63 | + |
| 64 | + # Call new_conn |
| 65 | + await muxer.new_conn(mock_conn, mock_peer_id) |
| 66 | + |
| 67 | + # Verify that select_one_of was called with the correct timeout |
| 68 | + muxer.multistream_client.select_one_of( |
| 69 | + tuple(muxer.transports.keys()), mock_communicator, 30 |
| 70 | + ) |
| 71 | + |
| 72 | + |
| 73 | +@pytest.mark.trio |
| 74 | +async def test_select_transport_no_protocol_selected(): |
| 75 | + """ |
| 76 | + Test that select_transport raises MultiselectError when no protocol is selected. |
| 77 | + """ |
| 78 | + # Mock dependencies |
| 79 | + mock_conn = MagicMock() |
| 80 | + mock_conn.is_initiator = False |
| 81 | + |
| 82 | + # Mock Multiselect to return None |
| 83 | + muxer = MuxerMultistream({}, negotiate_timeout=30) |
| 84 | + muxer.multiselect.negotiate = AsyncMock(return_value=(None, None)) |
| 85 | + |
| 86 | + # Expect MultiselectError to be raised |
| 87 | + with pytest.raises(MultiselectError, match="no protocol selected"): |
| 88 | + await muxer.select_transport(mock_conn) |
| 89 | + |
| 90 | + |
| 91 | +@pytest.mark.trio |
| 92 | +async def test_add_transport_updates_precedence(): |
| 93 | + """Test that adding a transport updates protocol precedence.""" |
| 94 | + # Mock transport classes |
| 95 | + mock_transport1 = MagicMock(spec=TMuxerClass) |
| 96 | + mock_transport2 = MagicMock(spec=TMuxerClass) |
| 97 | + |
| 98 | + # Initialize muxer and add transports |
| 99 | + muxer = MuxerMultistream({}, negotiate_timeout=30) |
| 100 | + muxer.add_transport(TProtocol("proto1"), mock_transport1) |
| 101 | + muxer.add_transport(TProtocol("proto2"), mock_transport2) |
| 102 | + |
| 103 | + # Verify transport order |
| 104 | + assert list(muxer.transports.keys()) == ["proto1", "proto2"] |
| 105 | + |
| 106 | + # Re-add proto1 to check if it moves to the end |
| 107 | + muxer.add_transport(TProtocol("proto1"), mock_transport1) |
| 108 | + assert list(muxer.transports.keys()) == ["proto2", "proto1"] |
0 commit comments