Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 229 additions & 36 deletions tokio/src/fs/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,61 @@ struct Inner {
#[derive(Debug)]
enum State {
Idle(Option<Buf>),
Busy(JoinHandle<(Operation, Buf)>),
Busy(JoinHandleInner<(Operation, Buf)>),
}

#[derive(Debug)]
enum JoinHandleInner<T> {
Blocking(JoinHandle<T>),
#[cfg(all(
tokio_unstable,
feature = "io-uring",
feature = "rt",
feature = "fs",
target_os = "linux"
))]
Async(BoxedOp<T>),
}

cfg_io_uring! {
struct BoxedOp<T>(Pin<Box<dyn Future<Output = T> + Send + Sync + 'static>>);

impl<T> std::fmt::Debug for BoxedOp<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// format of BoxedFuture(T::type_name())
f.debug_tuple("BoxedFuture")
.field(&std::any::type_name::<T>())
.finish()
}
}

impl<T> Future for BoxedOp<T> {
type Output = T;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
}

impl Future for JoinHandleInner<(Operation, Buf)> {
type Output = io::Result<(Operation, Buf)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.get_mut() {
JoinHandleInner::Blocking(ref mut jh) => Pin::new(jh)
.poll(cx)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "background task failed")),
#[cfg(all(
tokio_unstable,
feature = "io-uring",
feature = "rt",
feature = "fs",
target_os = "linux"
))]
JoinHandleInner::Async(ref mut jh) => Pin::new(jh).poll(cx).map(Ok),
}
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -399,7 +453,7 @@ impl File {

let std = self.std.clone();

inner.state = State::Busy(spawn_blocking(move || {
inner.state = State::Busy(JoinHandleInner::Blocking(spawn_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| std.set_len(size))
} else {
Expand All @@ -409,7 +463,7 @@ impl File {

// Return the result as a seek
(Operation::Seek(res), buf)
}));
})));

let (op, buf) = match inner.state {
State::Idle(_) => unreachable!(),
Expand Down Expand Up @@ -613,13 +667,14 @@ impl AsyncRead for File {
let std = me.std.clone();

let max_buf_size = cmp::min(dst.remaining(), me.max_buf_size);
inner.state = State::Busy(spawn_blocking(move || {
// SAFETY: the `Read` implementation of `std` does not
// read from the buffer it is borrowing and correctly
// reports the length of the data written into the buffer.
let res = unsafe { buf.read_from(&mut &*std, max_buf_size) };
(Operation::Read(res), buf)
}));
inner.state =
State::Busy(JoinHandleInner::Blocking(spawn_blocking(move || {
// SAFETY: the `Read` implementation of `std` does not
// read from the buffer it is borrowing and correctly
// reports the length of the data written into the buffer.
let res = unsafe { buf.read_from(&mut &*std, max_buf_size) };
(Operation::Read(res), buf)
})));
}
State::Busy(ref mut rx) => {
let (op, mut buf) = ready!(Pin::new(rx).poll(cx))?;
Expand Down Expand Up @@ -685,10 +740,10 @@ impl AsyncSeek for File {

let std = me.std.clone();

inner.state = State::Busy(spawn_blocking(move || {
inner.state = State::Busy(JoinHandleInner::Blocking(spawn_blocking(move || {
let res = (&*std).seek(pos);
(Operation::Seek(res), buf)
}));
})));
Ok(())
}
}
Expand Down Expand Up @@ -753,20 +808,90 @@ impl AsyncWrite for File {
let n = buf.copy_from(src, me.max_buf_size);
let std = me.std.clone();

let blocking_task_join_handle = spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};
#[allow(unused_mut)]
let mut data = Some((std, buf));

let mut task_join_handle = None;

#[cfg(all(
tokio_unstable,
feature = "io-uring",
feature = "rt",
feature = "fs",
target_os = "linux"
))]
{
use crate::runtime::Handle;

// Handle not present in some tests?
if let Ok(handle) = Handle::try_current() {
if handle.inner.driver().io().check_and_init()? {
task_join_handle = {
use crate::{io::uring::utils::ArcFd, runtime::driver::op::Op};

let (std, mut buf) = data.take().unwrap();
if let Some(seek) = seek {
// we do std seek before a write, so we can always use u64::MAX (current cursor) for the file offset
// seeking only modifies kernel metadata and does not block, so we can do it here
(&*std).seek(seek).map_err(|e| {
io::Error::new(
e.kind(),
format!("failed to seek before write: {e}"),
)
})?;
}

let mut fd: ArcFd = std;
let handle = BoxedOp(Box::pin(async move {
loop {
let op = Op::write_at(fd, buf, u64::MAX);
let (r, _buf, _fd) = op.await;
buf = _buf;
fd = _fd;
match r {
Ok(0) => {
break (
Operation::Write(Err(
io::ErrorKind::WriteZero.into(),
)),
buf,
);
}
Ok(_) if buf.is_empty() => {
break (Operation::Write(Ok(())), buf);
}
Ok(_) => continue, // more to write
Err(e) => break (Operation::Write(Err(e)), buf),
}
}
}));

Some(JoinHandleInner::Async(handle))
};
}
}
}

(Operation::Write(res), buf)
})
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "background task failed")
})?;
if let Some((std, mut buf)) = data {
task_join_handle = {
let handle = spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};

(Operation::Write(res), buf)
})
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "background task failed")
})?;

Some(JoinHandleInner::Blocking(handle))
};
}

inner.state = State::Busy(blocking_task_join_handle);
inner.state = State::Busy(task_join_handle.unwrap());

return Poll::Ready(Ok(n));
}
Expand Down Expand Up @@ -824,20 +949,88 @@ impl AsyncWrite for File {
let n = buf.copy_from_bufs(bufs, me.max_buf_size);
let std = me.std.clone();

let blocking_task_join_handle = spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};
#[allow(unused_mut)]
let mut data = Some((std, buf));

let mut task_join_handle = None;

#[cfg(all(
tokio_unstable,
feature = "io-uring",
feature = "rt",
feature = "fs",
target_os = "linux"
))]
{
use crate::runtime::Handle;

// Handle not present in some tests?
if let Ok(handle) = Handle::try_current() {
if handle.inner.driver().io().check_and_init()? {
task_join_handle = {
use crate::{io::uring::utils::ArcFd, runtime::driver::op::Op};

let (std, mut buf) = data.take().unwrap();
if let Some(seek) = seek {
// we do std seek before a write, so we can always use u64::MAX (current cursor) for the file offset
// seeking only modifies kernel metadata and does not block, so we can do it here
(&*std).seek(seek).map_err(|e| {
io::Error::new(
e.kind(),
format!("failed to seek before write: {e}"),
)
})?;
}

let mut fd: ArcFd = std;
let handle = BoxedOp(Box::pin(async move {
loop {
let op = Op::write_at(fd, buf, u64::MAX);
let (r, _buf, _fd) = op.await;
buf = _buf;
fd = _fd;
match r {
Ok(0) => {
break (
Operation::Write(Err(
io::ErrorKind::WriteZero.into(),
)),
buf,
);
}
Ok(_) if buf.is_empty() => {
break (Operation::Write(Ok(())), buf);
}
Ok(_) => continue, // more to write
Err(e) => break (Operation::Write(Err(e)), buf),
}
}
}));

Some(JoinHandleInner::Async(handle))
};
}
}
}

(Operation::Write(res), buf)
})
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "background task failed")
})?;
if let Some((std, mut buf)) = data {
task_join_handle = Some(JoinHandleInner::Blocking(
spawn_mandatory_blocking(move || {
let res = if let Some(seek) = seek {
(&*std).seek(seek).and_then(|_| buf.write_to(&mut &*std))
} else {
buf.write_to(&mut &*std)
};

(Operation::Write(res), buf)
})
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "background task failed")
})?,
));
}

inner.state = State::Busy(blocking_task_join_handle);
inner.state = State::Busy(task_join_handle.unwrap());

return Poll::Ready(Ok(n));
}
Expand Down
Loading