diff --git a/src/lib.rs b/src/lib.rs index 05c1b51..9fdad25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -130,6 +130,7 @@ use std::future::Future; use std::io; use std::pin::Pin; +use std::sync::Mutex; use std::task::{Context, Poll}; use std::thread; @@ -453,19 +454,43 @@ impl tokio::io::AsyncSeek for Compat { } } -fn get_runtime_handle() -> tokio::runtime::Handle { +/// Return a handle to the current runtime, or the fallback runtime, if any. +pub fn get_runtime_handle() -> tokio::runtime::Handle { tokio::runtime::Handle::try_current().unwrap_or_else(|_| TOKIO1.handle().clone()) } +/// Provide a custom tokio runtime builder for the fallback runtime. +/// +/// If this is set *before* the first use of the compatibility adapter, the fallback runtime will +/// be created with the function provided in this closure. This has no effect if a fallback runtime +/// has already been created. +pub fn set_runtime_builder( + builder: Box tokio::runtime::Builder + Send + Sync + 'static>, +) { + let mut guard = TOKIO1_RUNTIME_BUILDER.lock().unwrap(); + *guard = Some(builder); +} + +#[allow(clippy::type_complexity)] +static TOKIO1_RUNTIME_BUILDER: Lazy< + Mutex tokio::runtime::Builder + Send + Sync + 'static>>>, +> = Lazy::new(|| Mutex::new(None)); + static TOKIO1: Lazy = Lazy::new(|| { + // Keep the runtime alive. thread::Builder::new() .name("async-compat/tokio-1".into()) .spawn(|| TOKIO1.block_on(Pending)) .unwrap(); - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("cannot start tokio-1 runtime") + + if let Some(builder) = TOKIO1_RUNTIME_BUILDER.lock().unwrap().take() { + builder().build().expect("cannot start tokio-1 runtime") + } else { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("cannot start tokio-1 runtime") + } }); struct Pending; @@ -480,8 +505,10 @@ impl Future for Pending { #[cfg(test)] mod tests { + use std::sync::{Arc, Mutex}; + use super::Lazy; - use crate::{CompatExt, TOKIO1}; + use crate::{set_runtime_builder, CompatExt, TOKIO1}; #[test] fn fallback_runtime_is_created_if_and_only_if_outside_tokio_context() { @@ -504,6 +531,41 @@ mod tests { assert!(Lazy::get(&TOKIO1).is_some()); } + #[test] + fn fallback_runtime_is_created_with_custom_builder() { + let custom_called = Arc::new(Mutex::new(false)); + + let custom_called_clone = custom_called.clone(); + + set_runtime_builder(Box::new(move || { + *custom_called_clone.lock().unwrap() = true; + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder.enable_all(); + builder + })); + + // Use compat inside of a tokio context. + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(use_tokio().compat()); + + // We didn't need to create the fallback runtime, because we used compat + // inside of an existing tokio context. + assert!(Lazy::get(&TOKIO1).is_none()); + + // Use compat outside of a tokio context. + futures::executor::block_on(use_tokio().compat()); + + // We must have created the fallback runtime, because we used compat + // outside of a tokio context. + assert!(Lazy::get(&TOKIO1).is_some()); + + // And we've used the custom runtime builder for this. + assert!(*custom_called.lock().unwrap()); + } + async fn use_tokio() { tokio::time::sleep(std::time::Duration::from_micros(1)).await }