Skip to content

Commit 8c6fd07

Browse files
committed
Introduce async callbacks for set_select_certificate_callback
1 parent 6fcd97f commit 8c6fd07

File tree

5 files changed

+208
-4
lines changed

5 files changed

+208
-4
lines changed

boring/src/ssl/connector.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ impl ConnectConfiguration {
189189
self.verify_hostname = verify_hostname;
190190
}
191191

192+
pub fn ssl_mut(&mut self) -> &mut SslRef {
193+
&mut self.ssl
194+
}
195+
192196
/// Initiates a client-side TLS session on a stream.
193197
///
194198
/// The domain is used for SNI and hostname verification if enabled.
@@ -324,8 +328,12 @@ impl SslAcceptor {
324328
where
325329
S: Read + Write,
326330
{
327-
let ssl = Ssl::new(&self.0)?;
328-
ssl.accept(stream)
331+
self.new_session()?.accept(stream)
332+
}
333+
334+
/// Creates a new TLS session, ready to accept a stream.
335+
pub fn new_session(&self) -> Result<Ssl, ErrorStack> {
336+
Ssl::new(&self.0)
329337
}
330338

331339
/// Consumes the `SslAcceptor`, returning the inner raw `SslContext`.

boring/src/ssl/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t);
478478
impl SelectCertError {
479479
/// A fatal error occured and the handshake should be terminated.
480480
pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error);
481+
482+
/// The operation could not be completed and should be retried later.
483+
pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry);
481484
}
482485

483486
/// Extension types, to be used with `ClientHello::get_extension`.
@@ -3136,6 +3139,11 @@ impl<S> MidHandshakeSslStream<S> {
31363139
self.stream.ssl()
31373140
}
31383141

3142+
/// Returns a mutable reference to the `Ssl` of the stream.
3143+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3144+
self.stream.ssl_mut()
3145+
}
3146+
31393147
/// Returns the underlying error which interrupted this handshake.
31403148
pub fn error(&self) -> &Error {
31413149
&self.error
@@ -3390,6 +3398,11 @@ impl<S> SslStream<S> {
33903398
pub fn ssl(&self) -> &SslRef {
33913399
&self.ssl
33923400
}
3401+
3402+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
3403+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3404+
&mut self.ssl
3405+
}
33933406
}
33943407

33953408
impl<S: Read + Write> Read for SslStream<S> {

tokio-boring/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pq-experimental = ["boring/pq-experimental"]
3131
[dependencies]
3232
boring = { workspace = true }
3333
boring-sys = { workspace = true }
34+
once_cell = { workspace = true }
3435
tokio = { workspace = true }
3536

3637
[dev-dependencies]
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
use boring::ex_data::Index;
2+
use boring::ssl::{self, ClientHello, Ssl, SslContextBuilder};
3+
use once_cell::sync::Lazy;
4+
use std::future::Future;
5+
use std::pin::Pin;
6+
use std::task::{ready, Context, Poll};
7+
8+
type BoxSelectCertFuture = Pin<
9+
Box<
10+
dyn Future<Output = Result<BoxSelectCertFinish, AsyncSelectCertError>>
11+
+ Send
12+
+ Sync
13+
+ 'static,
14+
>,
15+
>;
16+
17+
type BoxSelectCertFinish =
18+
Box<dyn FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static>;
19+
20+
pub(crate) static TASK_CONTEXT_INDEX: Lazy<Index<Ssl, usize>> =
21+
Lazy::new(|| Ssl::new_ex_index().unwrap());
22+
pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, BoxSelectCertFuture>> =
23+
Lazy::new(|| Ssl::new_ex_index().unwrap());
24+
25+
/// Extensions to [`SslContextBuilder`].
26+
///
27+
/// This trait provides additional methods to use async callbacks with boring.
28+
pub trait SslContextBuilderExt: private::Sealed {
29+
/// Sets a callback that is called before most [`ClientHello`] processing
30+
/// and before the decision whether to resume a session is made. The
31+
/// callback may inspect the [`ClientHello`] and configure the connection.
32+
///
33+
/// This method uses a function that returns a future whose output is
34+
/// itself a closure that will be passed [`ClientHello`] to configure
35+
/// the connection based on the computations done in the future.
36+
///
37+
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
38+
/// setter of this callback.
39+
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
40+
where
41+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
42+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
43+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static;
44+
45+
/// Sets a callback that is called before most [`ClientHello`] processing
46+
/// and before the decision whether to resume a session is made. The
47+
/// callback may inspect the [`ClientHello`] and configure the connection.
48+
///
49+
/// This method uses a polling function.
50+
///
51+
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
52+
/// setter of this callback.
53+
fn set_polling_select_certificate_callback<F>(
54+
&mut self,
55+
callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll<Result<(), AsyncSelectCertError>>
56+
+ Send
57+
+ Sync
58+
+ 'static,
59+
);
60+
}
61+
62+
impl SslContextBuilderExt for SslContextBuilder {
63+
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
64+
where
65+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
66+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
67+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
68+
{
69+
self.set_select_certificate_callback(async_select_certificate_callback(callback))
70+
}
71+
72+
fn set_polling_select_certificate_callback<F>(
73+
&mut self,
74+
callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll<Result<(), AsyncSelectCertError>>
75+
+ Send
76+
+ Sync
77+
+ 'static,
78+
) {
79+
self.set_select_certificate_callback(polling_select_certificate_callback(callback));
80+
}
81+
}
82+
83+
/// A fatal error to be returned from select certificate callbacks.
84+
pub struct AsyncSelectCertError;
85+
86+
fn async_select_certificate_callback<Init, Fut, Finish>(
87+
callback: Init,
88+
) -> impl Fn(ClientHello<'_>) -> Result<(), ssl::SelectCertError> + Send + Sync + 'static
89+
where
90+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
91+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
92+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
93+
{
94+
polling_select_certificate_callback(move |mut client_hello, cx| {
95+
let fut_result = match client_hello
96+
.ssl_mut()
97+
.ex_data_mut(*SELECT_CERT_FUTURE_INDEX)
98+
{
99+
Some(fut) => ready!(fut.as_mut().poll(cx)),
100+
None => {
101+
let fut = callback(&mut client_hello)?;
102+
let mut box_fut =
103+
Box::pin(async move { Ok(Box::new(fut.await?) as BoxSelectCertFinish) })
104+
as BoxSelectCertFuture;
105+
106+
match box_fut.as_mut().poll(cx) {
107+
Poll::Ready(fut_result) => fut_result,
108+
Poll::Pending => {
109+
client_hello
110+
.ssl_mut()
111+
.set_ex_data(*SELECT_CERT_FUTURE_INDEX, box_fut);
112+
113+
return Poll::Pending;
114+
}
115+
}
116+
}
117+
};
118+
119+
// NOTE(nox): For memory usage concerns, maybe we should implement
120+
// a way to remove the stored future from the `Ssl` value here?
121+
122+
Poll::Ready(fut_result?(client_hello))
123+
})
124+
}
125+
126+
fn polling_select_certificate_callback(
127+
callback: impl Fn(ClientHello<'_>, &mut Context<'_>) -> Poll<Result<(), AsyncSelectCertError>>
128+
+ Send
129+
+ Sync
130+
+ 'static,
131+
) -> impl Fn(ClientHello<'_>) -> Result<(), ssl::SelectCertError> + Send + Sync + 'static {
132+
move |client_hello| {
133+
let cx = unsafe {
134+
&mut *(*client_hello
135+
.ssl()
136+
.ex_data(*TASK_CONTEXT_INDEX)
137+
.expect("task context should be set") as *mut Context<'_>)
138+
};
139+
140+
match callback(client_hello, cx) {
141+
Poll::Ready(Ok(())) => Ok(()),
142+
Poll::Ready(Err(AsyncSelectCertError)) => Err(ssl::SelectCertError::ERROR),
143+
Poll::Pending => Err(ssl::SelectCertError::RETRY),
144+
}
145+
}
146+
}
147+
148+
mod private {
149+
pub trait Sealed {}
150+
}
151+
152+
impl private::Sealed for SslContextBuilder {}

tokio-boring/src/lib.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,22 @@ use std::pin::Pin;
2626
use std::task::{Context, Poll};
2727
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2828

29+
mod async_callbacks;
30+
31+
use self::async_callbacks::TASK_CONTEXT_INDEX;
32+
pub use self::async_callbacks::{AsyncSelectCertError, SslContextBuilderExt};
33+
2934
/// Asynchronously performs a client-side TLS handshake over the provided stream.
3035
pub async fn connect<S>(
31-
config: ConnectConfiguration,
36+
mut config: ConnectConfiguration,
3237
domain: &str,
3338
stream: S,
3439
) -> Result<SslStream<S>, HandshakeError<S>>
3540
where
3641
S: AsyncRead + AsyncWrite + Unpin,
3742
{
43+
config.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);
44+
3845
handshake(|s| config.connect(domain, s), stream).await
3946
}
4047

@@ -43,7 +50,13 @@ pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>
4350
where
4451
S: AsyncRead + AsyncWrite + Unpin,
4552
{
46-
handshake(|s| acceptor.accept(s), stream).await
53+
let mut ssl = acceptor
54+
.new_session()
55+
.map_err(|e| HandshakeError(e.into()))?;
56+
57+
ssl.set_ex_data(*TASK_CONTEXT_INDEX, 0);
58+
59+
handshake(|s| ssl.accept(s), stream).await
4760
}
4861

4962
async fn handshake<F, S>(f: F, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
@@ -163,6 +176,11 @@ impl<S> SslStream<S> {
163176
self.0.ssl()
164177
}
165178

179+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
180+
pub fn ssl_mut(&mut self) -> &mut SslRef {
181+
self.0.ssl_mut()
182+
}
183+
166184
/// Returns a shared reference to the underlying stream.
167185
pub fn get_ref(&self) -> &S {
168186
&self.0.get_ref().stream
@@ -367,13 +385,18 @@ where
367385
stream: inner.stream,
368386
context: ctx as *mut _ as usize,
369387
};
388+
370389
match (inner.f)(stream) {
371390
Ok(mut s) => {
372391
s.get_mut().context = 0;
392+
s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);
393+
373394
Poll::Ready(Ok(StartedHandshake::Done(SslStream(s))))
374395
}
375396
Err(ssl::HandshakeError::WouldBlock(mut s)) => {
376397
s.get_mut().context = 0;
398+
s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);
399+
377400
Poll::Ready(Ok(StartedHandshake::Mid(s)))
378401
}
379402
Err(e) => Poll::Ready(Err(HandshakeError(e))),
@@ -396,13 +419,20 @@ where
396419
let mut s = self.0.take().expect("future polled after completion");
397420

398421
s.get_mut().context = ctx as *mut _ as usize;
422+
s.ssl_mut()
423+
.set_ex_data(*TASK_CONTEXT_INDEX, ctx as *mut _ as usize);
424+
399425
match s.handshake() {
400426
Ok(mut s) => {
401427
s.get_mut().context = 0;
428+
s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);
429+
402430
Poll::Ready(Ok(SslStream(s)))
403431
}
404432
Err(ssl::HandshakeError::WouldBlock(mut s)) => {
405433
s.get_mut().context = 0;
434+
s.ssl_mut().set_ex_data(*TASK_CONTEXT_INDEX, 0);
435+
406436
self.0 = Some(s);
407437
Poll::Pending
408438
}

0 commit comments

Comments
 (0)