Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ impl Builder {
#[cfg_attr(docsrs, doc(cfg(tokio_unstable)))]
pub fn build_local(&mut self, options: LocalOptions) -> io::Result<LocalRuntime> {
match &self.kind {
Kind::CurrentThread => self.build_current_thread_local_runtime(),
Kind::CurrentThread => self.build_current_thread_local_runtime(options),
#[cfg(feature = "rt-multi-thread")]
Kind::MultiThread => panic!("multi_thread is not supported for LocalRuntime"),
}
Expand Down Expand Up @@ -1439,11 +1439,16 @@ impl Builder {
}

#[cfg(tokio_unstable)]
fn build_current_thread_local_runtime(&mut self) -> io::Result<LocalRuntime> {
fn build_current_thread_local_runtime(
&mut self,
opts: LocalOptions,
) -> io::Result<LocalRuntime> {
use crate::runtime::local_runtime::LocalRuntimeScheduler;

let tid = std::thread::current().id();

self.before_park = opts.before_park;
self.after_unpark = opts.after_unpark;
let (scheduler, handle, blocking_pool) =
self.build_current_thread_runtime_components(Some(tid))?;

Expand Down
149 changes: 145 additions & 4 deletions tokio/src/runtime/local_runtime/options.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,159 @@
use std::marker::PhantomData;

use crate::runtime::Callback;

/// [`LocalRuntime`]-only config options
///
/// Currently, there are no such options, but in the future, things like `!Send + !Sync` hooks may
/// be added.
///
/// Use `LocalOptions::default()` to create the default set of options. This type is used with
/// [`Builder::build_local`].
///
/// When using [`Builder::build_local`], this overrides any pre-configured options set on the
/// [`Builder`].
///
/// [`Builder::build_local`]: crate::runtime::Builder::build_local
/// [`LocalRuntime`]: crate::runtime::LocalRuntime
#[derive(Default, Debug)]
/// [`Builder`]: crate::runtime::Builder
#[derive(Default)]
#[non_exhaustive]
#[allow(missing_debug_implementations)]
pub struct LocalOptions {
/// Marker used to make this !Send and !Sync.
_phantom: PhantomData<*mut u8>,

/// To run before the local runtime is parked.
pub(crate) before_park: Option<Callback>,

/// To run before the local runtime is spawned.
pub(crate) after_unpark: Option<Callback>,
}

impl std::fmt::Debug for LocalOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalOptions")
.field("before_park", &self.before_park.as_ref().map(|_| "..."))
.field("after_unpark", &self.after_unpark.as_ref().map(|_| "..."))
.finish()
}
}

impl LocalOptions {
/// Executes function `f` just before the local runtime is parked (goes idle).
/// `f` is called within the Tokio context, so functions like [`tokio::spawn`](crate::spawn)
/// can be called, and may result in this thread being unparked immediately.
///
/// This can be used to start work only when the executor is idle, or for bookkeeping
/// and monitoring purposes.
///
/// This differs from the [`Builder::on_thread_park`] method in that it accepts a non Send + Sync
/// closure.
///
/// Note: There can only be one park callback for a runtime; calling this function
/// more than once replaces the last callback defined, rather than adding to it.
///
/// # Examples
///
/// ```
/// # use tokio::runtime::{Builder, LocalOptions};
/// # pub fn main() {
/// let (tx, rx) = std::sync::mpsc::channel();
/// let mut opts = LocalOptions::default();
/// opts.on_thread_park(move || match rx.recv() {
/// Ok(x) => println!("Received from channel: {}", x),
/// Err(e) => println!("Error receiving from channel: {}", e),
/// });
///
/// let runtime = Builder::new_current_thread()
/// .enable_time()
/// .build_local(opts)
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn_local(async move {
/// tx.send(42).unwrap();
/// });
/// tokio::time::sleep(std::time::Duration::from_millis(1)).await;
/// })
/// # }
/// ```
///
/// [`Builder`]: crate::runtime::Builder
/// [`Builder::on_thread_park`]: crate::runtime::Builder::on_thread_park
pub fn on_thread_park<F>(&mut self, f: F) -> &mut Self
where
F: Fn() + 'static,
{
self.before_park = Some(std::sync::Arc::new(to_send_sync(f)));
self
}

/// Executes function `f` just after the local runtime unparks (starts executing tasks).
///
/// This is intended for bookkeeping and monitoring use cases; note that work
/// in this callback will increase latencies when the application has allowed one or
/// more runtime threads to go idle.
///
/// This differs from the [`Builder::on_thread_unpark`] method in that it accepts a non Send + Sync
/// closure.
///
/// Note: There can only be one unpark callback for a runtime; calling this function
/// more than once replaces the last callback defined, rather than adding to it.
///
/// # Examples
///
/// ```
/// # use tokio::runtime::{Builder, LocalOptions};
/// # pub fn main() {
/// let (tx, rx) = std::sync::mpsc::channel();
/// let mut opts = LocalOptions::default();
/// opts.on_thread_unpark(move || match rx.recv() {
/// Ok(x) => println!("Received from channel: {}", x),
/// Err(e) => println!("Error receiving from channel: {}", e),
/// });
///
/// let runtime = Builder::new_current_thread()
/// .enable_time()
/// .build_local(opts)
/// .unwrap();
///
/// runtime.block_on(async {
/// tokio::task::spawn_local(async move {
/// tx.send(42).unwrap();
/// });
/// tokio::time::sleep(std::time::Duration::from_millis(1)).await;
/// })
/// # }
/// ```
///
/// [`Builder`]: crate::runtime::Builder
/// [`Builder::on_thread_unpark`]: crate::runtime::Builder::on_thread_unpark
pub fn on_thread_unpark<F>(&mut self, f: F) -> &mut Self
where
F: Fn() + 'static,
{
self.after_unpark = Some(std::sync::Arc::new(to_send_sync(f)));
self
}
}

// A wrapper type to allow non-Send + Sync closures to be used in a Send + Sync context.
// This is specifically used for executing callbacks when using a `LocalRuntime`.
struct UnsafeSendSync<T>(T);

// SAFETY: This type is only used in a context where it is guaranteed that the closure will not be
// sent across threads.
unsafe impl<T> Send for UnsafeSendSync<T> {}
unsafe impl<T> Sync for UnsafeSendSync<T> {}

impl<T: Fn()> UnsafeSendSync<T> {
fn call(&self) {
(self.0)()
}
}

fn to_send_sync<F>(f: F) -> impl Fn() + Send + Sync
where
F: Fn(),
{
let f = UnsafeSendSync(f);
move || f.call()
}
50 changes: 42 additions & 8 deletions tokio/tests/rt_local.rs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a test to make sure callbacks will not be executed when using the Handle::block_on.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaiu, we can only test this for on_thread_unpark. we can't test this for on_thread_park because if the task parks itself inside Handle::block_on, there's nothing that can call unpark() for the task.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tokio::task::spawn_local;

#[test]
fn test_spawn_local_in_runtime() {
let rt = rt();
let rt = rt(LocalOptions::default());

let res = rt.block_on(async move {
let (tx, rx) = tokio::sync::oneshot::channel();
Expand All @@ -22,9 +22,43 @@ fn test_spawn_local_in_runtime() {
assert_eq!(res, 5);
}

#[test]
fn test_on_thread_park_unpark_in_runtime() {
let mut opts = LocalOptions::default();

// the refcell makes the below callbacks `!Send + !Sync`
let on_park_called = std::rc::Rc::new(std::cell::RefCell::new(false));
let on_park_cc = on_park_called.clone();
opts.on_thread_park(move || {
*on_park_cc.borrow_mut() = true;
});

let on_unpark_called = std::rc::Rc::new(std::cell::RefCell::new(false));
let on_unpark_cc = on_unpark_called.clone();
opts.on_thread_unpark(move || {
*on_unpark_cc.borrow_mut() = true;
});
let rt = rt(opts);

rt.block_on(async move {
let (tx, rx) = tokio::sync::oneshot::channel();

spawn_local(async {
tokio::task::yield_now().await;
tx.send(5).unwrap();
});

// this ensures on_thread_park is called
rx.await.unwrap()
});

assert!(*on_park_called.borrow());
assert!(*on_unpark_called.borrow());
}

#[test]
fn test_spawn_from_handle() {
let rt = rt();
let rt = rt(LocalOptions::default());

let (tx, rx) = tokio::sync::oneshot::channel();

Expand All @@ -40,7 +74,7 @@ fn test_spawn_from_handle() {

#[test]
fn test_spawn_local_on_runtime_object() {
let rt = rt();
let rt = rt(LocalOptions::default());

let (tx, rx) = tokio::sync::oneshot::channel();

Expand All @@ -56,7 +90,7 @@ fn test_spawn_local_on_runtime_object() {

#[test]
fn test_spawn_local_from_guard() {
let rt = rt();
let rt = rt(LocalOptions::default());

let (tx, rx) = tokio::sync::oneshot::channel();

Expand All @@ -78,7 +112,7 @@ fn test_spawn_from_guard_other_thread() {
let (tx, rx) = std::sync::mpsc::channel();

std::thread::spawn(move || {
let rt = rt();
let rt = rt(LocalOptions::default());
let handle = rt.handle().clone();

tx.send(handle).unwrap();
Expand All @@ -98,7 +132,7 @@ fn test_spawn_local_from_guard_other_thread() {
let (tx, rx) = std::sync::mpsc::channel();

std::thread::spawn(move || {
let rt = rt();
let rt = rt(LocalOptions::default());
let handle = rt.handle().clone();

tx.send(handle).unwrap();
Expand All @@ -111,9 +145,9 @@ fn test_spawn_local_from_guard_other_thread() {
spawn_local(async {});
}

fn rt() -> tokio::runtime::LocalRuntime {
fn rt(opts: LocalOptions) -> tokio::runtime::LocalRuntime {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build_local(LocalOptions::default())
.build_local(opts)
.unwrap()
}
Loading