diff --git a/tokio/src/process/unix/orphan.rs b/tokio/src/process/unix/orphan.rs index a89555f5876..72c1e0c30d4 100644 --- a/tokio/src/process/unix/orphan.rs +++ b/tokio/src/process/unix/orphan.rs @@ -295,7 +295,7 @@ pub(crate) mod test { #[cfg_attr(miri, ignore)] // Miri does not support epoll. #[test] fn does_not_register_signal_if_queue_empty() { - let (io_driver, io_handle) = IoDriver::new(1024).unwrap(); + let (io_driver, io_handle) = IoDriver::new(1024, 1).unwrap(); let signal_driver = SignalDriver::new(io_driver, &io_handle).unwrap(); let handle = signal_driver.handle(); diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 68a16772abc..6f64ee03bd7 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -1463,7 +1463,7 @@ impl Builder { use crate::runtime::scheduler; use crate::runtime::Config; - let (driver, driver_handle) = driver::Driver::new(self.get_cfg())?; + let (driver, driver_handle) = driver::Driver::new(self.get_cfg(), 1)?; // Blocking pool let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads); @@ -1642,7 +1642,7 @@ cfg_rt_multi_thread! { let worker_threads = self.worker_threads.unwrap_or_else(num_cpus); - let (driver, driver_handle) = driver::Driver::new(self.get_cfg())?; + let (driver, driver_handle) = driver::Driver::new(self.get_cfg(), worker_threads)?; // Create the blocking pool let blocking_pool = diff --git a/tokio/src/runtime/driver.rs b/tokio/src/runtime/driver.rs index 7f1fe4cf4c5..ab1f6fb8bf9 100644 --- a/tokio/src/runtime/driver.rs +++ b/tokio/src/runtime/driver.rs @@ -43,8 +43,9 @@ pub(crate) struct Cfg { } impl Driver { - pub(crate) fn new(cfg: Cfg) -> io::Result<(Self, Handle)> { - let (io_stack, io_handle, signal_handle) = create_io_stack(cfg.enable_io, cfg.nevents)?; + pub(crate) fn new(cfg: Cfg, num_workers: usize) -> io::Result<(Self, Handle)> { + let (io_stack, io_handle, signal_handle) = + create_io_stack(cfg.enable_io, cfg.nevents, num_workers)?; let clock = create_clock(cfg.enable_pause_time, cfg.start_paused); @@ -136,12 +137,12 @@ cfg_io_driver_impl_or_uring! { Disabled(UnparkThread), } - fn create_io_stack(enabled: bool, nevents: usize) -> io::Result<(IoStack, IoHandle, SignalHandle)> { + fn create_io_stack(enabled: bool, nevents: usize, num_workers: usize) -> io::Result<(IoStack, IoHandle, SignalHandle)> { #[cfg(loom)] assert!(!enabled); let ret = if enabled { - let (io_driver, io_handle) = crate::runtime::io::Driver::new(nevents)?; + let (io_driver, io_handle) = crate::runtime::io::Driver::new(nevents, num_workers)?; let (signal_driver, signal_handle) = create_signal_driver(io_driver, &io_handle)?; let process_driver = create_process_driver(signal_driver); @@ -202,7 +203,7 @@ cfg_not_io_driver_impl_or_uring! { #[derive(Debug)] pub(crate) struct IoStack(ParkThread); - fn create_io_stack(_enabled: bool, _nevents: usize) -> io::Result<(IoStack, IoHandle, SignalHandle)> { + fn create_io_stack(_enabled: bool, _nevents: usize, _num_worker: usize) -> io::Result<(IoStack, IoHandle, SignalHandle)> { let park_thread = ParkThread::new(); let unpark_thread = park_thread.unpark(); Ok((IoStack(park_thread), unpark_thread, Default::default())) diff --git a/tokio/src/runtime/driver/op.rs b/tokio/src/runtime/driver/op.rs index 023a85c79fe..2bad5cfb834 100644 --- a/tokio/src/runtime/driver/op.rs +++ b/tokio/src/runtime/driver/op.rs @@ -1,4 +1,5 @@ use crate::runtime::Handle; +use crate::runtime::OpId; use io_uring::cqueue; use io_uring::squeue::Entry; use std::future::Future; @@ -36,6 +37,8 @@ pub(crate) struct Op { state: State, // Per operation data. data: Option, + + pub(crate) shard_id: usize, } impl Op { @@ -48,6 +51,7 @@ impl Op { Self { data: Some(data), state: State::Initialize(Some(entry)), + shard_id: OpId::next().as_u64() as usize, } } pub(crate) fn take_data(&mut self) -> Option { @@ -109,13 +113,13 @@ impl Future for Op { let entry = entry_opt.take().expect("Entry must be present"); let waker = cx.waker().clone(); // SAFETY: entry is valid for the entire duration of the operation - let idx = unsafe { driver.register_op(entry, waker)? }; + let idx = unsafe { driver.register_op(entry, waker, this.shard_id)? }; this.state = State::Polled(idx); Poll::Pending } State::Polled(idx) => { - let mut ctx = driver.get_uring().lock(); + let mut ctx = driver.get_uring(this.shard_id).lock(); let lifecycle = ctx.ops.get_mut(*idx).expect("Lifecycle must be present"); match mem::replace(lifecycle, Lifecycle::Submitted) { diff --git a/tokio/src/runtime/io/driver.rs b/tokio/src/runtime/io/driver.rs index 76ced77a90d..7fa23e8e7e3 100644 --- a/tokio/src/runtime/io/driver.rs +++ b/tokio/src/runtime/io/driver.rs @@ -4,7 +4,7 @@ cfg_signal_internal_and_unix! { } cfg_tokio_unstable_uring! { mod uring; - use uring::UringContext; + use uring::{UringContext}; } use crate::io::interest::Interest; @@ -56,7 +56,7 @@ pub(crate) struct Handle { feature = "fs", target_os = "linux", ))] - pub(crate) uring_context: Mutex, + pub(crate) uring_context: Box<[Mutex]>, } #[derive(Debug)] @@ -92,7 +92,32 @@ pub(super) enum Tick { const TOKEN_WAKEUP: mio::Token = mio::Token(0); const TOKEN_SIGNAL: mio::Token = mio::Token(1); cfg_tokio_unstable_uring! { - pub(crate) const TOKEN_URING: mio::Token = mio::Token(2); + // Since `ScheduledIo` is at least `repr(align(16))`, we can use the first 4 bits. + // This allows us to use 13 (= 15 - 2) values as shard id for uring operations. + // + // 0b00000 => TOKEN_WAKEUP + // 0b00001 => TOKEN_SIGNAL + // 0b00010 ~ 0b01111 => URING_TOKEN + // 0b10000 ~ => raw ScheduledIo pointers + + const URING_TOKEN_START: usize = 0b10; + const URING_TOKEN_END: usize = 0b1111; + const MAX_SHARD_SIZE: usize = URING_TOKEN_END - URING_TOKEN_START; + + pub(super) fn is_uring_token(token: mio::Token) -> bool { + URING_TOKEN_START <= token.0 && token.0 <= URING_TOKEN_END + } + + pub(super) fn get_shard_id(token: mio::Token) -> usize { + debug_assert!(is_uring_token(token), "token {token:?} is not a uring token"); + token.0.saturating_sub(URING_TOKEN_START) + } + + pub(super) fn as_uring_token(n: usize) -> mio::Token { + let token = mio::Token(n.saturating_add(URING_TOKEN_START)); + debug_assert!(is_uring_token(token), "token {token:?} is not a uring token"); + token + } } fn _assert_kinds() { @@ -106,7 +131,10 @@ fn _assert_kinds() { impl Driver { /// Creates a new event loop, returning any error that happened during the /// creation. - pub(crate) fn new(nevents: usize) -> io::Result<(Driver, Handle)> { + pub(crate) fn new( + nevents: usize, + #[allow(unused)] num_workers: usize, + ) -> io::Result<(Driver, Handle)> { let poll = mio::Poll::new()?; #[cfg(not(target_os = "wasi"))] let waker = mio::Waker::new(poll.registry(), TOKEN_WAKEUP)?; @@ -120,6 +148,14 @@ impl Driver { let (registrations, synced) = RegistrationSet::new(); + #[cfg(all( + tokio_unstable_uring, + feature = "rt", + feature = "fs", + target_os = "linux", + ))] + let num_workers = num_workers.min(MAX_SHARD_SIZE); + let handle = Handle { registry, registrations, @@ -133,7 +169,10 @@ impl Driver { feature = "fs", target_os = "linux", ))] - uring_context: Mutex::new(UringContext::new()), + uring_context: (0..num_workers) + .map(|_| Mutex::new(UringContext::new())) + .collect::>() + .into_boxed_slice(), }; #[cfg(all( @@ -143,7 +182,9 @@ impl Driver { target_os = "linux", ))] { - handle.add_uring_source(Interest::READABLE)?; + for shard_id in 0..num_workers { + handle.add_uring_source(shard_id, Interest::READABLE)?; + } } Ok((driver, handle)) @@ -207,8 +248,10 @@ impl Driver { feature = "fs", target_os = "linux", ))] - TOKEN_URING => { - let mut guard = handle.get_uring().lock(); + token if is_uring_token(token) => { + let shard_id = get_shard_id(token); + + let mut guard = handle.get_uring(shard_id).lock(); let ctx = &mut *guard; ctx.dispatch_completions(); } diff --git a/tokio/src/runtime/io/driver/uring.rs b/tokio/src/runtime/io/driver/uring.rs index ce122aab6d6..95efc4e65c9 100644 --- a/tokio/src/runtime/io/driver/uring.rs +++ b/tokio/src/runtime/io/driver/uring.rs @@ -5,7 +5,7 @@ use slab::Slab; use crate::runtime::driver::op::{Lifecycle, Op}; use crate::{io::Interest, loom::sync::Mutex}; -use super::{Handle, TOKEN_URING}; +use super::{as_uring_token, Handle}; use std::os::fd::AsRawFd; use std::{io, mem, task::Waker}; @@ -123,24 +123,32 @@ impl Drop for UringContext { impl Handle { #[allow(dead_code)] - pub(crate) fn add_uring_source(&self, interest: Interest) -> io::Result<()> { + pub(crate) fn add_uring_source(&self, shard_id: usize, interest: Interest) -> io::Result<()> { // setup for io_uring - let uringfd = self.get_uring().lock().uring.as_raw_fd(); + let mut guard = self.get_uring(shard_id).lock(); + let ctx = &mut *guard; + let uringfd = ctx.uring.as_raw_fd(); let mut source = SourceFd(&uringfd); self.registry - .register(&mut source, TOKEN_URING, interest.to_mio()) + .register(&mut source, as_uring_token(shard_id), interest.to_mio()) } - pub(crate) fn get_uring(&self) -> &Mutex { - &self.uring_context + pub(crate) fn get_uring(&self, shard_id: usize) -> &Mutex { + let shard_id = shard_id % self.uring_context.len(); + &self.uring_context[shard_id] } /// # Safety /// /// Callers must ensure that parameters of the entry (such as buffer) are valid and will /// be valid for the entire duration of the operation, otherwise it may cause memory problems. - pub(crate) unsafe fn register_op(&self, entry: Entry, waker: Waker) -> io::Result { - let mut guard = self.get_uring().lock(); + pub(crate) unsafe fn register_op( + &self, + entry: Entry, + waker: Waker, + shard_id: usize, + ) -> io::Result { + let mut guard = self.get_uring(shard_id).lock(); let ctx = &mut *guard; let index = ctx.ops.insert(Lifecycle::Waiting(waker)); let entry = entry.user_data(index as u64); @@ -167,7 +175,7 @@ impl Handle { } pub(crate) fn cancel_op(&self, op: &mut Op, index: usize) { - let mut guard = self.get_uring().lock(); + let mut guard = self.get_uring(op.shard_id).lock(); let ctx = &mut *guard; let ops = &mut ctx.ops; let Some(lifecycle) = ops.get_mut(index) else { diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index bd212ac09d4..de039920efc 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -421,3 +421,8 @@ cfg_rt! { /// After thread starts / before thread stops type Callback = std::sync::Arc; } + +cfg_tokio_unstable_uring! { + pub(crate) mod op_id; + pub(crate) use op_id::OpId; +} diff --git a/tokio/src/runtime/op_id.rs b/tokio/src/runtime/op_id.rs new file mode 100644 index 00000000000..c5d153ca5c1 --- /dev/null +++ b/tokio/src/runtime/op_id.rs @@ -0,0 +1,32 @@ +// TODO: Put together with other id related utils. + +use std::num::NonZeroU64; + +#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)] +pub(crate) struct OpId(NonZeroU64); + +impl OpId { + pub(crate) fn next() -> Self { + use crate::loom::sync::atomic::Ordering::Relaxed; + use crate::loom::sync::atomic::StaticAtomicU64; + + #[cfg(all(test, loom))] + crate::loom::lazy_static! { + static ref NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(1); + } + + #[cfg(not(all(test, loom)))] + static NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(1); + + loop { + let id = NEXT_ID.fetch_add(1, Relaxed); + if let Some(id) = NonZeroU64::new(id) { + return Self(id); + } + } + } + + pub(crate) fn as_u64(&self) -> u64 { + self.0.get() + } +} diff --git a/tokio/tests/fs_uring.rs b/tokio/tests/fs_uring.rs index 4fb1a4fba07..295dd839e43 100644 --- a/tokio/tests/fs_uring.rs +++ b/tokio/tests/fs_uring.rs @@ -31,7 +31,7 @@ fn current_rt() -> Box Runtime> { #[test] fn all_tests() { - let rt_conbination = vec![current_rt(), multi_rt(1), multi_rt(8)]; + let rt_conbination = vec![current_rt(), multi_rt(1), multi_rt(8), multi_rt(16)]; for rt in rt_conbination { shutdown_runtime_while_performing_io_uring_ops(rt());