Skip to content

Replace MemCase's unsound impl of Deref/AsMut with a borrow() method #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ These are the main limitations you should be aware of before choosing to use
structure you will need to couple permanently the deserialized structure with
its serialized support, which is obtained by putting it in a [`MemCase`] using
the convenience methods [`Deserialize::load_mem`], [`Deserialize::load_mmap`],
and [`Deserialize::mmap`]. A [`MemCase`] will deref to its contained type, so it
can be used transparently as long as fields and methods are concerned, but if
your original type is `T` the field of the new structure will have to be of type
`MemCase<DeserType<'static, T>>`, not `T`.
and [`Deserialize::mmap`]. A [`MemCase`] provides a method that yields references
to the deserialized type associated to its contained type.

- No validation or padding cleaning is performed on zero-copy types. If you plan
to serialize data and distribute it, you must take care of these issues.

## Pros

- Almost instant deserialization with minimal allocation provided that you
Expand Down Expand Up @@ -140,7 +138,7 @@ let t: DeserType<'_, [usize; 1000]> =
assert_eq!(s, *t);

// This is a traditional deserialization instead
let t: [usize; 1000] =
let t: [usize; 1000] =
unsafe { <[usize; 1000]>::deserialize_full(
&mut std::fs::File::open(&file)?
)? };
Expand All @@ -149,25 +147,19 @@ assert_eq!(s, t);
// In this case we map the data structure into memory
//
// Note: requires the `mmap` feature.
let u: MemCase<&[usize; 1000]> =
unsafe { <[usize; 1000]>::mmap(&file, Flags::empty())? };

assert_eq!(s, **u);

// When using a MemCase, the lifetime of the derived deserialization type is 'static
let u: MemCase<DeserType<'static, [usize; 1000]>> =
let u: MemCase<[usize; 1000]> =
unsafe { <[usize; 1000]>::mmap(&file, Flags::empty())? };

assert_eq!(s, **u);
assert_eq!(s, **u.borrow());
# Ok(())
# }
```

Note how we serialize an array, but we deserialize a reference. The reference
points inside `b`, so there is no copy performed. The call to
[`deserialize_full`] creates a new array instead. The third call maps the data
structure into memory and returns a [`MemCase`] that can be used transparently
as a reference to the array; moreover, the [`MemCase`] can be passed to other
structure into memory and returns a [`MemCase`] that can be used to get
a reference to the array; moreover, the [`MemCase`] can be passed to other
functions or stored in a structure field, as it contains both the structure and
the memory-mapped region that supports it.

Expand Down Expand Up @@ -205,14 +197,14 @@ let t: DeserType<'_, Vec<usize>> =
assert_eq!(s, *t);

// This is a traditional deserialization instead
let t: Vec<usize> =
let t: Vec<usize> =
unsafe { <Vec<usize>>::load_full(&file)? };
assert_eq!(s, t);

// In this case we map the data structure into memory
let u: MemCase<DeserType<'static, Vec<usize>>> =
let u: MemCase<Vec<usize>> =
unsafe { <Vec<usize>>::mmap(&file, Flags::empty())? };
assert_eq!(s, **u);
assert_eq!(s, **u.borrow());
# Ok(())
# }
```
Expand Down Expand Up @@ -264,14 +256,14 @@ let t: DeserType<'_, Vec<Data>> =
assert_eq!(s, *t);

// This is a traditional deserialization instead
let t: Vec<Data> =
let t: Vec<Data> =
unsafe { <Vec<Data>>::load_full(&file)? };
assert_eq!(s, t);

// In this case we map the data structure into memory
let u: MemCase<DeserType<'static, Vec<Data>>> =
let u: MemCase<Vec<Data>> =
unsafe { <Vec<Data>>::mmap(&file, Flags::empty())? };
assert_eq!(s, **u);
assert_eq!(s, **u.borrow());
# Ok(())
# }
```
Expand Down Expand Up @@ -309,20 +301,21 @@ unsafe { s.store(&file) };
let b = std::fs::read(&file)?;

// The type of t will be inferred--it is shown here only for clarity
let t: MyStruct<&[isize]> =
let t: MyStruct<&[isize]> =
unsafe { <MyStruct<Vec<isize>>>::deserialize_eps(b.as_ref())? };

assert_eq!(s.id, t.id);
assert_eq!(s.data, Vec::from(t.data));

// This is a traditional deserialization instead
let t: MyStruct<Vec<isize>> =
let t: MyStruct<Vec<isize>> =
unsafe { <MyStruct<Vec<isize>>>::load_full(&file)? };
assert_eq!(s, t);

// In this case we map the data structure into memory
let u: MemCase<MyStruct<&[isize]>> =
let u: MemCase<MyStruct<Vec<isize>>> =
unsafe { <MyStruct<Vec<isize>>>::mmap(&file, Flags::empty())? };
let u: &MyStruct<&[isize]> = u.borrow();
assert_eq!(s.id, u.id);
assert_eq!(s.data, u.data.as_ref());
# Ok(())
Expand Down Expand Up @@ -371,8 +364,9 @@ let t = unsafe { MyStruct::deserialize_eps(b.as_ref())? };
assert_eq!(s.sum(), t.sum());

let t = unsafe { <MyStruct>::mmap(&file, Flags::empty())? };
let t: &MyStructParam<&[isize]> = t.borrow();

// t works transparently as a MyStructParam<&[isize]>
// t works transparently as a &MyStructParam<&[isize]>
assert_eq!(s.id, t.id);
assert_eq!(s.data, t.data.as_ref());
assert_eq!(s.sum(), t.sum());
Expand Down Expand Up @@ -407,7 +401,7 @@ unsafe { s.store(&file) };
let b = std::fs::read(&file)?;

// The type of t is unchanged
let t: MyStruct<Vec<isize>> =
let t: MyStruct<Vec<isize>> =
unsafe { <MyStruct<Vec<isize>>>::deserialize_eps(b.as_ref())? };
# Ok(())
# }
Expand Down Expand Up @@ -445,7 +439,7 @@ unsafe { s.store(&file) };
let b = std::fs::read(&file)?;

// The type of t is unchanged
let t: &MyStruct<i32> =
let t: &MyStruct<i32> =
unsafe { <MyStruct<i32>>::deserialize_eps(b.as_ref())? };
# Ok(())
# }
Expand Down Expand Up @@ -480,7 +474,7 @@ let e = Enum::B(vec![0, 1, 2, 3]);
let mut file = std::env::temp_dir();
file.push("serialized7");
unsafe { e.store(&file) };
// Deserializing using just Enum will fail, as the type parameter
// Deserializing using just Enum will fail, as the type parameter
// by default is Vec<usize>
assert!(unsafe { <Enum>::load_full(&file) }.is_err());
# Ok(())
Expand Down Expand Up @@ -513,7 +507,7 @@ let t: &[i32] = unsafe { <Vec<i32>>::deserialize_eps(b.as_ref())? };
let t: Vec<i32> = unsafe { <Vec<i32>>::deserialize_full(
&mut std::fs::File::open(&file)?
)? };
let t: MemCase<&[i32]> = unsafe { <Vec<i32>>::mmap(&file, Flags::empty())? };
let t: MemCase<Vec<i32>> = unsafe { <Vec<i32>>::mmap(&file, Flags::empty())? };

// Within a structure
#[derive(Epserde, Debug, PartialEq, Eq, Default, Clone)]
Expand All @@ -532,7 +526,7 @@ let t: Data<&[i32]> = unsafe { <Data<Vec<i32>>>::deserialize_eps(b.as_ref())? };
let t: Data<Vec<i32>> = unsafe { <Data<Vec<i32>>>::deserialize_full(
&mut std::fs::File::open(&file)?
)? };
let t: MemCase<Data<&[i32]>> = unsafe { <Data<Vec<i32>>>::mmap(&file, Flags::empty())? };
let t: MemCase<Data<Vec<i32>>> = unsafe { <Data<Vec<i32>>>::mmap(&file, Flags::empty())? };
# Ok(())
# }
```
Expand Down Expand Up @@ -565,7 +559,7 @@ let t: &[i32] = unsafe { <Vec<i32>>::deserialize_eps(b.as_ref())? };
let t: Vec<i32> = unsafe { <Vec<i32>>::deserialize_full(
&mut std::fs::File::open(&file)?
)? };
let t: MemCase<&[i32]> = unsafe { <Vec<i32>>::mmap(&file, Flags::empty())? };
let t: MemCase<Vec<i32>> = unsafe { <Vec<i32>>::mmap(&file, Flags::empty())? };

// Within a structure
#[derive(Epserde, Debug, PartialEq, Eq, Default, Clone)]
Expand All @@ -584,7 +578,7 @@ let t: Data<&[i32]> = unsafe { <Data<Vec<i32>>>::deserialize_eps(b.as_ref())? };
let t: Data<Vec<i32>> = unsafe { <Data<Vec<i32>>>::deserialize_full(
&mut std::fs::File::open(&file)?
)? };
let t: MemCase<Data<&[i32]>> = unsafe { <Data<Vec<i32>>>::mmap(&file, Flags::empty())? };
let t: MemCase<Data<Vec<i32>>> = unsafe { <Data<Vec<i32>>>::mmap(&file, Flags::empty())? };
# Ok(())
# }
```
Expand Down
12 changes: 8 additions & 4 deletions epserde-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,9 @@ pub fn epserde_derive(input: TokenStream) -> TokenStream {
}
}

// SAFETY: &'epserde_desertype Self is covariant
#[automatically_derived]
impl<#generics_deserialize> epserde::deser::DeserializeInner for #name<#concat_generics> #where_clause_des
unsafe impl<#generics_deserialize> epserde::deser::DeserializeInner for #name<#concat_generics> #where_clause_des
{
unsafe fn _deserialize_full_inner(
backend: &mut impl epserde::deser::ReadWithPos,
Expand Down Expand Up @@ -540,8 +541,9 @@ pub fn epserde_derive(input: TokenStream) -> TokenStream {
}
}

// SAFETY: #name is a struct, so it is covariant
#[automatically_derived]
impl<#generics_deserialize> epserde::deser::DeserializeInner for #name<#concat_generics> #where_clause_des {
unsafe impl<#generics_deserialize> epserde::deser::DeserializeInner for #name<#concat_generics> #where_clause_des {
unsafe fn _deserialize_full_inner(
backend: &mut impl epserde::deser::ReadWithPos,
) -> core::result::Result<Self, epserde::deser::Error> {
Expand Down Expand Up @@ -824,8 +826,9 @@ pub fn epserde_derive(input: TokenStream) -> TokenStream {
}
}

// SAFETY: &'epserde_desertype Self is covariant
#[automatically_derived]
impl<#generics_deserialize> epserde::deser::DeserializeInner for #name<#concat_generics> #where_clause_des {
unsafe impl<#generics_deserialize> epserde::deser::DeserializeInner for #name<#concat_generics> #where_clause_des {
unsafe fn _deserialize_full_inner(
backend: &mut impl epserde::deser::ReadWithPos,
) -> core::result::Result<Self, epserde::deser::Error> {
Expand Down Expand Up @@ -878,8 +881,9 @@ pub fn epserde_derive(input: TokenStream) -> TokenStream {
Ok(())
}
}
// SAFETY: #name is an enum, so it is covariant
#[automatically_derived]
impl<#generics_deserialize> epserde::deser::DeserializeInner for #name<#concat_generics> #where_clause_des {
unsafe impl<#generics_deserialize> epserde::deser::DeserializeInner for #name<#concat_generics> #where_clause_des {
unsafe fn _deserialize_full_inner(
backend: &mut impl epserde::deser::ReadWithPos,
) -> core::result::Result<Self, epserde::deser::Error> {
Expand Down
56 changes: 31 additions & 25 deletions epserde/src/deser/mem_case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
* SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later
*/

use crate::DeserializeInner;
use bitflags::bitflags;
use core::{mem::size_of, ops::Deref};
use core::{fmt, mem::size_of};
use maligned::A64;
use mem_dbg::{MemDbg, MemSize};

Expand Down Expand Up @@ -111,36 +112,41 @@ impl MemBackend {
/// wrapped type, using the no-op [`None`](`MemBackend#variant.None`) variant
/// of [`MemBackend`], so a structure can be [encased](MemCase::encase)
/// almost transparently.
#[derive(Debug, MemDbg, MemSize)]
pub struct MemCase<S>(pub(crate) S, pub(crate) MemBackend);
#[derive(MemDbg, MemSize)]
pub struct MemCase<'a, S: DeserializeInner>(
pub(crate) <S as DeserializeInner>::DeserType<'a>,
pub(crate) MemBackend,
);

impl<S> MemCase<S> {
/// Encases a data structure in a [`MemCase`] with no backend.
pub fn encase(s: S) -> MemCase<S> {
MemCase(s, MemBackend::None)
impl<'a, S: DeserializeInner> fmt::Debug for MemCase<'a, S>
where
<S as DeserializeInner>::DeserType<'a>: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("MemBackend")
.field(&self.0)
.field(&self.1)
.finish()
}
}

unsafe impl<S: Send> Send for MemCase<S> {}
unsafe impl<S: Sync> Sync for MemCase<S> {}

impl<S> Deref for MemCase<S> {
type Target = S;
#[inline(always)]
fn deref(&self) -> &Self::Target {
&self.0
impl<'a, S: DeserializeInner> MemCase<'a, S> {
/// Encases a data structure in a [`MemCase`] with no backend.
pub fn encase(s: <S as DeserializeInner>::DeserType<'a>) -> Self {
MemCase(s, MemBackend::None)
}
}

impl<S> AsRef<S> for MemCase<S> {
#[inline(always)]
fn as_ref(&self) -> &S {
&self.0
pub fn borrow<'this>(&'this self) -> &'this <S as DeserializeInner>::DeserType<'this> {
// SAFETY: 'a outlives 'this, and <S as DeserializeInner>::DeserType is required to be
// covariant (ie. it's a normal structure and not, say, a closure with 'this as argument)
unsafe {
core::mem::transmute::<
&'this <S as DeserializeInner>::DeserType<'a>,
&'this <S as DeserializeInner>::DeserType<'this>,
>(&self.0)
}
}
}

impl<S: Send + Sync> From<S> for MemCase<S> {
fn from(s: S) -> Self {
MemCase::encase(s)
}
}
unsafe impl<'a, S: DeserializeInner + Send> Send for MemCase<'a, S> {}
unsafe impl<'a, S: DeserializeInner + Sync> Sync for MemCase<'a, S> {}
28 changes: 13 additions & 15 deletions epserde/src/deser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ pub type DeserType<'a, T> = <T as DeserializeInner>::DeserType<'a>;
/// incompatible structures using the same code, or cause undefined behavior
/// by loading data with an incorrect alignment.
/// - Memory-mapped files might be modified externally.
/// - [`Self::DeserType`] must be covariant (ie. behave like a structure,
/// not a closure with a generic argument)
pub trait Deserialize: DeserializeInner {
/// Fully deserialize a structure of this type from the given backend.
///
Expand Down Expand Up @@ -99,9 +101,7 @@ pub trait Deserialize: DeserializeInner {
/// # Safety
///
/// See the [trait documentation](Deserialize).
unsafe fn load_mem<'a>(
path: impl AsRef<Path>,
) -> anyhow::Result<MemCase<<Self as DeserializeInner>::DeserType<'a>>> {
unsafe fn load_mem<'a>(path: impl AsRef<Path>) -> anyhow::Result<MemCase<'a, Self>> {
let align_to = align_of::<MemoryAlignment>();
if align_of::<Self>() > align_to {
return Err(Error::AlignmentError.into());
Expand All @@ -111,8 +111,7 @@ pub trait Deserialize: DeserializeInner {
// Round up to u128 size
let capacity = file_len + crate::pad_align_to(file_len, align_to);

let mut uninit: MaybeUninit<MemCase<<Self as DeserializeInner>::DeserType<'_>>> =
MaybeUninit::uninit();
let mut uninit: MaybeUninit<MemCase<'_, Self>> = MaybeUninit::uninit();
let ptr = uninit.as_mut_ptr();

// SAFETY: the entire vector will be filled with data read from the file,
Expand Down Expand Up @@ -170,13 +169,12 @@ pub trait Deserialize: DeserializeInner {
unsafe fn load_mmap<'a>(
path: impl AsRef<Path>,
flags: Flags,
) -> anyhow::Result<MemCase<<Self as DeserializeInner>::DeserType<'a>>> {
) -> anyhow::Result<MemCase<'a, Self>> {
let file_len = path.as_ref().metadata()?.len() as usize;
let mut file = std::fs::File::open(path)?;
let capacity = file_len + crate::pad_align_to(file_len, 16);

let mut uninit: MaybeUninit<MemCase<<Self as DeserializeInner>::DeserType<'_>>> =
MaybeUninit::uninit();
let mut uninit: MaybeUninit<MemCase<'_, Self>> = MaybeUninit::uninit();
let ptr = uninit.as_mut_ptr();

let mut mmap = mmap_rs::MmapOptions::new(capacity)?
Expand Down Expand Up @@ -217,15 +215,11 @@ pub trait Deserialize: DeserializeInner {
///
/// See the [trait documentation](Deserialize) and [mmap's `with_file`'s documentation](mmap_rs::MmapOptions::with_file).
#[cfg(feature = "mmap")]
unsafe fn mmap<'a>(
path: impl AsRef<Path>,
flags: Flags,
) -> anyhow::Result<MemCase<<Self as DeserializeInner>::DeserType<'a>>> {
unsafe fn mmap<'a>(path: impl AsRef<Path>, flags: Flags) -> anyhow::Result<MemCase<'a, Self>> {
let file_len = path.as_ref().metadata()?.len();
let file = std::fs::File::open(path)?;

let mut uninit: MaybeUninit<MemCase<<Self as DeserializeInner>::DeserType<'_>>> =
MaybeUninit::uninit();
let mut uninit: MaybeUninit<MemCase<'_, Self>> = MaybeUninit::uninit();
let ptr = uninit.as_mut_ptr();

let mmap = unsafe {
Expand Down Expand Up @@ -261,7 +255,11 @@ pub trait Deserialize: DeserializeInner {
/// the user from modifying the methods in [`Deserialize`].
///
/// The user should not implement this trait directly, but rather derive it.
pub trait DeserializeInner: Sized {
///
/// # Safety
///
/// See [`Deserialize`]
pub unsafe trait DeserializeInner: Sized {
/// The deserialization type associated with this type. It can be
/// retrieved conveniently with the alias [`DeserType`].
type DeserType<'a>;
Expand Down
2 changes: 1 addition & 1 deletion epserde/src/impls/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl<T: DeepCopy + SerializeInner, const N: usize> SerializeHelper<Deep> for [T;
}
}

impl<T: CopyType + DeserializeInner, const N: usize> DeserializeInner for [T; N]
unsafe impl<T: CopyType + DeserializeInner, const N: usize> DeserializeInner for [T; N]
where
[T; N]: DeserializeHelper<<T as CopyType>::Copy, FullType = [T; N]>,
{
Expand Down
Loading