Skip to content

Commit d5f4acc

Browse files
refactor(swarm): express dial logic linearly (#3253)
Previously, the logic within `Swarm::dial` involved fairly convoluted `match` expressions. This patch refactors this function to use new utility functions introduced on `DialOpts` to handle one concern at a time. This has the advantage that we are covering slightly more cases now. Because we are parsing the `PeerId` only once at the top, checks like banning will now also act on dials that specify the `PeerId` as part of the `/p2p` protocol.
1 parent 1765ae0 commit d5f4acc

File tree

2 files changed

+157
-124
lines changed

2 files changed

+157
-124
lines changed

swarm/src/dial_opts.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
// DEALINGS IN THE SOFTWARE.
2121

2222
use libp2p_core::connection::Endpoint;
23+
use libp2p_core::multiaddr::Protocol;
24+
use libp2p_core::multihash::Multihash;
2325
use libp2p_core::{Multiaddr, PeerId};
2426
use std::num::NonZeroU8;
2527

@@ -79,6 +81,104 @@ impl DialOpts {
7981
DialOpts(Opts::WithoutPeerIdWithAddress(_)) => None,
8082
}
8183
}
84+
85+
/// Retrieves the [`PeerId`] from the [`DialOpts`] if specified or otherwise tries to parse it
86+
/// from the multihash in the `/p2p` part of the address, if present.
87+
///
88+
/// Note: A [`Multiaddr`] with something else other than a [`PeerId`] within the `/p2p` protocol is invalid as per specification.
89+
/// Unfortunately, we are not making good use of the type system here.
90+
/// Really, this function should be merged with [`DialOpts::get_peer_id`] above.
91+
/// If it weren't for the parsing error, the function signatures would be the same.
92+
///
93+
/// See <https://github.com/multiformats/rust-multiaddr/issues/73>.
94+
pub(crate) fn get_or_parse_peer_id(&self) -> Result<Option<PeerId>, Multihash> {
95+
match self {
96+
DialOpts(Opts::WithPeerId(WithPeerId { peer_id, .. })) => Ok(Some(*peer_id)),
97+
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
98+
peer_id, ..
99+
})) => Ok(Some(*peer_id)),
100+
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
101+
address, ..
102+
})) => {
103+
let peer_id = address
104+
.iter()
105+
.last()
106+
.and_then(|p| {
107+
if let Protocol::P2p(ma) = p {
108+
Some(PeerId::try_from(ma))
109+
} else {
110+
None
111+
}
112+
})
113+
.transpose()?;
114+
115+
Ok(peer_id)
116+
}
117+
}
118+
}
119+
120+
pub(crate) fn get_addresses(&self) -> Vec<Multiaddr> {
121+
match self {
122+
DialOpts(Opts::WithPeerId(WithPeerId { .. })) => vec![],
123+
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
124+
addresses, ..
125+
})) => addresses.clone(),
126+
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
127+
address, ..
128+
})) => vec![address.clone()],
129+
}
130+
}
131+
132+
pub(crate) fn extend_addresses_through_behaviour(&self) -> bool {
133+
match self {
134+
DialOpts(Opts::WithPeerId(WithPeerId { .. })) => true,
135+
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
136+
extend_addresses_through_behaviour,
137+
..
138+
})) => *extend_addresses_through_behaviour,
139+
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => true,
140+
}
141+
}
142+
143+
pub(crate) fn peer_condition(&self) -> PeerCondition {
144+
match self {
145+
DialOpts(
146+
Opts::WithPeerId(WithPeerId { condition, .. })
147+
| Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses { condition, .. }),
148+
) => *condition,
149+
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => {
150+
PeerCondition::Always
151+
}
152+
}
153+
}
154+
155+
pub(crate) fn dial_concurrency_override(&self) -> Option<NonZeroU8> {
156+
match self {
157+
DialOpts(Opts::WithPeerId(WithPeerId {
158+
dial_concurrency_factor_override,
159+
..
160+
})) => *dial_concurrency_factor_override,
161+
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
162+
dial_concurrency_factor_override,
163+
..
164+
})) => *dial_concurrency_factor_override,
165+
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress { .. })) => None,
166+
}
167+
}
168+
169+
pub(crate) fn role_override(&self) -> Endpoint {
170+
match self {
171+
DialOpts(Opts::WithPeerId(WithPeerId { role_override, .. })) => *role_override,
172+
DialOpts(Opts::WithPeerIdWithAddresses(WithPeerIdWithAddresses {
173+
role_override,
174+
..
175+
})) => *role_override,
176+
DialOpts(Opts::WithoutPeerIdWithAddress(WithoutPeerIdWithAddress {
177+
role_override,
178+
..
179+
})) => *role_override,
180+
}
181+
}
82182
}
83183

84184
impl From<Multiaddr> for DialOpts {

swarm/src/lib.rs

Lines changed: 57 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ pub use registry::{AddAddressResult, AddressRecord, AddressScore};
122122
use connection::pool::{EstablishedConnection, Pool, PoolConfig, PoolEvent};
123123
use connection::IncomingInfo;
124124
use dial_opts::{DialOpts, PeerCondition};
125-
use either::Either;
126125
use futures::{executor::ThreadPoolBuilder, prelude::*, stream::FusedStream};
127126
use libp2p_core::connection::ConnectionId;
128127
use libp2p_core::muxing::SubstreamBox;
@@ -138,7 +137,6 @@ use libp2p_core::{
138137
use registry::{AddressIntoIter, Addresses};
139138
use smallvec::SmallVec;
140139
use std::collections::{HashMap, HashSet};
141-
use std::iter;
142140
use std::num::{NonZeroU32, NonZeroU8, NonZeroUsize};
143141
use std::{
144142
convert::TryFrom,
@@ -507,139 +505,72 @@ where
507505

508506
fn dial_with_handler(
509507
&mut self,
510-
swarm_dial_opts: DialOpts,
508+
dial_opts: DialOpts,
511509
handler: <TBehaviour as NetworkBehaviour>::ConnectionHandler,
512510
) -> Result<(), DialError> {
513-
let (peer_id, addresses, dial_concurrency_factor_override, role_override) =
514-
match swarm_dial_opts.0 {
515-
// Dial a known peer.
516-
dial_opts::Opts::WithPeerId(dial_opts::WithPeerId {
517-
peer_id,
518-
condition,
519-
role_override,
520-
dial_concurrency_factor_override,
521-
})
522-
| dial_opts::Opts::WithPeerIdWithAddresses(dial_opts::WithPeerIdWithAddresses {
523-
peer_id,
524-
condition,
525-
role_override,
526-
dial_concurrency_factor_override,
527-
..
528-
}) => {
529-
// Check [`PeerCondition`] if provided.
530-
let condition_matched = match condition {
531-
PeerCondition::Disconnected => !self.is_connected(&peer_id),
532-
PeerCondition::NotDialing => !self.pool.is_dialing(peer_id),
533-
PeerCondition::Always => true,
534-
};
535-
if !condition_matched {
536-
#[allow(deprecated)]
537-
self.behaviour.inject_dial_failure(
538-
Some(peer_id),
539-
handler,
540-
&DialError::DialPeerConditionFalse(condition),
541-
);
542-
543-
return Err(DialError::DialPeerConditionFalse(condition));
544-
}
511+
let peer_id = dial_opts
512+
.get_or_parse_peer_id()
513+
.map_err(DialError::InvalidPeerId)?;
514+
let condition = dial_opts.peer_condition();
515+
516+
let should_dial = match (condition, peer_id) {
517+
(PeerCondition::Always, _) => true,
518+
(PeerCondition::Disconnected, None) => true,
519+
(PeerCondition::NotDialing, None) => true,
520+
(PeerCondition::Disconnected, Some(peer_id)) => !self.pool.is_connected(peer_id),
521+
(PeerCondition::NotDialing, Some(peer_id)) => !self.pool.is_dialing(peer_id),
522+
};
545523

546-
// Check if peer is banned.
547-
if self.banned_peers.contains(&peer_id) {
548-
let error = DialError::Banned;
549-
#[allow(deprecated)]
550-
self.behaviour
551-
.inject_dial_failure(Some(peer_id), handler, &error);
552-
return Err(error);
553-
}
524+
if !should_dial {
525+
let e = DialError::DialPeerConditionFalse(condition);
554526

555-
// Retrieve the addresses to dial.
556-
let addresses = {
557-
let mut addresses = match swarm_dial_opts.0 {
558-
dial_opts::Opts::WithPeerId(dial_opts::WithPeerId { .. }) => {
559-
self.behaviour.addresses_of_peer(&peer_id)
560-
}
561-
dial_opts::Opts::WithPeerIdWithAddresses(
562-
dial_opts::WithPeerIdWithAddresses {
563-
peer_id,
564-
mut addresses,
565-
extend_addresses_through_behaviour,
566-
..
567-
},
568-
) => {
569-
if extend_addresses_through_behaviour {
570-
addresses.extend(self.behaviour.addresses_of_peer(&peer_id))
571-
}
572-
addresses
573-
}
574-
dial_opts::Opts::WithoutPeerIdWithAddress { .. } => {
575-
unreachable!("Due to outer match.")
576-
}
577-
};
527+
#[allow(deprecated)]
528+
self.behaviour.inject_dial_failure(peer_id, handler, &e);
578529

579-
let mut unique_addresses = HashSet::new();
580-
addresses.retain(|addr| {
581-
!self.listened_addrs.values().flatten().any(|a| a == addr)
582-
&& unique_addresses.insert(addr.clone())
583-
});
530+
return Err(e);
531+
}
584532

585-
if addresses.is_empty() {
586-
let error = DialError::NoAddresses;
587-
#[allow(deprecated)]
588-
self.behaviour
589-
.inject_dial_failure(Some(peer_id), handler, &error);
590-
return Err(error);
591-
};
533+
if let Some(peer_id) = peer_id {
534+
// Check if peer is banned.
535+
if self.banned_peers.contains(&peer_id) {
536+
let error = DialError::Banned;
537+
#[allow(deprecated)]
538+
self.behaviour
539+
.inject_dial_failure(Some(peer_id), handler, &error);
540+
return Err(error);
541+
}
542+
}
592543

593-
addresses
594-
};
544+
let addresses = {
545+
let mut addresses = dial_opts.get_addresses();
595546

596-
(
597-
Some(peer_id),
598-
Either::Left(addresses.into_iter()),
599-
dial_concurrency_factor_override,
600-
role_override,
601-
)
547+
if let Some(peer_id) = peer_id {
548+
if dial_opts.extend_addresses_through_behaviour() {
549+
addresses.extend(self.behaviour.addresses_of_peer(&peer_id));
602550
}
603-
// Dial an unknown peer.
604-
dial_opts::Opts::WithoutPeerIdWithAddress(
605-
dial_opts::WithoutPeerIdWithAddress {
606-
address,
607-
role_override,
608-
},
609-
) => {
610-
// If the address ultimately encapsulates an expected peer ID, dial that peer
611-
// such that any mismatch is detected. We do not "pop off" the `P2p` protocol
612-
// from the address, because it may be used by the `Transport`, i.e. `P2p`
613-
// is a protocol component that can influence any transport, like `libp2p-dns`.
614-
let peer_id = match address
615-
.iter()
616-
.last()
617-
.and_then(|p| {
618-
if let Protocol::P2p(ma) = p {
619-
Some(PeerId::try_from(ma))
620-
} else {
621-
None
622-
}
623-
})
624-
.transpose()
625-
{
626-
Ok(peer_id) => peer_id,
627-
Err(multihash) => return Err(DialError::InvalidPeerId(multihash)),
628-
};
551+
}
629552

630-
(
631-
peer_id,
632-
Either::Right(iter::once(address)),
633-
None,
634-
role_override,
635-
)
636-
}
553+
let mut unique_addresses = HashSet::new();
554+
addresses.retain(|addr| {
555+
!self.listened_addrs.values().flatten().any(|a| a == addr)
556+
&& unique_addresses.insert(addr.clone())
557+
});
558+
559+
if addresses.is_empty() {
560+
let error = DialError::NoAddresses;
561+
#[allow(deprecated)]
562+
self.behaviour.inject_dial_failure(peer_id, handler, &error);
563+
return Err(error);
637564
};
638565

566+
addresses
567+
};
568+
639569
let dials = addresses
570+
.into_iter()
640571
.map(|a| match p2p_addr(peer_id, a) {
641572
Ok(address) => {
642-
let dial = match role_override {
573+
let dial = match dial_opts.role_override() {
643574
Endpoint::Dialer => self.transport.dial(address.clone()),
644575
Endpoint::Listener => self.transport.dial_as_listener(address.clone()),
645576
};
@@ -662,8 +593,8 @@ where
662593
dials,
663594
peer_id,
664595
handler,
665-
role_override,
666-
dial_concurrency_factor_override,
596+
dial_opts.role_override(),
597+
dial_opts.dial_concurrency_override(),
667598
) {
668599
Ok(_connection_id) => Ok(()),
669600
Err((connection_limit, handler)) => {
@@ -1088,9 +1019,9 @@ where
10881019
return Some(SwarmEvent::Behaviour(event))
10891020
}
10901021
NetworkBehaviourAction::Dial { opts, handler } => {
1091-
let peer_id = opts.get_peer_id();
1022+
let peer_id = opts.get_or_parse_peer_id();
10921023
if let Ok(()) = self.dial_with_handler(opts, handler) {
1093-
if let Some(peer_id) = peer_id {
1024+
if let Ok(Some(peer_id)) = peer_id {
10941025
return Some(SwarmEvent::Dialing(peer_id));
10951026
}
10961027
}
@@ -2516,6 +2447,8 @@ mod tests {
25162447
_ => panic!("Was expecting the listen address to be reported"),
25172448
}));
25182449

2450+
swarm.listened_addrs.clear(); // This is a hack to actually execute the dial to ourselves which would otherwise be filtered.
2451+
25192452
swarm.dial(local_address.clone()).unwrap();
25202453

25212454
let mut got_dial_err = false;

0 commit comments

Comments
 (0)