Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio/src/process/unix/orphan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ pub(crate) mod test {
#[cfg_attr(miri, ignore)] // Miri does not support epoll.
#[test]
fn does_not_register_signal_if_queue_empty() {
let (io_driver, io_handle) = IoDriver::new(1024).unwrap();
let (io_driver, io_handle) = IoDriver::new(1024, 1).unwrap();
let signal_driver = SignalDriver::new(io_driver, &io_handle).unwrap();
let handle = signal_driver.handle();

Expand Down
4 changes: 2 additions & 2 deletions tokio/src/runtime/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,7 @@ impl Builder {
use crate::runtime::scheduler;
use crate::runtime::Config;

let (driver, driver_handle) = driver::Driver::new(self.get_cfg())?;
let (driver, driver_handle) = driver::Driver::new(self.get_cfg(), 1)?;

// Blocking pool
let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads);
Expand Down Expand Up @@ -1642,7 +1642,7 @@ cfg_rt_multi_thread! {

let worker_threads = self.worker_threads.unwrap_or_else(num_cpus);

let (driver, driver_handle) = driver::Driver::new(self.get_cfg())?;
let (driver, driver_handle) = driver::Driver::new(self.get_cfg(), worker_threads)?;

// Create the blocking pool
let blocking_pool =
Expand Down
11 changes: 6 additions & 5 deletions tokio/src/runtime/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ pub(crate) struct Cfg {
}

impl Driver {
pub(crate) fn new(cfg: Cfg) -> io::Result<(Self, Handle)> {
let (io_stack, io_handle, signal_handle) = create_io_stack(cfg.enable_io, cfg.nevents)?;
pub(crate) fn new(cfg: Cfg, num_workers: usize) -> io::Result<(Self, Handle)> {
let (io_stack, io_handle, signal_handle) =
create_io_stack(cfg.enable_io, cfg.nevents, num_workers)?;

let clock = create_clock(cfg.enable_pause_time, cfg.start_paused);

Expand Down Expand Up @@ -136,12 +137,12 @@ cfg_io_driver_impl_or_uring! {
Disabled(UnparkThread),
}

fn create_io_stack(enabled: bool, nevents: usize) -> io::Result<(IoStack, IoHandle, SignalHandle)> {
fn create_io_stack(enabled: bool, nevents: usize, num_workers: usize) -> io::Result<(IoStack, IoHandle, SignalHandle)> {
#[cfg(loom)]
assert!(!enabled);

let ret = if enabled {
let (io_driver, io_handle) = crate::runtime::io::Driver::new(nevents)?;
let (io_driver, io_handle) = crate::runtime::io::Driver::new(nevents, num_workers)?;

let (signal_driver, signal_handle) = create_signal_driver(io_driver, &io_handle)?;
let process_driver = create_process_driver(signal_driver);
Expand Down Expand Up @@ -202,7 +203,7 @@ cfg_not_io_driver_impl_or_uring! {
#[derive(Debug)]
pub(crate) struct IoStack(ParkThread);

fn create_io_stack(_enabled: bool, _nevents: usize) -> io::Result<(IoStack, IoHandle, SignalHandle)> {
fn create_io_stack(_enabled: bool, _nevents: usize, _num_worker: usize) -> io::Result<(IoStack, IoHandle, SignalHandle)> {
let park_thread = ParkThread::new();
let unpark_thread = park_thread.unpark();
Ok((IoStack(park_thread), unpark_thread, Default::default()))
Expand Down
8 changes: 6 additions & 2 deletions tokio/src/runtime/driver/op.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::runtime::Handle;
use crate::runtime::OpId;
use io_uring::cqueue;
use io_uring::squeue::Entry;
use std::future::Future;
Expand Down Expand Up @@ -36,6 +37,8 @@ pub(crate) struct Op<T: Send + 'static> {
state: State,
// Per operation data.
data: Option<T>,

pub(crate) shard_id: usize,
}

impl<T: Send + 'static> Op<T> {
Expand All @@ -48,6 +51,7 @@ impl<T: Send + 'static> Op<T> {
Self {
data: Some(data),
state: State::Initialize(Some(entry)),
shard_id: OpId::next().as_u64() as usize,
}
}
pub(crate) fn take_data(&mut self) -> Option<T> {
Expand Down Expand Up @@ -109,13 +113,13 @@ impl<T: Completable + Unpin + Send> Future for Op<T> {
let entry = entry_opt.take().expect("Entry must be present");
let waker = cx.waker().clone();
// SAFETY: entry is valid for the entire duration of the operation
let idx = unsafe { driver.register_op(entry, waker)? };
let idx = unsafe { driver.register_op(entry, waker, this.shard_id)? };
this.state = State::Polled(idx);
Poll::Pending
}

State::Polled(idx) => {
let mut ctx = driver.get_uring().lock();
let mut ctx = driver.get_uring(this.shard_id).lock();
let lifecycle = ctx.ops.get_mut(*idx).expect("Lifecycle must be present");

match mem::replace(lifecycle, Lifecycle::Submitted) {
Expand Down
59 changes: 51 additions & 8 deletions tokio/src/runtime/io/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ cfg_signal_internal_and_unix! {
}
cfg_tokio_unstable_uring! {
mod uring;
use uring::UringContext;
use uring::{UringContext};
}

use crate::io::interest::Interest;
Expand Down Expand Up @@ -56,7 +56,7 @@ pub(crate) struct Handle {
feature = "fs",
target_os = "linux",
))]
pub(crate) uring_context: Mutex<UringContext>,
pub(crate) uring_context: Box<[Mutex<UringContext>]>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -92,7 +92,32 @@ pub(super) enum Tick {
const TOKEN_WAKEUP: mio::Token = mio::Token(0);
const TOKEN_SIGNAL: mio::Token = mio::Token(1);
cfg_tokio_unstable_uring! {
pub(crate) const TOKEN_URING: mio::Token = mio::Token(2);
// Since `ScheduledIo` is at least `repr(align(16))`, we can use the first 4 bits.
// This allows us to use 13 (= 15 - 2) values as shard id for uring operations.
//
// 0b00000 => TOKEN_WAKEUP
// 0b00001 => TOKEN_SIGNAL
// 0b00010 ~ 0b01111 => URING_TOKEN
// 0b10000 ~ => raw ScheduledIo pointers

const URING_TOKEN_START: usize = 0b10;
const URING_TOKEN_END: usize = 0b1111;
const MAX_SHARD_SIZE: usize = URING_TOKEN_END - URING_TOKEN_START;

pub(super) fn is_uring_token(token: mio::Token) -> bool {
URING_TOKEN_START <= token.0 && token.0 <= URING_TOKEN_END
}

pub(super) fn get_shard_id(token: mio::Token) -> usize {
debug_assert!(is_uring_token(token), "token {token:?} is not a uring token");
token.0.saturating_sub(URING_TOKEN_START)
}

pub(super) fn as_uring_token(n: usize) -> mio::Token {
let token = mio::Token(n.saturating_add(URING_TOKEN_START));
debug_assert!(is_uring_token(token), "token {token:?} is not a uring token");
token
}
}

fn _assert_kinds() {
Expand All @@ -106,7 +131,10 @@ fn _assert_kinds() {
impl Driver {
/// Creates a new event loop, returning any error that happened during the
/// creation.
pub(crate) fn new(nevents: usize) -> io::Result<(Driver, Handle)> {
pub(crate) fn new(
nevents: usize,
#[allow(unused)] num_workers: usize,
) -> io::Result<(Driver, Handle)> {
let poll = mio::Poll::new()?;
#[cfg(not(target_os = "wasi"))]
let waker = mio::Waker::new(poll.registry(), TOKEN_WAKEUP)?;
Expand All @@ -120,6 +148,14 @@ impl Driver {

let (registrations, synced) = RegistrationSet::new();

#[cfg(all(
tokio_unstable_uring,
feature = "rt",
feature = "fs",
target_os = "linux",
))]
let num_workers = num_workers.min(MAX_SHARD_SIZE);

let handle = Handle {
registry,
registrations,
Expand All @@ -133,7 +169,10 @@ impl Driver {
feature = "fs",
target_os = "linux",
))]
uring_context: Mutex::new(UringContext::new()),
uring_context: (0..num_workers)
.map(|_| Mutex::new(UringContext::new()))
.collect::<Vec<_>>()
.into_boxed_slice(),
};

#[cfg(all(
Expand All @@ -143,7 +182,9 @@ impl Driver {
target_os = "linux",
))]
{
handle.add_uring_source(Interest::READABLE)?;
for shard_id in 0..num_workers {
handle.add_uring_source(shard_id, Interest::READABLE)?;
}
}

Ok((driver, handle))
Expand Down Expand Up @@ -207,8 +248,10 @@ impl Driver {
feature = "fs",
target_os = "linux",
))]
TOKEN_URING => {
let mut guard = handle.get_uring().lock();
token if is_uring_token(token) => {
let shard_id = get_shard_id(token);

let mut guard = handle.get_uring(shard_id).lock();
let ctx = &mut *guard;
ctx.dispatch_completions();
}
Expand Down
26 changes: 17 additions & 9 deletions tokio/src/runtime/io/driver/uring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use slab::Slab;
use crate::runtime::driver::op::{Lifecycle, Op};
use crate::{io::Interest, loom::sync::Mutex};

use super::{Handle, TOKEN_URING};
use super::{as_uring_token, Handle};

use std::os::fd::AsRawFd;
use std::{io, mem, task::Waker};
Expand Down Expand Up @@ -123,24 +123,32 @@ impl Drop for UringContext {

impl Handle {
#[allow(dead_code)]
pub(crate) fn add_uring_source(&self, interest: Interest) -> io::Result<()> {
pub(crate) fn add_uring_source(&self, shard_id: usize, interest: Interest) -> io::Result<()> {
// setup for io_uring
let uringfd = self.get_uring().lock().uring.as_raw_fd();
let mut guard = self.get_uring(shard_id).lock();
let ctx = &mut *guard;
let uringfd = ctx.uring.as_raw_fd();
let mut source = SourceFd(&uringfd);
self.registry
.register(&mut source, TOKEN_URING, interest.to_mio())
.register(&mut source, as_uring_token(shard_id), interest.to_mio())
}

pub(crate) fn get_uring(&self) -> &Mutex<UringContext> {
&self.uring_context
pub(crate) fn get_uring(&self, shard_id: usize) -> &Mutex<UringContext> {
let shard_id = shard_id % self.uring_context.len();
&self.uring_context[shard_id]
}

/// # Safety
///
/// Callers must ensure that parameters of the entry (such as buffer) are valid and will
/// be valid for the entire duration of the operation, otherwise it may cause memory problems.
pub(crate) unsafe fn register_op(&self, entry: Entry, waker: Waker) -> io::Result<usize> {
let mut guard = self.get_uring().lock();
pub(crate) unsafe fn register_op(
&self,
entry: Entry,
waker: Waker,
shard_id: usize,
) -> io::Result<usize> {
let mut guard = self.get_uring(shard_id).lock();
let ctx = &mut *guard;
let index = ctx.ops.insert(Lifecycle::Waiting(waker));
let entry = entry.user_data(index as u64);
Expand All @@ -167,7 +175,7 @@ impl Handle {
}

pub(crate) fn cancel_op<T: Send + 'static>(&self, op: &mut Op<T>, index: usize) {
let mut guard = self.get_uring().lock();
let mut guard = self.get_uring(op.shard_id).lock();
let ctx = &mut *guard;
let ops = &mut ctx.ops;
let Some(lifecycle) = ops.get_mut(index) else {
Expand Down
5 changes: 5 additions & 0 deletions tokio/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,8 @@ cfg_rt! {
/// After thread starts / before thread stops
type Callback = std::sync::Arc<dyn Fn() + Send + Sync>;
}

cfg_tokio_unstable_uring! {
pub(crate) mod op_id;
pub(crate) use op_id::OpId;
}
32 changes: 32 additions & 0 deletions tokio/src/runtime/op_id.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// TODO: Put together with other id related utils.

use std::num::NonZeroU64;

#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
pub(crate) struct OpId(NonZeroU64);

impl OpId {
pub(crate) fn next() -> Self {
use crate::loom::sync::atomic::Ordering::Relaxed;
use crate::loom::sync::atomic::StaticAtomicU64;

#[cfg(all(test, loom))]
crate::loom::lazy_static! {
static ref NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(1);
}

#[cfg(not(all(test, loom)))]
static NEXT_ID: StaticAtomicU64 = StaticAtomicU64::new(1);

loop {
let id = NEXT_ID.fetch_add(1, Relaxed);
if let Some(id) = NonZeroU64::new(id) {
return Self(id);
}
}
}

pub(crate) fn as_u64(&self) -> u64 {
self.0.get()
}
}
2 changes: 1 addition & 1 deletion tokio/tests/fs_uring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn current_rt() -> Box<dyn Fn() -> Runtime> {

#[test]
fn all_tests() {
let rt_conbination = vec![current_rt(), multi_rt(1), multi_rt(8)];
let rt_conbination = vec![current_rt(), multi_rt(1), multi_rt(8), multi_rt(16)];

for rt in rt_conbination {
shutdown_runtime_while_performing_io_uring_ops(rt());
Expand Down
Loading