diff --git a/Cargo.toml b/Cargo.toml index 042b14e2..105ad6c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ libc = "0.2" tempfile = "3.1.0" [target.'cfg(target_os = "windows")'.dependencies] -schannel = "0.1.16" +schannel = "0.1.17" [target.'cfg(not(any(target_os = "windows", target_os = "macos", target_os = "ios")))'.dependencies] log = "0.4.5" @@ -36,4 +36,4 @@ openssl-src = { version = "300.0.3", optional = true } [dev-dependencies] tempfile = "3.0" -test-cert-gen = "0.1" +test-cert-gen = "0.7" diff --git a/README.md b/README.md index 22bd1fde..fe5d04b8 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ native-tls = "0.2" An example client looks like: -```rust +```rust,ignore extern crate native_tls; use native_tls::TlsConnector; @@ -46,7 +46,7 @@ fn main() { To accept connections as a server from remote clients: -```rust,no_run +```rust,ignore extern crate native_tls; use native_tls::{Identity, TlsAcceptor, TlsStream}; diff --git a/examples/simple-server-pkcs8.rs b/examples/simple-server-pkcs8.rs new file mode 100644 index 00000000..df9c95da --- /dev/null +++ b/examples/simple-server-pkcs8.rs @@ -0,0 +1,45 @@ +extern crate native_tls; + +use native_tls::{Identity, TlsAcceptor, TlsStream}; +use std::fs::File; +use std::io::{Read, Write}; +use std::net::{TcpListener, TcpStream}; +use std::sync::Arc; +use std::thread; + +fn main() { + let mut cert_file = File::open("test/cert.pem").unwrap(); + let mut certs = vec![]; + cert_file.read_to_end(&mut certs).unwrap(); + let mut key_file = File::open("test/key.pem").unwrap(); + let mut key = vec![]; + key_file.read_to_end(&mut key).unwrap(); + let pkcs8 = Identity::from_pkcs8(&certs, &key).unwrap(); + + let acceptor = TlsAcceptor::new(pkcs8).unwrap(); + let acceptor = Arc::new(acceptor); + + let listener = TcpListener::bind("0.0.0.0:8443").unwrap(); + + fn handle_client(mut stream: TlsStream) { + let mut buf = [0; 1024]; + let read = stream.read(&mut buf).unwrap(); + let received = std::str::from_utf8(&buf[0..read]).unwrap(); + stream + .write_all(format!("received '{}'", received).as_bytes()) + .unwrap(); + } + + for stream in listener.incoming() { + match stream { + Ok(stream) => { + let acceptor = acceptor.clone(); + thread::spawn(move || { + let stream = acceptor.accept(stream).unwrap(); + handle_client(stream); + }); + } + Err(_e) => { /* connection failed */ } + } + } +} diff --git a/src/imp/openssl.rs b/src/imp/openssl.rs index ef254a3e..389caa5e 100644 --- a/src/imp/openssl.rs +++ b/src/imp/openssl.rs @@ -5,7 +5,7 @@ use self::openssl::error::ErrorStack; use self::openssl::hash::MessageDigest; use self::openssl::nid::Nid; use self::openssl::pkcs12::Pkcs12; -use self::openssl::pkey::PKey; +use self::openssl::pkey::{PKey, Private}; use self::openssl::ssl::{ self, MidHandshakeSslStream, SslAcceptor, SslConnector, SslContextBuilder, SslMethod, SslVerifyMode, @@ -16,7 +16,6 @@ use std::fmt; use std::io; use std::sync::Once; -use self::openssl::pkey::Private; use {Protocol, TlsAcceptorBuilder, TlsConnectorBuilder}; #[cfg(have_min_max_version)] @@ -117,6 +116,8 @@ fn load_android_root_certs(connector: &mut SslContextBuilder) -> Result<(), Erro pub enum Error { Normal(ErrorStack), Ssl(ssl::Error, X509VerifyResult), + EmptyChain, + NotPkcs8, } impl error::Error for Error { @@ -124,6 +125,8 @@ impl error::Error for Error { match *self { Error::Normal(ref e) => error::Error::source(e), Error::Ssl(ref e, _) => error::Error::source(e), + Error::EmptyChain => None, + Error::NotPkcs8 => None, } } } @@ -134,6 +137,11 @@ impl fmt::Display for Error { Error::Normal(ref e) => fmt::Display::fmt(e, fmt), Error::Ssl(ref e, X509VerifyResult::OK) => fmt::Display::fmt(e, fmt), Error::Ssl(ref e, v) => write!(fmt, "{} ({})", e, v), + Error::EmptyChain => write!( + fmt, + "at least one certificate must be provided to create an identity" + ), + Error::NotPkcs8 => write!(fmt, "expected PKCS#8 PEM"), } } } @@ -158,9 +166,24 @@ impl Identity { Ok(Identity { pkey: parsed.pkey, cert: parsed.cert, - chain: parsed.chain.into_iter().flatten().collect(), + // > The stack is the reverse of what you might expect due to the way + // > PKCS12_parse is implemented, so we need to load it backwards. + // > https://github.com/sfackler/rust-native-tls/commit/05fb5e583be589ab63d9f83d986d095639f8ec44 + chain: parsed.chain.into_iter().flatten().rev().collect(), }) } + + pub fn from_pkcs8(buf: &[u8], key: &[u8]) -> Result { + if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") { + return Err(Error::NotPkcs8); + } + + let pkey = PKey::private_key_from_pem(key)?; + let mut cert_chain = X509::stack_from_pem(buf)?.into_iter(); + let cert = cert_chain.next().ok_or(Error::EmptyChain)?; + let chain = cert_chain.collect(); + Ok(Identity { pkey, cert, chain }) + } } #[derive(Clone)] @@ -258,7 +281,10 @@ impl TlsConnector { if let Some(ref identity) = builder.identity { connector.set_certificate(&identity.0.cert)?; connector.set_private_key(&identity.0.pkey)?; - for cert in identity.0.chain.iter().rev() { + for cert in identity.0.chain.iter() { + // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html + // specifies that "When sending a certificate chain, extra chain certificates are + // sent in order following the end entity certificate." connector.add_extra_chain_cert(cert.to_owned())?; } } @@ -342,7 +368,10 @@ impl TlsAcceptor { let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; acceptor.set_private_key(&builder.identity.0.pkey)?; acceptor.set_certificate(&builder.identity.0.cert)?; - for cert in builder.identity.0.chain.iter().rev() { + for cert in builder.identity.0.chain.iter() { + // https://www.openssl.org/docs/manmaster/man3/SSL_CTX_add_extra_chain_cert.html + // specifies that "When sending a certificate chain, extra chain certificates are + // sent in order following the end entity certificate." acceptor.add_extra_chain_cert(cert.to_owned())?; } supported_protocols(builder.min_protocol, builder.max_protocol, &mut acceptor)?; diff --git a/src/imp/schannel.rs b/src/imp/schannel.rs index 58ec0636..62e5042f 100644 --- a/src/imp/schannel.rs +++ b/src/imp/schannel.rs @@ -1,7 +1,8 @@ extern crate schannel; -use self::schannel::cert_context::{CertContext, HashAlgorithm}; +use self::schannel::cert_context::{CertContext, HashAlgorithm, KeySpec}; use self::schannel::cert_store::{CertAdd, CertStore, Memory, PfxImportOptions}; +use self::schannel::crypt_prov::{AcquireOptions, ProviderType}; use self::schannel::schannel_cred::{Direction, Protocol, SchannelCred}; use self::schannel::tls_stream; use std::error; @@ -93,6 +94,59 @@ impl Identity { Ok(Identity { cert: identity }) } + + pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Result { + if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "not a PKCS#8 key").into()); + } + + let mut store = Memory::new()?.into_store(); + let mut cert_iter = pem::PemBlock::new(pem).into_iter(); + let leaf = cert_iter.next().ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "at least one certificate must be provided to create an identity", + ) + })?; + let cert = CertContext::from_pem(std::str::from_utf8(leaf).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "leaf cert contains invalid utf8", + ) + })?)?; + + let name = gen_container_name(); + let mut options = AcquireOptions::new(); + options.container(&name); + let type_ = ProviderType::rsa_full(); + + let mut container = match options.acquire(type_) { + Ok(container) => container, + Err(_) => options.new_keyset(true).acquire(type_)?, + }; + container.import().import_pkcs8_pem(&key)?; + + cert.set_key_prov_info() + .container(&name) + .type_(type_) + .keep_open(true) + .key_spec(KeySpec::key_exchange()) + .set()?; + let mut context = store.add_cert(&cert, CertAdd::Always)?; + + for int_cert in cert_iter { + let certificate = Certificate::from_pem(int_cert)?; + context = store.add_cert(&certificate.0, CertAdd::Always)?; + } + Ok(Identity { cert: context }) + } +} + +// The name of the container must be unique to have multiple active keys. +fn gen_container_name() -> String { + use std::sync::atomic::{AtomicUsize, Ordering}; + static COUNTER: AtomicUsize = AtomicUsize::new(0); + format!("native-tls-{}", COUNTER.fetch_add(1, Ordering::Relaxed)) } #[derive(Clone)] @@ -384,3 +438,125 @@ impl io::Write for TlsStream { self.0.flush() } } + +mod pem { + /// Split data by PEM guard lines + pub struct PemBlock<'a> { + pem_block: &'a str, + cur_end: usize, + } + + impl<'a> PemBlock<'a> { + pub fn new(data: &'a [u8]) -> PemBlock<'a> { + let s = ::std::str::from_utf8(data).unwrap(); + PemBlock { + pem_block: s, + cur_end: s.find("-----BEGIN").unwrap_or(s.len()), + } + } + } + + impl<'a> Iterator for PemBlock<'a> { + type Item = &'a [u8]; + fn next(&mut self) -> Option { + let last = self.pem_block.len(); + if self.cur_end >= last { + return None; + } + let begin = self.cur_end; + let pos = self.pem_block[begin + 1..].find("-----BEGIN"); + self.cur_end = match pos { + Some(end) => end + begin + 1, + None => last, + }; + return Some(&self.pem_block[begin..self.cur_end].as_bytes()); + } + } + + #[test] + fn test_split() { + // Split three certs, CRLF line terminators. + assert_eq!( + PemBlock::new( + b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n\ + -----BEGIN SECOND-----\r\n-----END SECOND\r\n\ + -----BEGIN THIRD-----\r\n-----END THIRD\r\n" + ) + .collect::>(), + vec![ + b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n" as &[u8], + b"-----BEGIN SECOND-----\r\n-----END SECOND\r\n", + b"-----BEGIN THIRD-----\r\n-----END THIRD\r\n" + ] + ); + // Split three certs, CRLF line terminators except at EOF. + assert_eq!( + PemBlock::new( + b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n\ + -----BEGIN SECOND-----\r\n-----END SECOND-----\r\n\ + -----BEGIN THIRD-----\r\n-----END THIRD-----" + ) + .collect::>(), + vec![ + b"-----BEGIN FIRST-----\r\n-----END FIRST-----\r\n" as &[u8], + b"-----BEGIN SECOND-----\r\n-----END SECOND-----\r\n", + b"-----BEGIN THIRD-----\r\n-----END THIRD-----" + ] + ); + // Split two certs, LF line terminators. + assert_eq!( + PemBlock::new( + b"-----BEGIN FIRST-----\n-----END FIRST-----\n\ + -----BEGIN SECOND-----\n-----END SECOND\n" + ) + .collect::>(), + vec![ + b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8], + b"-----BEGIN SECOND-----\n-----END SECOND\n" + ] + ); + // Split two certs, CR line terminators. + assert_eq!( + PemBlock::new( + b"-----BEGIN FIRST-----\r-----END FIRST-----\r\ + -----BEGIN SECOND-----\r-----END SECOND\r" + ) + .collect::>(), + vec![ + b"-----BEGIN FIRST-----\r-----END FIRST-----\r" as &[u8], + b"-----BEGIN SECOND-----\r-----END SECOND\r" + ] + ); + // Split two certs, LF line terminators except at EOF. + assert_eq!( + PemBlock::new( + b"-----BEGIN FIRST-----\n-----END FIRST-----\n\ + -----BEGIN SECOND-----\n-----END SECOND" + ) + .collect::>(), + vec![ + b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8], + b"-----BEGIN SECOND-----\n-----END SECOND" + ] + ); + // Split a single cert, LF line terminators. + assert_eq!( + PemBlock::new(b"-----BEGIN FIRST-----\n-----END FIRST-----\n").collect::>(), + vec![b"-----BEGIN FIRST-----\n-----END FIRST-----\n" as &[u8]] + ); + // Split a single cert, LF line terminators except at EOF. + assert_eq!( + PemBlock::new(b"-----BEGIN FIRST-----\n-----END FIRST-----").collect::>(), + vec![b"-----BEGIN FIRST-----\n-----END FIRST-----" as &[u8]] + ); + // (Don't) split garbage. + assert_eq!( + PemBlock::new(b"junk").collect::>(), + Vec::<&[u8]>::new() + ); + assert_eq!( + PemBlock::new(b"junk-----BEGIN garbage").collect::>(), + vec![b"-----BEGIN garbage" as &[u8]] + ); + } +} diff --git a/src/imp/security_framework.rs b/src/imp/security_framework.rs index c14b2fb0..5a89dfa8 100644 --- a/src/imp/security_framework.rs +++ b/src/imp/security_framework.rs @@ -7,6 +7,7 @@ use self::security_framework::base; use self::security_framework::certificate::SecCertificate; use self::security_framework::identity::SecIdentity; use self::security_framework::import_export::{ImportedIdentity, Pkcs12ImportOptions}; +use self::security_framework::random::SecRandom; use self::security_framework::secure_transport::{ self, ClientBuilder, SslConnectionType, SslContext, SslProtocol, SslProtocolSide, }; @@ -24,6 +25,8 @@ use self::security_framework::os::macos::certificate::{PropertyType, SecCertific #[cfg(not(target_os = "ios"))] use self::security_framework::os::macos::certificate_oids::CertificateOid; #[cfg(not(target_os = "ios"))] +use self::security_framework::os::macos::identity::SecIdentityExt; +#[cfg(not(target_os = "ios"))] use self::security_framework::os::macos::import_export::{ ImportOptions, Pkcs12ImportOptionsExt, SecItems, }; @@ -82,6 +85,41 @@ pub struct Identity { } impl Identity { + pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Result { + if !key.starts_with(b"-----BEGIN PRIVATE KEY-----") { + return Err(Error(base::Error::from(errSecParam))); + } + + let dir = TempDir::new().map_err(|_| Error(base::Error::from(errSecIO)))?; + let keychain = keychain::CreateOptions::new() + .password(&random_password()?) + .create(dir.path().join("identity.keychain"))?; + + let mut items = SecItems::default(); + + ImportOptions::new() + .filename("key.pem") + .items(&mut items) + .keychain(&keychain) + .import(&key)?; + + ImportOptions::new() + .filename("chain.pem") + .items(&mut items) + .keychain(&keychain) + .import(&pem)?; + + let cert = items + .certificates + .get(0) + .ok_or_else(|| Error(base::Error::from(errSecParam)))?; + let ident = SecIdentity::with_certificate(&[keychain], cert)?; + Ok(Identity { + identity: ident, + chain: items.certificates, + }) + } + pub fn from_pkcs12(buf: &[u8], pass: &str) -> Result { let mut imports = Identity::import_options(buf, pass)?; let import = imports.pop().unwrap(); @@ -143,6 +181,19 @@ impl Identity { } } +fn random_password() -> Result { + use std::fmt::Write; + let mut bytes = [0_u8; 10]; + SecRandom::default() + .copy_bytes(&mut bytes) + .map_err(|_| Error(base::Error::from(errSecIO)))?; + let mut s = String::with_capacity(2 * bytes.len()); + for byte in bytes { + write!(s, "{:02X}", byte).map_err(|_| Error(base::Error::from(errSecIO)))?; + } + Ok(s) +} + #[derive(Clone)] pub struct Certificate(SecCertificate); diff --git a/src/lib.rs b/src/lib.rs index 171e8c7e..14dabb7b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,7 @@ //! * TLS/SSL client communication //! * TLS/SSL server communication //! * PKCS#12 encoded identities +//! * X.509/PKCS#8 encoded identities //! * Secure-by-default for client and server //! * Includes hostname verification for clients //! * Supports asynchronous I/O for both the server and the client @@ -177,6 +178,18 @@ impl Identity { let identity = imp::Identity::from_pkcs12(der, password)?; Ok(Identity(identity)) } + + /// Parses a chain of PEM encoded X509 certificates, with the leaf certificate first. + /// `key` is a PEM encoded PKCS #8 formatted private key for the leaf certificate. + /// + /// The certificate chain should contain any intermediate cerficates that should be sent to + /// clients to allow them to build a chain to a trusted root. + /// + /// A certificate chain here means a series of PEM encoded certificates concatenated together. + pub fn from_pkcs8(pem: &[u8], key: &[u8]) -> Result { + let identity = imp::Identity::from_pkcs8(pem, key)?; + Ok(Identity(identity)) + } } /// An X509 certificate. diff --git a/src/test.rs b/src/test.rs index 940ff21f..d29f0d26 100644 --- a/src/test.rs +++ b/src/test.rs @@ -59,8 +59,8 @@ fn server_no_root_certs() { let keys = test_cert_gen::keys(); let identity = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let builder = p!(TlsAcceptor::new(identity)); @@ -78,7 +78,7 @@ fn server_no_root_certs() { p!(socket.write_all(b"world")); }); - let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap(); + let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap(); let socket = p!(TcpStream::connect(("localhost", port))); let builder = p!(TlsConnector::builder() @@ -100,8 +100,8 @@ fn server() { let keys = test_cert_gen::keys(); let identity = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let builder = p!(TlsAcceptor::new(identity)); @@ -119,7 +119,7 @@ fn server() { p!(socket.write_all(b"world")); }); - let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap(); + let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap(); let socket = p!(TcpStream::connect(("localhost", port))); let builder = p!(TlsConnector::builder() @@ -141,7 +141,7 @@ fn certificate_from_pem() { let keys = test_cert_gen::keys(); let der_path = dir.path().join("cert.der"); - fs::write(&der_path, &keys.client.cert_der).unwrap(); + fs::write(&der_path, &keys.client.ca.get_der()).unwrap(); let output = Command::new("openssl") .arg("x509") .arg("-in") @@ -155,7 +155,7 @@ fn certificate_from_pem() { assert!(output.status.success()); let cert = Certificate::from_pem(&output.stdout).unwrap(); - assert_eq!(cert.to_der().unwrap(), keys.client.cert_der); + assert_eq!(cert.to_der().unwrap(), keys.client.ca.get_der()); } #[test] @@ -163,8 +163,8 @@ fn peer_certificate() { let keys = test_cert_gen::keys(); let identity = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let builder = p!(TlsAcceptor::new(identity)); @@ -177,7 +177,7 @@ fn peer_certificate() { assert!(socket.peer_certificate().unwrap().is_none()); }); - let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap(); + let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap(); let socket = p!(TcpStream::connect(("localhost", port))); let builder = p!(TlsConnector::builder() @@ -186,7 +186,10 @@ fn peer_certificate() { let socket = p!(builder.connect("localhost", socket)); let cert = socket.peer_certificate().unwrap().unwrap(); - assert_eq!(cert.to_der().unwrap(), keys.client.cert_der); + assert_eq!( + cert.to_der().unwrap(), + keys.server.cert_and_key.cert.get_der() + ); p!(j.join()); } @@ -196,8 +199,8 @@ fn server_tls11_only() { let keys = test_cert_gen::keys(); let identity = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let builder = p!(TlsAcceptor::builder(identity) .min_protocol_version(Some(Protocol::Tlsv12)) @@ -218,7 +221,7 @@ fn server_tls11_only() { p!(socket.write_all(b"world")); }); - let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap(); + let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap(); let socket = p!(TcpStream::connect(("localhost", port))); let builder = p!(TlsConnector::builder() @@ -241,8 +244,8 @@ fn server_no_shared_protocol() { let keys = test_cert_gen::keys(); let identity = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let builder = p!(TlsAcceptor::builder(identity) .min_protocol_version(Some(Protocol::Tlsv12)) @@ -256,7 +259,7 @@ fn server_no_shared_protocol() { assert!(builder.accept(socket).is_err()); }); - let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap(); + let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap(); let socket = p!(TcpStream::connect(("localhost", port))); let builder = p!(TlsConnector::builder() @@ -274,8 +277,8 @@ fn server_untrusted() { let keys = test_cert_gen::keys(); let identity = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let builder = p!(TlsAcceptor::new(identity)); @@ -301,8 +304,8 @@ fn server_untrusted_unverified() { let keys = test_cert_gen::keys(); let identity = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let builder = p!(TlsAcceptor::new(identity)); @@ -339,13 +342,28 @@ fn import_same_identity_multiple_times() { let keys = test_cert_gen::keys(); let _ = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let _ = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); + + let cert = keys.server.cert_and_key.cert.to_pem().into_bytes(); + let key = rsa_to_pkcs8(&keys.server.cert_and_key.key.to_pem_incorrect()).into_bytes(); + let _ = p!(Identity::from_pkcs8(&cert, &key)); + let _ = p!(Identity::from_pkcs8(&cert, &key)); +} + +#[test] +fn from_pkcs8_rejects_rsa_key() { + let keys = test_cert_gen::keys(); + let cert = keys.server.cert_and_key.cert.to_pem().into_bytes(); + let rsa_key = keys.server.cert_and_key.key.to_pem_incorrect(); + assert!(Identity::from_pkcs8(&cert, rsa_key.as_bytes()).is_err()); + let pkcs8_key = rsa_to_pkcs8(&rsa_key); + assert!(Identity::from_pkcs8(&cert, pkcs8_key.as_bytes()).is_ok()); } #[test] @@ -353,8 +371,8 @@ fn shutdown() { let keys = test_cert_gen::keys(); let identity = p!(Identity::from_pkcs12( - &keys.server.pkcs12, - &keys.server.pkcs12_password + &keys.server.cert_and_key_pkcs12.pkcs12.0, + &keys.server.cert_and_key_pkcs12.password )); let builder = p!(TlsAcceptor::new(identity)); @@ -373,7 +391,7 @@ fn shutdown() { p!(socket.shutdown()); }); - let root_ca = Certificate::from_der(&keys.client.cert_der).unwrap(); + let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap(); let socket = p!(TcpStream::connect(("localhost", port))); let builder = p!(TlsConnector::builder() @@ -416,3 +434,140 @@ fn alpn_google_none() { let alpn = p!(socket.negotiated_alpn()); assert_eq!(alpn, None); } + +#[test] +fn server_pkcs8() { + let keys = test_cert_gen::keys(); + let cert = keys.server.cert_and_key.cert.to_pem().into_bytes(); + let key = rsa_to_pkcs8(&keys.server.cert_and_key.key.to_pem_incorrect()).into_bytes(); + + let ident = Identity::from_pkcs8(&cert, &key).unwrap(); + let ident2 = ident.clone(); + let builder = p!(TlsAcceptor::new(ident)); + + let listener = p!(TcpListener::bind("0.0.0.0:0")); + let port = p!(listener.local_addr()).port(); + + let j = thread::spawn(move || { + let socket = p!(listener.accept()).0; + let mut socket = p!(builder.accept(socket)); + + let mut buf = [0; 5]; + p!(socket.read_exact(&mut buf)); + assert_eq!(&buf, b"hello"); + + p!(socket.write_all(b"world")); + }); + + let root_ca = Certificate::from_der(keys.client.ca.get_der()).unwrap(); + + let socket = p!(TcpStream::connect(("localhost", port))); + let mut builder = TlsConnector::builder(); + // FIXME + // This checks that we can successfully add a certificate on the client side. + // Unfortunately, we can not request client certificates through the API of this library, + // otherwise we could check in the server thread that + // socket.peer_certificate().unwrap().is_some() + builder.identity(ident2); + + builder.add_root_certificate(root_ca); + let builder = p!(builder.build()); + let mut socket = p!(builder.connect("localhost", socket)); + + p!(socket.write_all(b"hello")); + let mut buf = vec![]; + p!(socket.read_to_end(&mut buf)); + assert_eq!(buf, b"world"); + + p!(j.join()); +} + +#[test] +fn two_servers() { + let keys1 = test_cert_gen::gen_keys(); + let cert = keys1.server.cert_and_key.cert.to_pem().into_bytes(); + let key = rsa_to_pkcs8(&keys1.server.cert_and_key.key.to_pem_incorrect()).into_bytes(); + let identity = p!(Identity::from_pkcs8(&cert, &key)); + let builder = TlsAcceptor::builder(identity); + let builder = p!(builder.build()); + + let listener = p!(TcpListener::bind("0.0.0.0:0")); + let port = p!(listener.local_addr()).port(); + + let j = thread::spawn(move || { + let socket = p!(listener.accept()).0; + let mut socket = p!(builder.accept(socket)); + + let mut buf = [0; 5]; + p!(socket.read_exact(&mut buf)); + assert_eq!(&buf, b"hello"); + + p!(socket.write_all(b"world")); + }); + + let keys2 = test_cert_gen::gen_keys(); + let cert = keys2.server.cert_and_key.cert.to_pem().into_bytes(); + let key = rsa_to_pkcs8(&keys2.server.cert_and_key.key.to_pem_incorrect()).into_bytes(); + let identity = p!(Identity::from_pkcs8(&cert, &key)); + let builder = TlsAcceptor::builder(identity); + let builder = p!(builder.build()); + + let listener = p!(TcpListener::bind("0.0.0.0:0")); + let port2 = p!(listener.local_addr()).port(); + + let j2 = thread::spawn(move || { + let socket = p!(listener.accept()).0; + let mut socket = p!(builder.accept(socket)); + + let mut buf = [0; 5]; + p!(socket.read_exact(&mut buf)); + assert_eq!(&buf, b"hello"); + + p!(socket.write_all(b"world")); + }); + + let root_ca = Certificate::from_der(keys1.client.ca.get_der()).unwrap(); + + let socket = p!(TcpStream::connect(("localhost", port))); + let mut builder = TlsConnector::builder(); + builder.add_root_certificate(root_ca); + let builder = p!(builder.build()); + let mut socket = p!(builder.connect("localhost", socket)); + + p!(socket.write_all(b"hello")); + let mut buf = vec![]; + p!(socket.read_to_end(&mut buf)); + assert_eq!(buf, b"world"); + + let root_ca = Certificate::from_der(keys2.client.ca.get_der()).unwrap(); + + let socket = p!(TcpStream::connect(("localhost", port2))); + let mut builder = TlsConnector::builder(); + builder.add_root_certificate(root_ca); + let builder = p!(builder.build()); + let mut socket = p!(builder.connect("localhost", socket)); + + p!(socket.write_all(b"hello")); + let mut buf = vec![]; + p!(socket.read_to_end(&mut buf)); + assert_eq!(buf, b"world"); + + p!(j.join()); + p!(j2.join()); +} + +fn rsa_to_pkcs8(pem: &str) -> String { + let mut child = Command::new("openssl") + .arg("pkcs8") + .arg("-topk8") + .arg("-nocrypt") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .spawn() + .unwrap(); + { + let child_stdin = child.stdin.as_mut().unwrap(); + child_stdin.write_all(pem.as_bytes()).unwrap(); + } + String::from_utf8(child.wait_with_output().unwrap().stdout).unwrap() +}