diff --git a/libp2p/relay/circuit_v2/dcutr.py b/libp2p/relay/circuit_v2/dcutr.py index 644ea75f0..1328ac491 100644 --- a/libp2p/relay/circuit_v2/dcutr.py +++ b/libp2p/relay/circuit_v2/dcutr.py @@ -524,7 +524,9 @@ async def _have_direct_connection(self, peer_id: ID) -> bool: # Handle both single connection and list of connections connections: list[INetConn] = ( - [conn_or_conns] if not isinstance(conn_or_conns, list) else conn_or_conns + list(conn_or_conns) + if not isinstance(conn_or_conns, list) + else conn_or_conns ) # Check if any connection is direct (not relayed) diff --git a/libp2p/relay/circuit_v2/transport.py b/libp2p/relay/circuit_v2/transport.py index 3632615a4..23454b890 100644 --- a/libp2p/relay/circuit_v2/transport.py +++ b/libp2p/relay/circuit_v2/transport.py @@ -92,6 +92,7 @@ def __init__( stream_timeout=config.timeouts.discovery_stream_timeout, peer_protocol_timeout=config.timeouts.peer_protocol_timeout, ) + self.relay_counter = 0 # for round robin load balancing async def dial( self, @@ -221,11 +222,25 @@ async def _select_relay(self, peer_info: PeerInfo) -> ID | None: # Get a relay from the list of discovered relays relays = self.discovery.get_relays() if relays: - # TODO: Implement more sophisticated relay selection - # For now, just return the first available relay - return relays[0] - - # Wait and try discovery + # Prioritize relays with active reservations + relays_with_reservations = [] + other_relays = [] + + for relay_id in relays: + relay_info = self.discovery.get_relay_info(relay_id) + if relay_info and relay_info.has_reservation: + relays_with_reservations.append(relay_id) + else: + other_relays.append(relay_id) + + # Return first available relay with reservation, or fallback to others + self.relay_counter += 1 + if relays_with_reservations: + return relays_with_reservations[ + (self.relay_counter - 1) % len(relays_with_reservations) + ] + elif other_relays: + return other_relays[(self.relay_counter - 1) % len(other_relays)] await trio.sleep(1) attempts += 1 diff --git a/newsfragments/735.internal.rst b/newsfragments/735.internal.rst new file mode 100644 index 000000000..d773fb770 --- /dev/null +++ b/newsfragments/735.internal.rst @@ -0,0 +1 @@ +Improve relay selection by load balancing and reservation priority. diff --git a/tests/core/relay/test_circuit_v2_transport.py b/tests/core/relay/test_circuit_v2_transport.py index 8498dba40..dc0273818 100644 --- a/tests/core/relay/test_circuit_v2_transport.py +++ b/tests/core/relay/test_circuit_v2_transport.py @@ -11,6 +11,7 @@ StreamEOF, StreamReset, ) +from libp2p.peer.peerinfo import PeerInfo from libp2p.relay.circuit_v2.config import ( RelayConfig, ) @@ -344,3 +345,78 @@ async def test_circuit_v2_transport_relay_limits(): # Test successful - transports were initialized with the correct limits logger.info("Transport limit test successful") + + +@pytest.mark.trio +async def test_circuit_v2_transport_relay_selection(): + """Test relay round robin load balancing and reservation priority""" + async with HostFactory.create_batch_and_listen(5) as hosts: + client1_host, relay_host1, relay_host2, relay_host3, target_host = hosts + + # Setup relay with strict limits + limits = RelayLimits( + duration=DEFAULT_RELAY_LIMITS.duration, + data=DEFAULT_RELAY_LIMITS.data, + max_circuit_conns=DEFAULT_RELAY_LIMITS.max_circuit_conns, + max_reservations=DEFAULT_RELAY_LIMITS.max_reservations, + ) + + # Register test handler on target + test_protocol = "/test/echo/1.0.0" + target_host.set_stream_handler(TProtocol(test_protocol), echo_stream_handler) + target_host_info = PeerInfo(target_host.get_id(), target_host.get_addrs()) + client_config = RelayConfig() + + # Client setup + client1_protocol = CircuitV2Protocol(client1_host, limits, allow_hop=False) + client1_discovery = RelayDiscovery( + host=client1_host, + auto_reserve=False, + discovery_interval=client_config.discovery_interval, + max_relays=client_config.max_relays, + ) + client1_transport = CircuitV2Transport( + client1_host, client1_protocol, client_config + ) + client1_transport.discovery = client1_discovery + # Add relay to discovery + relay_id1 = relay_host1.get_id() + relay_id2 = relay_host2.get_id() + relay_id3 = relay_host3.get_id() + + # Connect all peers + try: + with trio.fail_after(CONNECT_TIMEOUT): + # Connect clients to relay + await connect(client1_host, relay_host1) + await connect(client1_host, relay_host2) + await connect(client1_host, relay_host3) + + await client1_discovery._add_relay(relay_id1) + await client1_discovery._add_relay(relay_id2) + await client1_discovery._add_relay(relay_id3) + + logger.info("All connections established") + except Exception as e: + logger.error("Failed to connect peers: %s", str(e)) + raise + + selected_relay = await client1_transport._select_relay(target_host_info) + # Without reservation preferance + # Round robin, so 1st time must be relay1 + assert selected_relay is not None and selected_relay is relay_id1 + + selected_relay = await client1_transport._select_relay(target_host_info) + # Round robin, so 2nd time must be relay2 + assert selected_relay is not None and selected_relay is relay_id2 + + # Mock reservation with relay1 to prioritize over relay2 + relay_info3 = client1_discovery.get_relay_info(relay_id3) + if relay_info3: + relay_info3.has_reservation = True + + selected_relay = await client1_transport._select_relay(target_host_info) + # With reservation preferance, relay2 must be chosen for target_peer. + assert selected_relay is not None and selected_relay is relay_host3.get_id() + + logger.info("Relay selection successful")