diff --git a/epserde/src/impls/deref.rs b/epserde/src/impls/deref.rs new file mode 100644 index 0000000..11351b5 --- /dev/null +++ b/epserde/src/impls/deref.rs @@ -0,0 +1,98 @@ +/* + * SPDX-FileCopyrightText: 2025 Inria + * + * SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later + */ + +/*! + +Blanket implementations for references and single-item containers + +*/ + +use crate::prelude::*; +use ser::*; + +macro_rules! impl_ser { + ($type:ty) => { + impl CopyType for $type { + type Copy = ::Copy; + } + + impl TypeHash for $type { + #[inline(always)] + fn type_hash(hasher: &mut impl core::hash::Hasher) { + ::type_hash(hasher) + } + } + + impl AlignHash for $type { + #[inline(always)] + fn align_hash(hasher: &mut impl core::hash::Hasher, offset_of: &mut usize) { + ::align_hash(hasher, offset_of) + } + } + + impl SerializeInner for $type { + type SerType = T; + const IS_ZERO_COPY: bool = ::IS_ZERO_COPY; + const ZERO_COPY_MISMATCH: bool = ::ZERO_COPY_MISMATCH; + + #[inline(always)] + unsafe fn _serialize_inner( + &self, + backend: &mut impl WriteWithNames, + ) -> ser::Result<()> { + ::_serialize_inner(self, backend) + } + } + }; +} + +macro_rules! impl_all { + ($type:ident) => { + impl_ser!($type); + + impl DeserializeInner for $type { + type DeserType<'a> = $type<::DeserType<'a>>; + + #[inline(always)] + unsafe fn _deserialize_full_inner( + backend: &mut impl ReadWithPos, + ) -> deser::Result { + ::_deserialize_full_inner(backend).map($type::new) + } + #[inline(always)] + unsafe fn _deserialize_eps_inner<'a>( + backend: &mut SliceWithPos<'a>, + ) -> deser::Result> { + ::_deserialize_eps_inner(backend).map($type::new) + } + } + }; +} + +impl_ser!(&T); +impl_ser!(&mut T); + +#[cfg(any(feature = "std", feature = "alloc"))] +mod std_impl { + use super::*; + + #[cfg(not(feature = "std"))] + mod imports { + pub use alloc::boxed::Box; + pub use alloc::rc::Rc; + pub use alloc::sync::Arc; + } + #[cfg(feature = "std")] + mod imports { + pub use std::rc::Rc; + pub use std::sync::Arc; + } + use imports::*; + + impl_all!(Box); + impl_all!(Arc); + impl_all!(Rc); +} diff --git a/epserde/src/impls/mod.rs b/epserde/src/impls/mod.rs index b57dadf..934d355 100644 --- a/epserde/src/impls/mod.rs +++ b/epserde/src/impls/mod.rs @@ -14,6 +14,7 @@ and [`DeserializeInner`](crate::deser::DeserializeInner) for standard Rust types pub mod array; pub mod boxed_slice; +pub mod deref; pub mod iter; pub mod prim; pub mod slice; diff --git a/epserde/tests/test_std.rs b/epserde/tests/test_std.rs index 41ab120..03593f8 100644 --- a/epserde/tests/test_std.rs +++ b/epserde/tests/test_std.rs @@ -5,11 +5,22 @@ */ use epserde::prelude::*; +use std::rc::Rc; +use std::sync::Arc; fn test_generic(s: T) where T: Serialize + Deserialize + PartialEq + core::fmt::Debug, for<'a> ::DeserType<'a>: PartialEq + core::fmt::Debug, +{ + test_generic_split::(s, |value| value) +} +fn test_generic_split(s: Ser, deref: impl Fn(&Ser) -> &OwnedSer) +where + Ser: Serialize, + Deser: Deserialize + PartialEq + core::fmt::Debug, + OwnedSer: core::fmt::Debug, + for<'a> ::DeserType<'a>: PartialEq + core::fmt::Debug, { { let mut v = vec![]; @@ -19,11 +30,12 @@ where schema.0.sort_by_key(|a| a.offset); cursor.set_position(0); - let full_copy = unsafe { ::deserialize_full(&mut std::io::Cursor::new(&v)).unwrap() }; - assert_eq!(s, full_copy); + let full_copy = + unsafe { ::deserialize_full(&mut std::io::Cursor::new(&v)).unwrap() }; + assert_eq!(&full_copy, deref(&s)); - let full_copy = unsafe { ::deserialize_eps(&v).unwrap() }; - assert_eq!(full_copy, s); + let full_copy = unsafe { ::deserialize_eps(&v).unwrap() }; + assert_eq!(&full_copy, deref(&s)); let _ = schema.to_csv(); let _ = schema.debug(&v); @@ -34,11 +46,12 @@ where unsafe { s.serialize(&mut cursor).unwrap() }; cursor.set_position(0); - let full_copy = unsafe { ::deserialize_full(&mut std::io::Cursor::new(&v)).unwrap() }; - assert_eq!(s, full_copy); + let full_copy = + unsafe { ::deserialize_full(&mut std::io::Cursor::new(&v)).unwrap() }; + assert_eq!(&full_copy, deref(&s)); - let full_copy = unsafe { ::deserialize_eps(&v).unwrap() }; - assert_eq!(full_copy, s); + let full_copy = unsafe { ::deserialize_eps(&v).unwrap() }; + assert_eq!(&full_copy, deref(&s)); } } @@ -50,3 +63,16 @@ fn test_range() { struct Data(std::ops::Range); test_generic(Data(0..10)); } + +#[test] +fn test_containers() { + test_generic::>(Box::new(10)); + test_generic::>(Arc::new(10)); + test_generic::>(Rc::new(10)); +} + +#[test] +fn test_references() { + test_generic_split::<&i32, i32, i32>(&10, |n| *n); + test_generic_split::<&mut i32, i32, i32>(&mut 10, |n| *n); +}