Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 24 additions & 51 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ moka = { version = "0.8.1", features = ["future"] }
tw_chain = "1.1.3"
keccak_prime = "0.1.0"
protobuf = "2.6.0"
raft = { git = "https://github.com/ABlockOfficial/raft-rs", branch = "0.5.1" }
raft = { git = "https://github.com/AIBlockOfficial/raft-rs", branch = "0.5.1" }
rand = "0.7.3"
ring = "0.16.20"
rocksdb = "0.21.0"
Expand All @@ -33,7 +33,7 @@ serde = { version = "1.0.104", features = ["derive"] }
sha3 = "0.9.1"
serde_json = "1.0.61"
tokio = { version = "1.7.1", features = ["full"] }
tokio-rustls = "0.23.0"
tokio-rustls = "0.24.0"
tokio-util = { version = "0.6.7", features = ["full"] }
tokio-stream = "0.1.6"
tracing = "0.1.40"
Expand All @@ -43,6 +43,11 @@ warp = { version = "0.3.1", features = ["tls"] }
url = "2.4.1"
trust-dns-resolver = "0.23.2"
rustls-pemfile = "2.0.0"
rustls = "0.21.11"
shlex = "1.3.0"
h2 = "0.3.26"
mio = "0.8.11"
rustls-pki-types = "1.8.0"

[features]
mock = []
Expand Down
11 changes: 0 additions & 11 deletions src/comms_handler/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::net::SocketAddr;
use std::{error::Error, fmt, io};
use tokio::sync::mpsc;
use tokio_rustls::rustls::Error as TLSError;
use tokio_rustls::webpki;

#[derive(Debug)]
pub enum CommsError {
Expand Down Expand Up @@ -32,8 +31,6 @@ pub enum CommsError {
Serialization(bincode::Error),
/// MPSC channel error.
ChannelSendError(mpsc::error::SendError<Event>),
/// Webpki error
WebpkiError(webpki::Error),
}

#[derive(Debug)]
Expand All @@ -57,7 +54,6 @@ impl fmt::Display for CommsError {
Self::PeerIncompatible(info) => write!(f, "Peer incompatible: {info:?}"),
Self::Serialization(err) => write!(f, "Serialization error: {err}"),
Self::ChannelSendError(err) => write!(f, "MPSC channel send error: {err}"),
Self::WebpkiError(err) => write!(f, "Webpki error: {err}"),
}
}
}
Expand All @@ -77,7 +73,6 @@ impl Error for CommsError {
Self::PeerIncompatible(_) => None,
Self::Serialization(err) => Some(err),
Self::ChannelSendError(err) => Some(err),
Self::WebpkiError(err) => Some(err),
}
}
}
Expand Down Expand Up @@ -105,9 +100,3 @@ impl From<TLSError> for CommsError {
Self::TlsError(other)
}
}

impl From<webpki::Error> for CommsError {
fn from(other: webpki::Error) -> Self {
Self::WebpkiError(other)
}
}
28 changes: 16 additions & 12 deletions src/comms_handler/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@
//! [netbuffersize]: https://stackoverflow.com/a/7865130/168853

use super::tcp_tls::{
verify_is_valid_for_dns_names, TcpTlsConnector, TcpTlsListner, TcpTlsStream, TlsCertificate,
// verify_is_valid_for_dns_names,
TcpTlsConnector,
TcpTlsListner,
TcpTlsStream,
TlsCertificate,
};
use super::{CommsError, Event, Result, TcpTlsConfig};
use crate::comms_handler::error::PeerInfo;
Expand Down Expand Up @@ -973,7 +977,7 @@ impl Node {
&self,
peer_out_addr: SocketAddr,
mut peer_in_addr: SocketAddr,
peer_cert: &Option<TlsCertificate>,
_peer_cert: &Option<TlsCertificate>,
mut send_tx: ResultBytesSender,
network_version: u32,
peer_type: NodeType,
Expand Down Expand Up @@ -1024,16 +1028,16 @@ impl Node {
}

// We only do DNS validation on mempool and storage nodes
if self.node_type == NodeType::Mempool || self.node_type == NodeType::Storage {
if let Some(peer_cert) = peer_cert {
let connector = self.tcp_tls_connector.read().await;
let peer_name = connector.socket_name_mapping(peer_in_addr);
// We don't need strict DNS name validation for miner or user nodes
if peer_type != NodeType::Miner && peer_type != NodeType::User {
verify_is_valid_for_dns_names(peer_cert, std::iter::once(peer_name.as_str()))?;
}
}
}
// if self.node_type == NodeType::Mempool || self.node_type == NodeType::Storage {
// if let Some(peer_cert) = peer_cert {
// let connector = self.tcp_tls_connector.read().await;
// let peer_name = connector.socket_name_mapping(peer_in_addr);
// // We don't need strict DNS name validation for miner or user nodes
// if peer_type != NodeType::Miner && peer_type != NodeType::User {
// verify_is_valid_for_dns_names(peer_cert, std::iter::once(peer_name.as_str()))?;
// }
// }
// }

peer.network_version = Some(network_version);
peer.peer_type = Some(peer_type);
Expand Down
30 changes: 15 additions & 15 deletions src/comms_handler/tcp_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use tokio_rustls::rustls::client::ServerName;
use tokio_rustls::rustls::{
Certificate, ClientConfig, CommonState, PrivateKey, RootCertStore, ServerConfig,
};
use tokio_rustls::webpki::{DnsNameRef, EndEntityCert};
// use tokio_rustls::webpki::{DnsNameRef, EndEntityCert};
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tokio_stream::Stream;

Expand Down Expand Up @@ -311,7 +311,7 @@ fn new_client_config(config: &TcpTlsConfig) -> Result<ClientConfig> {
.with_safe_default_protocol_versions()
.unwrap()
.with_root_certificates(root_store)
.with_single_cert(certs, keys.remove(0))?;
.with_client_auth_cert(certs, keys.remove(0))?;

Ok(client_config)
}
Expand Down Expand Up @@ -387,19 +387,19 @@ impl AsyncWrite for TcpTlsStream {
}
}

/// verify the dna name is valid for the certificae
pub fn verify_is_valid_for_dns_names<'a>(
cert: &TlsCertificate,
tls_names: impl Iterator<Item = &'a str>,
) -> Result<()> {
let domains: std::result::Result<Vec<_>, _> =
tls_names.map(DnsNameRef::try_from_ascii_str).collect();
let domains = domains.map_err(|_| CommsError::ConfigError("invalid dnsname"))?;

let cert = EndEntityCert::try_from(cert.0.as_slice()).unwrap();
cert.verify_is_valid_for_at_least_one_dns_name(domains.iter().copied())?;
Ok(())
}
// /// verify the dna name is valid for the certificae
// pub fn verify_is_valid_for_dns_names<'a>(
// cert: &TlsCertificate,
// tls_names: impl Iterator<Item = &'a str>,
// ) -> Result<()> {
// let domains: std::result::Result<Vec<_>, _> =
// tls_names.map(DnsNameRef::try_from_ascii_str).collect();
// let domains = domains.map_err(|_| CommsError::ConfigError("invalid dnsname"))?;

// let cert = EndEntityCert::try_from(cert.0.as_slice()).unwrap();
// cert.verify_is_valid_for_at_least_one_dns_name(domains.iter().copied())?;
// Ok(())
// }

/// Retrieves the certificate from a TLS session connection. In later versions of rustls, this is
/// a method on `CommonState`
Expand Down