Skip to content

Commit 0305ec6

Browse files
committed
Introduce async callbacks
We introduce tokio_boring::SslContextBuilderExt, with 2 methods: * set_async_select_certificate_callback * set_async_private_key_method
1 parent 887f6fd commit 0305ec6

File tree

9 files changed

+586
-11
lines changed

9 files changed

+586
-11
lines changed

boring/src/ssl/mod.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t);
482482
impl SelectCertError {
483483
/// A fatal error occured and the handshake should be terminated.
484484
pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error);
485+
486+
/// The operation could not be completed and should be retried later.
487+
pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry);
485488
}
486489

487490
/// Extension types, to be used with `ClientHello::get_extension`.
@@ -2486,7 +2489,7 @@ impl SslRef {
24862489
}
24872490

24882491
#[cfg(feature = "kx-safe-default")]
2489-
fn client_set_default_curves_list(&mut self) -> Result<(), ErrorStack> {
2492+
fn client_set_default_curves_list(&mut self) {
24902493
let curves = if cfg!(feature = "kx-client-pq-preferred") {
24912494
if cfg!(feature = "kx-client-nist-required") {
24922495
"P256Kyber768Draft00:P-256:P-384"
@@ -2508,11 +2511,13 @@ impl SslRef {
25082511
};
25092512

25102513
self.set_curves_list(curves)
2514+
.expect("invalid default client curves list")
25112515
}
25122516

25132517
#[cfg(feature = "kx-safe-default")]
2514-
fn server_set_default_curves_list(&mut self) -> Result<(), ErrorStack> {
2518+
fn server_set_default_curves_list(&mut self) {
25152519
self.set_curves_list("X25519Kyber768Draft00:P256Kyber768Draft00:X25519:P-256:P-384")
2520+
.expect("invalid default server curves list")
25162521
}
25172522

25182523
/// Like [`SslContextBuilder::set_verify`].
@@ -3260,6 +3265,11 @@ impl<S> MidHandshakeSslStream<S> {
32603265
self.stream.ssl()
32613266
}
32623267

3268+
/// Returns a mutable reference to the `Ssl` of the stream.
3269+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3270+
self.stream.ssl_mut()
3271+
}
3272+
32633273
/// Returns the underlying error which interrupted this handshake.
32643274
pub fn error(&self) -> &Error {
32653275
&self.error
@@ -3514,6 +3524,11 @@ impl<S> SslStream<S> {
35143524
pub fn ssl(&self) -> &SslRef {
35153525
&self.ssl
35163526
}
3527+
3528+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
3529+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3530+
&mut self.ssl
3531+
}
35173532
}
35183533

35193534
impl<S: Read + Write> Read for SslStream<S> {
@@ -3599,7 +3614,7 @@ where
35993614
self.set_connect_state();
36003615

36013616
#[cfg(feature = "kx-safe-default")]
3602-
stream.ssl.client_set_default_curves_list()?;
3617+
self.inner.client_set_default_curves_list();
36033618

36043619
MidHandshakeSslStream {
36053620
stream: self.inner,
@@ -3630,7 +3645,7 @@ where
36303645
self.set_accept_state();
36313646

36323647
#[cfg(feature = "kx-safe-default")]
3633-
stream.ssl.server_set_default_curves_list()?;
3648+
self.inner.server_set_default_curves_list();
36343649

36353650
MidHandshakeSslStream {
36363651
stream: self.inner,

boring/src/ssl/test/mod.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,8 +1120,7 @@ fn client_set_default_curves_list() {
11201120
let ssl_ctx = SslContextBuilder::new(SslMethod::tls()).unwrap().build();
11211121
let mut ssl = Ssl::new(&ssl_ctx).unwrap();
11221122

1123-
ssl.client_set_default_curves_list()
1124-
.expect("Failed to set curves list. Is Kyber768 missing in boringSSL?")
1123+
ssl.client_set_default_curves_list();
11251124
}
11261125

11271126
#[cfg(feature = "kx-safe-default")]
@@ -1130,6 +1129,5 @@ fn server_set_default_curves_list() {
11301129
let ssl_ctx = SslContextBuilder::new(SslMethod::tls()).unwrap().build();
11311130
let mut ssl = Ssl::new(&ssl_ctx).unwrap();
11321131

1133-
ssl.server_set_default_curves_list()
1134-
.expect("Failed to set curves list. Is Kyber768 missing in boringSSL?")
1132+
ssl.server_set_default_curves_list();
11351133
}

boring/src/ssl/test/private_key_method.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ fn test_sign_retry_complete_failure() {
189189
ErrorCode::WANT_PRIVATE_KEY_OPERATION
190190
);
191191

192-
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
192+
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err()
193+
else {
193194
panic!("should be WouldBlock");
194195
};
195196

tokio-boring/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ no-patches = ["boring/no-patches"]
3939
[dependencies]
4040
boring = { workspace = true }
4141
boring-sys = { workspace = true }
42+
once_cell = { workspace = true }
4243
tokio = { workspace = true }
4344

4445
[dev-dependencies]
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
use boring::ex_data::Index;
2+
use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder};
3+
use once_cell::sync::Lazy;
4+
use std::future::Future;
5+
use std::pin::Pin;
6+
use std::task::{ready, Context, Poll, Waker};
7+
8+
type BoxSelectCertFuture = ExDataFuture<Result<BoxSelectCertFinish, AsyncSelectCertError>>;
9+
10+
type BoxSelectCertFinish = Box<dyn FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError>>;
11+
12+
/// The type of futures returned by [`AsyncPrivateKeyMethod`] methods.
13+
pub type BoxPrivateKeyMethodFuture =
14+
ExDataFuture<Result<BoxPrivateKeyMethodFinish, AsyncPrivateKeyMethodError>>;
15+
16+
/// The type of callbacks returned by [`BoxPrivateKeyMethodFuture`].
17+
pub type BoxPrivateKeyMethodFinish =
18+
Box<dyn FnOnce(&mut ssl::SslRef, &mut [u8]) -> Result<usize, AsyncPrivateKeyMethodError>>;
19+
20+
type ExDataFuture<T> = Pin<Box<dyn Future<Output = T> + Send + Sync>>;
21+
22+
pub(crate) static TASK_WAKER_INDEX: Lazy<Index<Ssl, Option<Waker>>> =
23+
Lazy::new(|| Ssl::new_ex_index().unwrap());
24+
pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, BoxSelectCertFuture>> =
25+
Lazy::new(|| Ssl::new_ex_index().unwrap());
26+
pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy<
27+
Index<Ssl, BoxPrivateKeyMethodFuture>,
28+
> = Lazy::new(|| Ssl::new_ex_index().unwrap());
29+
30+
/// Extensions to [`SslContextBuilder`].
31+
///
32+
/// This trait provides additional methods to use async callbacks with boring.
33+
pub trait SslContextBuilderExt: private::Sealed {
34+
/// Sets a callback that is called before most [`ClientHello`] processing
35+
/// and before the decision whether to resume a session is made. The
36+
/// callback may inspect the [`ClientHello`] and configure the connection.
37+
///
38+
/// This method uses a function that returns a future whose output is
39+
/// itself a closure that will be passed [`ClientHello`] to configure
40+
/// the connection based on the computations done in the future.
41+
///
42+
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
43+
/// setter of this callback.
44+
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
45+
where
46+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
47+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
48+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static;
49+
50+
/// Configures a custom private key method on the context.
51+
///
52+
/// See [`AsyncPrivateKeyMethod`] for more details.
53+
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod);
54+
}
55+
56+
impl SslContextBuilderExt for SslContextBuilder {
57+
fn set_async_select_certificate_callback<Init, Fut, Finish>(&mut self, callback: Init)
58+
where
59+
Init: Fn(&mut ClientHello<'_>) -> Result<Fut, AsyncSelectCertError> + Send + Sync + 'static,
60+
Fut: Future<Output = Result<Finish, AsyncSelectCertError>> + Send + Sync + 'static,
61+
Finish: FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError> + 'static,
62+
{
63+
self.set_select_certificate_callback(move |mut client_hello| {
64+
let fut_poll_result = with_ex_data_future(
65+
&mut client_hello,
66+
*SELECT_CERT_FUTURE_INDEX,
67+
ClientHello::ssl_mut,
68+
|client_hello| {
69+
let fut = callback(client_hello)?;
70+
71+
Ok(Box::pin(async move {
72+
Ok(Box::new(fut.await?) as BoxSelectCertFinish)
73+
}))
74+
},
75+
);
76+
77+
let fut_result = match fut_poll_result {
78+
Poll::Ready(fut_result) => fut_result,
79+
Poll::Pending => return Err(ssl::SelectCertError::RETRY),
80+
};
81+
82+
let finish = fut_result.or(Err(ssl::SelectCertError::ERROR))?;
83+
84+
finish(client_hello).or(Err(ssl::SelectCertError::ERROR))
85+
})
86+
}
87+
88+
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) {
89+
self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method)));
90+
}
91+
}
92+
93+
/// A fatal error to be returned from async select certificate callbacks.
94+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
95+
pub struct AsyncSelectCertError;
96+
97+
/// Describes async private key hooks. This is used to off-load signing
98+
/// operations to a custom, potentially asynchronous, backend. Metadata about the
99+
/// key such as the type and size are parsed out of the certificate.
100+
///
101+
/// See [`PrivateKeyMethod`] for the sync version of those hooks.
102+
///
103+
/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st
104+
pub trait AsyncPrivateKeyMethod: Send + Sync + 'static {
105+
/// Signs the message `input` using the specified signature algorithm.
106+
///
107+
/// This method uses a function that returns a future whose output is
108+
/// itself a closure that will be passed `ssl` and `output`
109+
/// to finish writing the signature.
110+
///
111+
/// See [`PrivateKeyMethod::sign`] for the sync version of this method.
112+
fn sign(
113+
&self,
114+
ssl: &mut ssl::SslRef,
115+
input: &[u8],
116+
signature_algorithm: ssl::SslSignatureAlgorithm,
117+
output: &mut [u8],
118+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
119+
120+
/// Decrypts `input`.
121+
///
122+
/// This method uses a function that returns a future whose output is
123+
/// itself a closure that will be passed `ssl` and `output`
124+
/// to finish decrypting the input.
125+
///
126+
/// See [`PrivateKeyMethod::decrypt`] for the sync version of this method.
127+
fn decrypt(
128+
&self,
129+
ssl: &mut ssl::SslRef,
130+
input: &[u8],
131+
output: &mut [u8],
132+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
133+
}
134+
135+
/// A fatal error to be returned from async private key methods.
136+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
137+
pub struct AsyncPrivateKeyMethodError;
138+
139+
struct AsyncPrivateKeyMethodBridge(Box<dyn AsyncPrivateKeyMethod>);
140+
141+
impl PrivateKeyMethod for AsyncPrivateKeyMethodBridge {
142+
fn sign(
143+
&self,
144+
ssl: &mut ssl::SslRef,
145+
input: &[u8],
146+
signature_algorithm: ssl::SslSignatureAlgorithm,
147+
output: &mut [u8],
148+
) -> Result<usize, ssl::PrivateKeyMethodError> {
149+
with_private_key_method(ssl, output, |ssl, output| {
150+
<dyn AsyncPrivateKeyMethod>::sign(&*self.0, ssl, input, signature_algorithm, output)
151+
})
152+
}
153+
154+
fn decrypt(
155+
&self,
156+
ssl: &mut ssl::SslRef,
157+
input: &[u8],
158+
output: &mut [u8],
159+
) -> Result<usize, ssl::PrivateKeyMethodError> {
160+
with_private_key_method(ssl, output, |ssl, output| {
161+
<dyn AsyncPrivateKeyMethod>::decrypt(&*self.0, ssl, input, output)
162+
})
163+
}
164+
165+
fn complete(
166+
&self,
167+
ssl: &mut ssl::SslRef,
168+
output: &mut [u8],
169+
) -> Result<usize, ssl::PrivateKeyMethodError> {
170+
with_private_key_method(ssl, output, |_, _| {
171+
// This should never be reached, if it does, that's a bug on boring's side,
172+
// which called `complete` without having been returned to with a pending
173+
// future from `sign` or `decrypt`.
174+
175+
if cfg!(debug_assertions) {
176+
panic!("BUG: boring called complete without a pending operation");
177+
}
178+
179+
Err(AsyncPrivateKeyMethodError)
180+
})
181+
}
182+
}
183+
184+
/// Creates and drives a private key method future.
185+
///
186+
/// This is a convenience function for the three methods of impl `PrivateKeyMethod``
187+
/// for `dyn AsyncPrivateKeyMethod`. It relies on [`with_ex_data_future`] to
188+
/// drive the future and then immediately calls the final [`BoxPrivateKeyMethodFinish`]
189+
/// when the future is ready.
190+
fn with_private_key_method(
191+
ssl: &mut ssl::SslRef,
192+
output: &mut [u8],
193+
create_fut: impl FnOnce(
194+
&mut ssl::SslRef,
195+
&mut [u8],
196+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>,
197+
) -> Result<usize, ssl::PrivateKeyMethodError> {
198+
let fut_poll_result = with_ex_data_future(
199+
ssl,
200+
*SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX,
201+
|ssl| ssl,
202+
|ssl| create_fut(ssl, output),
203+
);
204+
205+
let fut_result = match fut_poll_result {
206+
Poll::Ready(fut_result) => fut_result,
207+
Poll::Pending => return Err(ssl::PrivateKeyMethodError::RETRY),
208+
};
209+
210+
let finish = fut_result.or(Err(ssl::PrivateKeyMethodError::FAILURE))?;
211+
212+
finish(ssl, output).or(Err(ssl::PrivateKeyMethodError::FAILURE))
213+
}
214+
215+
/// Creates and drives a future stored in `ssl_handle`'s `Ssl` at ex data index `index`.
216+
///
217+
/// This function won't even bother storing the future in `index` if the future
218+
/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call.
219+
fn with_ex_data_future<H, T, E>(
220+
ssl_handle: &mut H,
221+
index: Index<ssl::Ssl, ExDataFuture<Result<T, E>>>,
222+
get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef,
223+
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<Result<T, E>>, E>,
224+
) -> Poll<Result<T, E>> {
225+
let ssl = get_ssl_mut(ssl_handle);
226+
let waker = ssl
227+
.ex_data(*TASK_WAKER_INDEX)
228+
.cloned()
229+
.flatten()
230+
.expect("task waker should be set");
231+
232+
let mut ctx = Context::from_waker(&waker);
233+
234+
match ssl.ex_data_mut(index) {
235+
Some(fut) => {
236+
let fut_result = ready!(fut.as_mut().poll(&mut ctx));
237+
238+
// NOTE(nox): For memory usage concerns, maybe we should implement
239+
// a way to remove the stored future from the `Ssl` value here?
240+
241+
Poll::Ready(fut_result)
242+
}
243+
None => {
244+
let mut fut = create_fut(ssl_handle)?;
245+
246+
match fut.as_mut().poll(&mut ctx) {
247+
Poll::Ready(fut_result) => Poll::Ready(fut_result),
248+
Poll::Pending => {
249+
get_ssl_mut(ssl_handle).set_ex_data(index, fut);
250+
251+
Poll::Pending
252+
}
253+
}
254+
}
255+
}
256+
}
257+
258+
mod private {
259+
pub trait Sealed {}
260+
}
261+
262+
impl private::Sealed for SslContextBuilder {}

tokio-boring/src/bridge.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
//! Bridge between sync IO traits and async tokio IO traits.
2-
32
use std::fmt;
43
use std::io;
54
use std::pin::Pin;
@@ -35,7 +34,7 @@ impl<S> AsyncStreamBridge<S> {
3534
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
3635
{
3736
let mut ctx =
38-
Context::from_waker(self.waker.as_ref().expect("missing task context pointer"));
37+
Context::from_waker(self.waker.as_ref().expect("BUG: missing waker in bridge"));
3938

4039
f(&mut ctx, Pin::new(&mut self.stream))
4140
}

0 commit comments

Comments
 (0)