diff --git a/pgx-macros/src/lib.rs b/pgx-macros/src/lib.rs index 4d035929fc..2a2df75299 100644 --- a/pgx-macros/src/lib.rs +++ b/pgx-macros/src/lib.rs @@ -14,8 +14,12 @@ use std::collections::HashSet; use proc_macro2::Ident; use quote::{quote, ToTokens}; +use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::{parse_macro_input, Attribute, Data, DeriveInput, Item, ItemImpl}; +use syn::{ + parse_macro_input, Attribute, Data, DeriveInput, GenericParam, Item, ItemImpl, Lifetime, + LifetimeDef, Token, +}; use operators::{impl_postgres_eq, impl_postgres_hash, impl_postgres_ord}; use pgx_sql_entity_graph::{ @@ -706,9 +710,13 @@ Optionally accepts the following attributes: * `inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the type. * `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type. +* `custom_serializer`: Define your own implementation of `pgx::datum::Serializer` trait (only for `Serialize/Deserialize`-implementing types) * `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx))] +#[proc_macro_derive( + PostgresType, + attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx, custom_serializer) +)] pub fn postgres_type(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); @@ -740,7 +748,8 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result } } - if args.is_empty() { + // If no in/out parameters are defined + if args.iter().filter(|a| a != &&PostgresTypeAttribute::CustomSerializer).next().is_none() { // assume the user wants us to implement the InOutFuncs args.insert(PostgresTypeAttribute::Default); } @@ -755,6 +764,28 @@ fn impl_postgres_type(ast: DeriveInput) -> syn::Result impl #generics ::pgx::PostgresType for #name #generics { } }); + if !args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs) + && !args.contains(&PostgresTypeAttribute::CustomSerializer) + { + let mut lt_generics = generics.clone(); + let mut de = LifetimeDef::new(Lifetime::new("'de", generics.span())); + let bounds = generics + .params + .iter() + .filter_map(|p| match p { + GenericParam::Type(_) => None, + GenericParam::Const(_) => None, + GenericParam::Lifetime(lt) => Some(lt.clone().lifetime), + }) + .collect::>(); + de.bounds = bounds; + lt_generics.params.insert(0, GenericParam::Lifetime(de)); + stream.extend(quote! { + impl #generics ::pgx::datum::Serializer for #name #generics { } + impl #lt_generics ::pgx::datum::Deserializer<'de> for #name #generics { } + }); + } + // and if we don't have custom inout/funcs, we use the JsonInOutFuncs trait // which implements _in and _out #[pg_extern] functions that just return the type itself if args.contains(&PostgresTypeAttribute::Default) { @@ -931,6 +962,7 @@ enum PostgresTypeAttribute { InOutFuncs, PgVarlenaInOutFuncs, Default, + CustomSerializer, } fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet { @@ -948,6 +980,10 @@ fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet { + categorized_attributes.insert(PostgresTypeAttribute::CustomSerializer); + } + _ => { // we can just ignore attributes we don't understand } @@ -1091,8 +1127,6 @@ pub fn pg_trigger(attrs: TokenStream, input: TokenStream) -> TokenStream { fn wrapped(attrs: TokenStream, input: TokenStream) -> Result { use pgx_sql_entity_graph::{PgTrigger, PgTriggerAttribute}; use syn::parse::Parser; - use syn::punctuated::Punctuated; - use syn::Token; let attributes = Punctuated::::parse_terminated.parse(attrs)?; diff --git a/pgx-tests/src/tests/postgres_type_tests.rs b/pgx-tests/src/tests/postgres_type_tests.rs index 8f1a4ad5b7..a15525ac52 100644 --- a/pgx-tests/src/tests/postgres_type_tests.rs +++ b/pgx-tests/src/tests/postgres_type_tests.rs @@ -8,8 +8,9 @@ Use of this source code is governed by the MIT license that can be found in the */ use core::ffi::CStr; use pgx::prelude::*; -use pgx::{InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, StringInfo}; +use pgx::{Deserializer, InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, Serializer, StringInfo}; use serde::{Deserialize, Serialize}; +use std::io::Write; use std::str::FromStr; #[derive(Copy, Clone, PostgresType)] @@ -152,6 +153,26 @@ pub enum JsonEnumType { E2 { b: f32 }, } +#[derive(Serialize, Deserialize, PostgresType)] +#[custom_serializer] +pub struct CustomSerialized; + +impl Serializer for CustomSerialized { + fn to_writer(&self, mut writer: W) { + writer.write(&[1]).expect("can't write"); + } +} + +impl<'de> Deserializer<'de> for CustomSerialized { + fn from_slice(slice: &'de [u8]) -> Self { + if slice != &[1] { + panic!("wrong type") + } else { + CustomSerialized + } + } +} + #[cfg(any(test, feature = "pg_test"))] #[pgx::pg_schema] mod tests { @@ -159,11 +180,11 @@ mod tests { use crate as pgx_tests; use crate::tests::postgres_type_tests::{ - CustomTextFormatSerializedEnumType, CustomTextFormatSerializedType, JsonEnumType, JsonType, - VarlenaEnumType, VarlenaType, + CustomSerialized, CustomTextFormatSerializedEnumType, CustomTextFormatSerializedType, + JsonEnumType, JsonType, VarlenaEnumType, VarlenaType, }; use pgx::prelude::*; - use pgx::PgVarlena; + use pgx::{varsize_any_exhdr, PgVarlena}; #[pg_test] fn test_mytype() -> Result<(), pgx::spi::Error> { @@ -253,4 +274,26 @@ mod tests { assert!(matches!(result, JsonEnumType::E1 { a } if a == 1.0)); Ok(()) } + + #[pg_test] + fn custom_serializer() { + let datum = CustomSerialized.into_datum().unwrap(); + // Ensure we actually get our custom format, not the default CBOR + unsafe { + let input = datum.cast_mut_ptr(); + let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena); + let len = varsize_any_exhdr(varlena); + assert_eq!(len, 1); + } + } + + #[pg_test] + fn custom_serializer_end_to_end() { + let s = CustomSerialized; + let _ = Spi::get_one_with_args::( + r#"SELECT $1"#, + vec![(PgOid::Custom(CustomSerialized::type_oid()), s.into_datum())], + ) + .unwrap(); + } } diff --git a/pgx/src/datum/mod.rs b/pgx/src/datum/mod.rs index c16792deab..14654c8a35 100644 --- a/pgx/src/datum/mod.rs +++ b/pgx/src/datum/mod.rs @@ -50,6 +50,7 @@ pub use json::*; pub use numeric::{AnyNumeric, Numeric}; use once_cell::sync::Lazy; pub use range::*; +use serde::{Deserialize, Serialize}; use std::any::TypeId; pub use time_stamp::*; pub use time_stamp_with_timezone::*; @@ -57,13 +58,130 @@ pub use time_with_timezone::*; pub use tuples::*; pub use varlena::*; -use crate::PgBox; +use crate::{pg_sys, PgBox, PgMemoryContexts, StringInfo}; use pgx_sql_entity_graph::RustSqlMapping; /// A tagging trait to indicate a user type is also meant to be used by Postgres /// Implemented automatically by `#[derive(PostgresType)]` pub trait PostgresType {} +/// Serializing to datum +/// +/// Default implementation uses CBOR and Varlena +pub trait Serializer: Serialize { + /// Serializes the value to Datum + /// + /// Default implementation wraps the output of `Self::to_writer` into a Varlena + fn serialize(&self) -> pg_sys::Datum { + let mut serialized = StringInfo::new(); + + serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); // reserve space fo the header + self.to_writer(&mut serialized); + + let size = serialized.len() as usize; + let varlena = serialized.into_char_ptr(); + unsafe { + crate::set_varsize(varlena as *mut pg_sys::varlena, size as i32); + } + + (varlena as *const pg_sys::varlena).into() + } + + /// Serializes the value to a writer + /// + /// Default implementation serializes to CBOR + fn to_writer(&self, writer: W) { + serde_cbor::to_writer(writer, &self).expect("failed to encode as CBOR"); + } +} + +/// Deserializing from datum +/// +/// Default implementation uses CBOR and Varlena +pub trait Deserializer<'de>: Deserialize<'de> { + /// Deserializes datum into a value + /// + /// Default implementation assumes datum to be a varlena and uses `Self::from_slice` + /// to deserialize the actual value. + fn deserialize(datum: pg_sys::Datum) -> Self { + unsafe { + let input = datum.cast_mut_ptr(); + let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena); + let len = crate::varsize_any_exhdr(varlena); + let data = crate::vardata_any(varlena); + let slice = std::slice::from_raw_parts(data as *const u8, len); + Self::from_slice(slice) + } + } + + /// Deserializes datum into a value into a given context + /// Default implementation assumes datum to be a varlena and uses `Self::from_slice` + /// to deserialize the actual value. + fn deserialize_into_context( + mut memory_context: PgMemoryContexts, + datum: pg_sys::Datum, + ) -> Self { + unsafe { + memory_context.switch_to(|_| { + let input = datum.cast_mut_ptr(); + // this gets the varlena Datum copied into this memory context + let varlena = pg_sys::pg_detoast_datum_copy(input as *mut pg_sys::varlena); + >::deserialize(varlena.into()) + }) + } + } + + /// Deserializes a value from a slice + /// + /// Default implementation deserializes from CBOR. + fn from_slice(slice: &'de [u8]) -> Self { + serde_cbor::from_slice(slice).expect("failed to decode CBOR") + } +} + +impl IntoDatum for T +where + T: PostgresType + Serializer, +{ + fn into_datum(self) -> Option { + Some(Serializer::serialize(&self)) + } + + fn type_oid() -> pg_sys::Oid { + crate::rust_regtypein::() + } +} + +impl<'de, T> FromDatum for T +where + T: PostgresType + Deserializer<'de>, +{ + unsafe fn from_polymorphic_datum( + datum: pg_sys::Datum, + is_null: bool, + _typoid: pg_sys::Oid, + ) -> Option { + if is_null { + None + } else { + Some(::deserialize(datum)) + } + } + + unsafe fn from_datum_in_memory_context( + memory_context: PgMemoryContexts, + datum: pg_sys::Datum, + is_null: bool, + _typoid: pg_sys::Oid, + ) -> Option { + if is_null { + None + } else { + Some(T::deserialize_into_context(memory_context, datum)) + } + } +} + /// A type which can have it's [`core::any::TypeId`]s registered for Rust to SQL mapping. /// /// An example use of this trait: diff --git a/pgx/src/datum/varlena.rs b/pgx/src/datum/varlena.rs index 6f5549a774..1e3f2a7eeb 100644 --- a/pgx/src/datum/varlena.rs +++ b/pgx/src/datum/varlena.rs @@ -9,15 +9,12 @@ Use of this source code is governed by the MIT license that can be found in the //! Wrapper for Postgres 'varlena' type, over Rust types of a fixed size (ie, `impl Copy`) use crate::pg_sys::{VARATT_SHORT_MAX, VARHDRSZ_SHORT}; use crate::{ - pg_sys, rust_regtypein, set_varsize, set_varsize_short, vardata_any, varsize_any, - varsize_any_exhdr, void_mut_ptr, FromDatum, IntoDatum, PgMemoryContexts, PostgresType, - StringInfo, + pg_sys, rust_regtypein, set_varsize, set_varsize_short, vardata_any, varsize_any, void_mut_ptr, + FromDatum, IntoDatum, PgMemoryContexts, }; -use pgx_pg_sys::varlena; use pgx_sql_entity_graph::metadata::{ ArgumentError, Returns, ReturnsError, SqlMapping, SqlTranslatable, }; -use serde::{Deserialize, Serialize}; use std::borrow::Cow; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; @@ -335,123 +332,6 @@ where } } -impl IntoDatum for T -where - T: PostgresType + Serialize, -{ - fn into_datum(self) -> Option { - Some(cbor_encode(&self).into()) - } - - fn type_oid() -> pg_sys::Oid { - crate::rust_regtypein::() - } -} - -impl<'de, T> FromDatum for T -where - T: PostgresType + Deserialize<'de>, -{ - unsafe fn from_polymorphic_datum( - datum: pg_sys::Datum, - is_null: bool, - _typoid: pg_sys::Oid, - ) -> Option { - if is_null { - None - } else { - cbor_decode(datum.cast_mut_ptr()) - } - } - - unsafe fn from_datum_in_memory_context( - memory_context: PgMemoryContexts, - datum: pg_sys::Datum, - is_null: bool, - _typoid: pg_sys::Oid, - ) -> Option { - if is_null { - None - } else { - cbor_decode_into_context(memory_context, datum.cast_mut_ptr()) - } - } -} - -fn cbor_encode(input: T) -> *const pg_sys::varlena -where - T: Serialize, -{ - let mut serialized = StringInfo::new(); - - serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); // reserve space for the header - serde_cbor::to_writer(&mut serialized, &input).expect("failed to encode as CBOR"); - - let size = serialized.len() as usize; - let varlena = serialized.into_char_ptr(); - unsafe { - set_varsize(varlena as *mut pg_sys::varlena, size as i32); - } - - varlena as *const pg_sys::varlena -} - -pub unsafe fn cbor_decode<'de, T>(input: *mut pg_sys::varlena) -> T -where - T: Deserialize<'de>, -{ - let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena); - let len = varsize_any_exhdr(varlena); - let data = vardata_any(varlena); - let slice = std::slice::from_raw_parts(data as *const u8, len); - serde_cbor::from_slice(slice).expect("failed to decode CBOR") -} - -pub unsafe fn cbor_decode_into_context<'de, T>( - mut memory_context: PgMemoryContexts, - input: *mut pg_sys::varlena, -) -> T -where - T: Deserialize<'de>, -{ - memory_context.switch_to(|_| { - // this gets the varlena Datum copied into this memory context - let varlena = pg_sys::pg_detoast_datum_copy(input as *mut pg_sys::varlena); - cbor_decode(varlena) - }) -} - -#[allow(dead_code)] -fn json_encode(input: T) -> *const varlena -where - T: Serialize, -{ - let mut serialized = StringInfo::new(); - - serialized.push_bytes(&[0u8; pg_sys::VARHDRSZ]); // reserve space for the header - serde_json::to_writer(&mut serialized, &input).expect("failed to encode as JSON"); - - let size = serialized.len() as usize; - let varlena = serialized.into_char_ptr(); - unsafe { - set_varsize(varlena as *mut pg_sys::varlena, size as i32); - } - - varlena as *const pg_sys::varlena -} - -#[allow(dead_code)] -unsafe fn json_decode<'de, T>(input: *mut pg_sys::varlena) -> T -where - T: Deserialize<'de>, -{ - let varlena = pg_sys::pg_detoast_datum_packed(input as *mut pg_sys::varlena); - let len = varsize_any_exhdr(varlena); - let data = vardata_any(varlena); - let slice = std::slice::from_raw_parts(data as *const u8, len); - serde_json::from_slice(slice).expect("failed to decode JSON") -} - unsafe impl SqlTranslatable for PgVarlena where T: SqlTranslatable + Copy,