From 8cc4a635ec8968472bd694e81cc5de1f0ecb848a Mon Sep 17 00:00:00 2001 From: Eric Gillespie Date: Thu, 19 Jan 2023 16:18:59 -0600 Subject: [PATCH] alternative aggregate builder experiment --- crates/aggregate_builder/src/aggregate.rs | 379 ++++++++++++++++++++++ crates/aggregate_builder/src/combine.rs | 172 ++++++++++ crates/aggregate_builder/src/lib.rs | 46 ++- docs/safe_aggregates.md | 111 +++++++ extension/src/aggregate_utils.rs | 70 ++++ extension/src/ohlc.rs | 180 ++++------ extension/src/raw.rs | 2 + 7 files changed, 842 insertions(+), 118 deletions(-) create mode 100644 crates/aggregate_builder/src/aggregate.rs create mode 100644 crates/aggregate_builder/src/combine.rs create mode 100644 docs/safe_aggregates.md diff --git a/crates/aggregate_builder/src/aggregate.rs b/crates/aggregate_builder/src/aggregate.rs new file mode 100644 index 00000000..adb3d196 --- /dev/null +++ b/crates/aggregate_builder/src/aggregate.rs @@ -0,0 +1,379 @@ +use proc_macro::TokenStream; +use quote::quote; +use syn::parse::Parser as _; +use syn::{ + parse::{Parse, ParseStream}, + punctuated::Punctuated, + Token, +}; + +// TODO move to crate rather than duplicating +macro_rules! error { + ($span: expr, $fmt: literal, $($arg:expr),* $(,)?) => { + return Err(syn::Error::new($span, format!($fmt, $($arg),*))) + }; + ($span: expr, $msg: literal) => { + return Err(syn::Error::new($span, $msg)) + }; +} + +/// Parsed representation of the source function we generate from. +#[derive(Debug)] +pub struct SourceFunction { + ident: syn::Ident, + state_parameter: crate::AggregateArg, + extra_parameters: Vec, + return_type: syn::ReturnType, + body: syn::Block, +} +impl Parse for SourceFunction { + fn parse(input: ParseStream) -> syn::Result { + let crate::AggregateFn { + ident, + parens, + args, + ret: return_type, + body, + .. + } = input.parse()?; + let mut iter = args.iter(); + let state_parameter = iter + .next() + .ok_or_else(|| syn::Error::new(parens.span, "state parameter required"))? + .clone(); + let extra_parameters = iter.map(|p| p.clone()).collect(); + Ok(Self { + ident, + state_parameter, + extra_parameters, + return_type, + body, + }) + } +} + +#[derive(Debug)] +pub struct Attributes { + name: syn::Ident, + schema: Option, + immutable: bool, + parallel: Parallel, + strict: bool, + + finalfunc: Option, + combinefunc: Option, + serialfunc: Option, + deserialfunc: Option, +} + +impl Attributes { + pub fn parse(input: TokenStream) -> syn::Result { + let mut aggregate_name = None; + let mut schema = None; + let mut immutable = false; + let mut parallel = Parallel::default(); + let mut strict = false; + let mut finalfunc = None; + let mut combinefunc = None; + let mut serialfunc = None; + let mut deserialfunc = None; + + let parser = Punctuated::::parse_terminated; + for attr in parser.parse2(input.into())?.iter_mut() { + assert!( + !attr.value.is_empty(), + "Attr::Parse should not allow empty attribute value" + ); + let name = attr.name.to_string(); + match name.as_str() { + "name" | "schema" | "immutable" | "parallel" | "strict" => { + if attr.value.len() > 1 { + error!(attr.name.span(), "{} requires simple identifier", name); + } + let value = attr.value.pop().ok_or_else(|| { + syn::Error::new( + attr.name.span(), + format!("{} requires simple identifier", name), + ) + })?; + match name.as_str() { + "name" => aggregate_name = Some(value), + "schema" => schema = Some(value), + "parallel" => { + parallel = match value.to_string().as_str() { + "restricted" => Parallel::Restricted, + "safe" => Parallel::Safe, + "unsafe" => Parallel::Unsafe, + _ => error!(value.span(), "illegal parallel"), + } + } + "immutable" | "strict" => { + let value = match value.to_string().as_str() { + "true" => true, + "false" => false, + _ => { + error!(attr.value[0].span(), "{} requires true or false", name) + } + }; + match name.as_str() { + "immutable" => immutable = value, + "strict" => strict = value, + _ => unreachable!("processing subset here"), + } + } + _ => unreachable!("processing subset here"), + } + } + + "finalfunc" | "combinefunc" | "serialfunc" | "deserialfunc" => { + if attr.value.len() > 2 { + error!( + attr.name.span(), + "{} requires one or two path segments only (`foo` or `foo::bar`)", name + ); + } + let func = { + let name = attr.value.pop().ok_or_else(||syn::Error::new( + attr.name.span(), + format!("{} requires one or two path segments only (`foo` or `foo::bar`)", name) + ))?; + match attr.value.pop() { + None => Func { name, schema: None }, + schema => Func { name, schema }, + } + }; + match name.as_str() { + "finalfunc" => finalfunc = Some(func), + "combinefunc" => combinefunc = Some(func), + "serialfunc" => serialfunc = Some(func), + "deserialfunc" => deserialfunc = Some(func), + _ => unreachable!("processing subset here"), + } + } + _ => error!(attr.name.span(), "unexpected"), + }; + } + let name = aggregate_name + .ok_or_else(|| syn::Error::new(proc_macro2::Span::call_site(), "name required"))?; + Ok(Self { + name, + schema, + immutable, + parallel, + strict, + finalfunc, + combinefunc, + serialfunc, + deserialfunc, + }) + } +} + +#[derive(Debug)] +pub struct Generator { + attributes: Attributes, + schema: Option, + function: SourceFunction, +} + +impl Generator { + pub(crate) fn new(attributes: Attributes, function: SourceFunction) -> syn::Result { + // TODO Default None but `schema=` attribute overrides; or just don't + // support `schema=` and instead require using pg_extern's treating + // enclosing mod as schema. Why have more than one way to do things? + let schema = match &attributes.schema { + Some(schema) => Some(schema.clone()), + None => None, + }; + Ok(Self { + attributes, + schema, + function, + }) + } + + pub fn generate(self) -> proc_macro2::TokenStream { + let Self { + attributes, + schema, + function, + } = self; + + let name = attributes.name.to_string(); + + let transition_fn_name = function.ident; + + // TODO It's redundant to require us to mark every type with its sql + // type. We should do that just once and derive it here. + let mut sql_args = vec![]; + let state_signature = function.state_parameter.rust; + let mut all_arg_signatures = vec![&state_signature]; + let mut extra_arg_signatures = vec![]; + for arg in function.extra_parameters.iter() { + let super::AggregateArg { rust, sql } = arg; + sql_args.push({ + let name = match rust.pat.as_ref() { + syn::Pat::Ident(syn::PatIdent { ident, .. }) => ident, + _ => unreachable!("parsing made this name available"), + }; + format!( + "{} {}", + name, + match sql { + None => unreachable!("parsing made this sql type available"), + Some(sql) => sql.value(), + } + ) + }); + extra_arg_signatures.push(rust); + all_arg_signatures.push(rust); + } + + let ret = function.return_type; + let body = function.body; + + let (sql_schema, pg_extern_schema) = match schema.as_ref() { + None => (String::new(), None), + Some(schema) => { + let schema = schema.to_string(); + (format!("{schema}."), Some(quote!(, schema = #schema))) + } + }; + + let impl_fn_name = syn::Ident::new( + &format!("{}__impl", transition_fn_name), + proc_macro2::Span::call_site(), + ); + + let mut create = format!( + r#"CREATE AGGREGATE {}{}( + {}) +( + stype = internal, + sfunc = {}{}, +"#, + sql_schema, + name, + sql_args.join(",\n "), + sql_schema, + transition_fn_name, + ); + let final_fn_name = attributes + .finalfunc + .map(|func| fmt_agg_func(&mut create, "final", &func)); + let combine_fn_name = attributes + .combinefunc + .map(|func| fmt_agg_func(&mut create, "combine", &func)); + let serial_fn_name = attributes + .serialfunc + .map(|func| fmt_agg_func(&mut create, "serial", &func)); + let deserial_fn_name = attributes + .deserialfunc + .map(|func| fmt_agg_func(&mut create, "deserial", &func)); + let create = format!( + r#"{} + immutable = {}, + parallel = {}, + strict = {});"#, + create, attributes.immutable, attributes.parallel, attributes.strict + ); + + let extension_sql_name = format!("{}_extension_sql", name); + + let name = format!("{}", transition_fn_name); + let name = quote! { name = #name }; + + quote! { + // TODO type checks + + fn #transition_fn_name( + #(#all_arg_signatures,)* + ) #ret { + #body + } + + // TODO derive immutable and parallel_safe from above + #[pgx::pg_extern(#name, immutable, parallel_safe #pg_extern_schema)] + fn #impl_fn_name( + state: crate::palloc::Internal, + #(#extra_arg_signatures,)* + fcinfo: pgx::pg_sys::FunctionCallInfo, + ) -> Option { + // TODO Extract extra_arg_NAMES so we can call directly into transition_fn above rather than duplicate. + let f = |#state_signature| #body; + unsafe { crate::aggregate_utils::transition(state, fcinfo, f) } + } + + pgx::extension_sql!( + #create, + name=#extension_sql_name, + requires = [ + #impl_fn_name, + #final_fn_name + #combine_fn_name + #serial_fn_name + #deserial_fn_name + ], + ); + } + } +} + +fn fmt_agg_func(create: &mut String, funcprefix: &str, func: &Func) -> proc_macro2::TokenStream { + create.push_str(&format!(" {}func = ", funcprefix)); + if let Some(schema) = func.schema.as_ref() { + create.push_str(&format!("{}.", schema)); + } + create.push_str(&format!("{},\n", func.name)); + let name = &func.name; + quote! { #name, } +} + +#[derive(Debug)] +enum Parallel { + Unsafe, + Restricted, + Safe, +} +impl Default for Parallel { + fn default() -> Self { + Self::Unsafe + } +} +impl std::fmt::Display for Parallel { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str(match self { + Self::Unsafe => "unsafe", + Self::Restricted => "restricted", + Self::Safe => "safe", + }) + } +} + +#[derive(Debug)] +struct Attr { + name: syn::Ident, + value: Vec, +} +impl Parse for Attr { + fn parse(input: ParseStream) -> syn::Result { + let name = input.parse()?; + let _: Token![=] = input.parse()?; + let path: syn::Path = input.parse()?; + let value; + match path.segments.iter().collect::>().as_slice() { + [syn::PathSegment { ident, .. }] => value = vec![ident.clone()], + [schema, ident] => { + value = vec![schema.ident.clone(), ident.ident.clone()]; + } + what => todo!("hmm got {:?}", what), + } + Ok(Self { name, value }) + } +} + +#[derive(Debug)] +struct Func { + name: syn::Ident, + schema: Option, +} diff --git a/crates/aggregate_builder/src/combine.rs b/crates/aggregate_builder/src/combine.rs new file mode 100644 index 00000000..a8ab576c --- /dev/null +++ b/crates/aggregate_builder/src/combine.rs @@ -0,0 +1,172 @@ +use quote::quote; +use syn::parse::{Parse, ParseStream}; +use syn::spanned::Spanned as _; + +// TODO move to crate rather than duplicating +macro_rules! error { + ($span: expr, $fmt: literal, $($arg:expr),* $(,)?) => { + return Err(syn::Error::new($span, format!($fmt, $($arg),*))) + }; + ($span: expr, $msg: literal) => { + return Err(syn::Error::new($span, $msg)) + }; +} + +/// Parsed representation of the source function we generate from. +#[derive(Debug)] +pub struct SourceFunction { + ident: syn::Ident, + parameters: Vec, + return_type: syn::ReturnType, + body: syn::Block, +} +impl Parse for SourceFunction { + fn parse(input: ParseStream) -> syn::Result { + let crate::AggregateFn { + ident, + parens, + args, + ret: return_type, + body, + .. + } = input.parse()?; + + if args.len() != 2 { + error!( + parens.span, + "combine function must take exactly two parameters of type `Option<&T>`" + ) + } + let state_type = get_state_type(&args[0])?; + let state_type2 = get_state_type(&args[1])?; + if state_type2 != state_type { + error!( + args[1].rust.span(), + "mismatched state types {} vs. {}", state_type, state_type2 + ) + } + + let parameters = args.iter().map(|p| p.clone()).collect(); + + Ok(Self { + ident, + parameters, + return_type, + body, + }) + } +} + +pub struct Generator { + schema: Option, + function: SourceFunction, +} + +impl Generator { + pub(crate) fn new( + _attributes: syn::AttributeArgs, + function: SourceFunction, + ) -> syn::Result { + // TODO Default None but `schema=` attribute overrides; or just don't + // support `schema=` and instead require using pg_extern's treating + // enclosing mod as schema. Why have more than one way to do things? + let schema = Some(syn::Ident::new( + "toolkit_experimental", + function.ident.span(), + )); + + Ok(Self { schema, function }) + } + + pub fn generate(self) -> proc_macro2::TokenStream { + let Self { schema, function } = self; + + let fn_name = function.ident; + + let impl_fn_name = syn::Ident::new( + &format!("{}__impl", fn_name), + proc_macro2::Span::call_site(), + ); + + let inner_arg_signatures = function.parameters.iter().map(|arg| &arg.rust); + + let ret = function.return_type; + let body = function.body; + + // TODO default to this but `name=` attribute overrides + let name = format!("{}", fn_name); + let name = quote!(, name = #name); + + let schema = schema.as_ref().map(|s| { + let s = format!("{}", s); + quote!(, schema = #s) + }); + + quote! { + fn #impl_fn_name( + #(#inner_arg_signatures,)* + ) #ret { + #body + } + + #[::pgx::pg_extern(immutable, parallel_safe #name #schema)] + fn #fn_name( + state1: crate::palloc::Internal, + state2: crate::palloc::Internal, + fcinfo: pgx::pg_sys::FunctionCallInfo + ) -> Option { + unsafe { + crate::aggregate_utils::combine( + state1, + state2, + fcinfo, + #impl_fn_name, + ) + } + } + } + } +} + +fn get_state_type(arg: &crate::AggregateArg) -> syn::Result<&syn::Ident> { + match arg.rust.ty.as_ref() { + syn::Type::Path(path) => { + // TODO want `match path.path.segments.as_slice() { [segment] => ...` but they don't have as_slice :( + match path.path.segments.iter().collect::>().as_slice() { + [segment] => { + // TODO This erroneously accepts local types also named Option. + if segment.ident.to_string() == "Option" { + match &segment.arguments { + syn::PathArguments::AngleBracketed(arguments) => { + match arguments.args.iter().collect::>().as_slice() { + [syn::GenericArgument::Type(syn::Type::Reference( + syn::TypeReference { elem, .. }, + ))] => match elem.as_ref() { + syn::Type::Path(path) => { + match path + .path + .segments + .iter() + .collect::>() + .as_slice() + { + [segment] => return Ok(&segment.ident), + _ => {} + } + } + _ => {} + }, + _ => {} + } + } + _ => {} + } + } + } + _ => {} + } + } + _ => {} + } + error!(arg.rust.span(), "parameters must be Option<&T>") +} diff --git a/crates/aggregate_builder/src/lib.rs b/crates/aggregate_builder/src/lib.rs index 4703dc16..e3e19085 100644 --- a/crates/aggregate_builder/src/lib.rs +++ b/crates/aggregate_builder/src/lib.rs @@ -15,6 +15,48 @@ use syn::{ Token, }; +mod aggregate; +mod combine; + +#[proc_macro_attribute] +pub fn aggregate2(attr: TokenStream, item: TokenStream) -> TokenStream { + let attributes; + match aggregate::Attributes::parse(attr) { + Err(e) => return TokenStream::from(e.to_compile_error()), + Ok(value) => attributes = value, + } + let generator; + match aggregate::Generator::new( + attributes, + parse_macro_input!(item as aggregate::SourceFunction), + ) { + Err(e) => return TokenStream::from(e.to_compile_error()), + Ok(value) => generator = value, + } + let generated = generator.generate(); + if cfg!(feature = "print-generated") { + println!("{}", generated); + } + generated.into() +} + +#[proc_macro_attribute] +pub fn combine(attr: TokenStream, item: TokenStream) -> TokenStream { + let generator; + match combine::Generator::new( + parse_macro_input!(attr as syn::AttributeArgs), + parse_macro_input!(item as combine::SourceFunction), + ) { + Err(e) => return TokenStream::from(e.to_compile_error()), + Ok(value) => generator = value, + } + let generated = generator.generate(); + if cfg!(feature = "print-generated") { + println!("{}", generated); + } + generated.into() +} + #[proc_macro_attribute] pub fn aggregate(_attr: TokenStream, item: TokenStream) -> TokenStream { // Parse the input tokens into a syntax tree @@ -32,7 +74,7 @@ pub fn aggregate(_attr: TokenStream, item: TokenStream) -> TokenStream { // like ItemImpl except that we allow `name: Type "SqlType"` for `fn transition` struct Aggregate { - schema: Option, + schema: std::option::Option, name: syn::Ident, state_ty: AggregateTy, @@ -72,7 +114,7 @@ struct AggregateFn { fcinfo: Option, } -#[derive(Clone)] +#[derive(Clone, Debug)] struct AggregateArg { rust: syn::PatType, sql: Option, diff --git a/docs/safe_aggregates.md b/docs/safe_aggregates.md new file mode 100644 index 00000000..35f08678 --- /dev/null +++ b/docs/safe_aggregates.md @@ -0,0 +1,111 @@ +# Building aggregates safely + +## Goals + +1. Memory Safety: no memory corruption, which can lead to corrupted results or worse. +2. Correctness: mostly down to business logic, but to the extent a framework can help or hinder, at least do not hinder. +3. Robustness: crashes are not as bad as incorrect results, but still undesirable. +4. Performance: correct results returned quickly and without excessive resource consumption. +5. Developer productivity + +We chose Rust because it gives us powerful tools to meet all 4 goals. +Unfortunately, we are not yet taking advantage of those tools. + +Most of our aggregates naively call into unsafe code without any checks that +their invariants aren't invalidated. + +## Next steps + +I estimate two or three more weeks of effort to finish off the macros, plus +two hours or so of toil to convert each aggregate. + +I suggest aggressively attacking experimental aggregates, and then converting +just one stabilized aggregate in a release before proceeding further, +out of an abundance of caution. + +0. While at least some of us (me!) weren't looking, pgx added a new trait that + may address some or all of our goals. Evaluate that first. +1. Adapt at least a few more of our existing experimental aggregates to use + the two new macros, in case that should turn up any show-stoppers. +2. Build out the rest of the macros: + - finalfunc + - serializefunc + - deserializefunc +3. Finish building `aggregate` and `combine` macros: + - support name override (vs. default of rust fn name) + - support schema (haven't tested the non-schema case and hard-coded `toolkit_experimental` in one place) + - immutable and parallel_safe (currently parsed but ignored) + - copy the missing features from aggregate_builder (type assertions, test counters) + - eliminate #body duplicate in aggregate.rs (see TODO) + - fix bug about accepting any type named `Option` (require `std::option::Option`) + - address clippy's complaints and other code cleanup +4. Nice to haves: + - get rid of `#[sql_type]` + - unduplicate the error! macro + - tidy attribute-parsing error-handling (it's not wrong, just messy) + +## Examples + +### Illegal mutation + +The PostgreSQL manual includes this big warning: + + Never modify the contents of a pass-by-reference input value. If you + do so you are likely to corrupt on-disk data + +Rust lets us express that in the type system such that code attempting to +modify that input does not compile. + +Yet we pass those raw references into our business logic without such protection. + +The ohlc bug was the inevitable result. + +### Accidental unsafe + +The primitives we currently use to build our aggregates encourage including +large blocks of code in unsafe blocks. They don't require it; it is possible +to separate them. But that's going against the grain. + +In some cases we have business logic of high cyclomatic complexity and dozens +of lines all inside unsafe blocks. + +Addressing this doesn't require building new primitives, but if we are, we +need to get this part right, too. + +### Invalid cast + +[I THINK pgx is able to put enough type information into the `CREATE FUNCTION` +and `CREATE AGGREGATE` such that postgresl can prevent this. I THINK. +It still makes my hair stand on end, and many security disasters can be traced +back to "it's probably fine"... and it wasn't.] + +This class of bug seems likely to be a developer productivity issue and less +likely to manifest in production, except our test coverage is not great and we +may ship a variant of an aggregate that isn't tested. + +In any case: we want a clear compiler error, not a mysterious crash (if we're +lucky) or mysteriously corrupt data (if we're unlucky) in testing. + +What happens here is we implement a function accepting `pgx::Internal` and +then cast it to our internal type, and then later we bundle it with `CREATE +AGGREGATE` without any assurance that the types match: + +```rust +fn foo_transition(state: pgx::Internal, value: Option, fcinfo: pg_sys::FunctionCallInfo) { + let state: Option> = unsafe { state.to_inner() }; + // ... +} + +fn bar_final(state: pgx::Internal, fcinfo: pg_sys::FunctionCallInfo) -> Option { + let state = Option> = unsafe { state.to_inner() }; + // ... +} +``` + +```sql +CREATE AGGREGATE foo() ( + sfunc = foo_transition, + stype = internal, + finalfunc = bar_final, +); +``` diff --git a/extension/src/aggregate_utils.rs b/extension/src/aggregate_utils.rs index 2eb51b24..0cda3506 100644 --- a/extension/src/aggregate_utils.rs +++ b/extension/src/aggregate_utils.rs @@ -2,6 +2,9 @@ use std::ptr::null_mut; use pgx::pg_sys; +use crate::palloc::InternalAsValue as _; +use crate::palloc::ToInternal as _; + // TODO move to func_utils once there are enough function to warrant one pub unsafe fn get_collation(fcinfo: pg_sys::FunctionCallInfo) -> Option { if (*fcinfo).fncollation == 0 { @@ -11,10 +14,75 @@ pub unsafe fn get_collation(fcinfo: pg_sys::FunctionCallInfo) -> Option, Option<&State>) -> Option>( + state1: pgx::Internal, + state2: pgx::Internal, + fcinfo: pg_sys::FunctionCallInfo, + f: F, +) -> Option { + unsafe_combine(state1, state2, fcinfo, f) +} + +fn unsafe_combine, Option<&State>) -> Option>( + state1: pgx::Internal, + state2: pgx::Internal, + fcinfo: pg_sys::FunctionCallInfo, + f: F, +) -> Option { + let state1 = unsafe { state1.to_inner() }; + let state2 = unsafe { state2.to_inner() }; + let state1 = match &state1 { + None => None, + Some(inner) => Some(&**inner), + }; + let state2 = match &state2 { + None => None, + Some(inner) => Some(&**inner), + }; + let f = || f(state1, state2); + unsafe { in_aggregate_context(fcinfo, f) } + .map(|internal| internal.into()) + .internal() +} + +pub unsafe fn transition) -> Option>( + state: pgx::Internal, + fcinfo: pg_sys::FunctionCallInfo, + f: F, +) -> Option { + unsafe_transition(state, fcinfo, f) +} + +fn unsafe_transition) -> Option>( + state: pgx::Internal, + fcinfo: pg_sys::FunctionCallInfo, + f: F, +) -> Option { + let mut inner = unsafe { state.to_inner() }; + let state: Option = match &mut inner { + None => None, + Some(inner) => Option::take(&mut **inner), + }; + let f = || { + let result: Option = f(state); + inner = match (inner, result) { + (None, None) => None, + (None, result @ Some(..)) => Some(result.into()), + (Some(mut inner), result) => { + *inner = result; + Some(inner) + } + }; + inner.internal() + }; + unsafe { in_aggregate_context(fcinfo, f) } +} + pub unsafe fn in_aggregate_context T>( fcinfo: pg_sys::FunctionCallInfo, f: F, ) -> T { + // TODO Is this unsafe for any reason other than "all FFI is unsafe"? let mctx = aggregate_mctx(fcinfo).unwrap_or_else(|| pgx::error!("cannot call as non-aggregate")); crate::palloc::in_memory_context(mctx, f) @@ -22,9 +90,11 @@ pub unsafe fn in_aggregate_context T>( pub unsafe fn aggregate_mctx(fcinfo: pg_sys::FunctionCallInfo) -> Option { if fcinfo.is_null() { + // TODO Is this unsafe for any reason other than "all FFI is unsafe"? return Some(pg_sys::CurrentMemoryContext); } let mut mctx = null_mut(); + // TODO Is this unsafe for any reason other than "all FFI is unsafe"? let is_aggregate = pg_sys::AggCheckCallContext(fcinfo, &mut mctx); if is_aggregate == 0 { None diff --git a/extension/src/ohlc.rs b/extension/src/ohlc.rs index 02edf34a..2dee93a4 100644 --- a/extension/src/ohlc.rs +++ b/extension/src/ohlc.rs @@ -1,3 +1,4 @@ +#![allow(non_snake_case)] use pgx::*; use serde::{Deserialize, Serialize}; @@ -215,51 +216,54 @@ pub fn candlestick( } } -#[pg_extern(immutable, parallel_safe, schema = "toolkit_experimental")] -pub fn tick_data_no_vol_transition( - state: Internal, - ts: Option, - price: Option, - fcinfo: pg_sys::FunctionCallInfo, -) -> Option { - tick_data_transition_inner(unsafe { state.to_inner() }, ts, price, None, fcinfo).internal() -} - -#[pg_extern(immutable, parallel_safe, schema = "toolkit_experimental")] -pub fn tick_data_transition( - state: Internal, - ts: Option, - price: Option, - volume: Option, - fcinfo: pg_sys::FunctionCallInfo, -) -> Option { - tick_data_transition_inner(unsafe { state.to_inner() }, ts, price, volume, fcinfo).internal() +#[aggregate_builder::aggregate2( + name = ohlc, + // TODO Can we derive namespace from mod? + finalfunc = toolkit_experimental::candlestick_final, + combinefunc = toolkit_experimental::candlestick_combine, + serialfunc = toolkit_experimental::candlestick_serialize, + deserialfunc = toolkit_experimental::candlestick_deserialize, + parallel = safe, + schema = toolkit_experimental, +)] +fn tick_data_no_vol_transition( + // TODO Teach AggregateFn parser to handle lifetime generics and change this + // back to a scoped lifetime ('input), not static. + state: Option>, + #[sql_type("timestamptz")] ts: Option, + #[sql_type("double precision")] price: Option, +) -> Option> { + tick_data_transition(state, ts, price, None) } -pub fn tick_data_transition_inner( - state: Option>, - ts: Option, - price: Option, - volume: Option, - fcinfo: pg_sys::FunctionCallInfo, -) -> Option> { - unsafe { - in_aggregate_context(fcinfo, || { - if let (Some(ts), Some(price)) = (ts, price) { - match state { - None => { - let cs = Candlestick::from_tick(ts.into(), price, volume); - Some(cs.into()) - } - Some(mut cs) => { - cs.add_tick_data(ts.into(), price, volume); - Some(cs) - } - } - } else { - state +#[aggregate_builder::aggregate2( + name = candlestick_agg, + schema = toolkit_experimental, + finalfunc = toolkit_experimental::candlestick_final, + combinefunc = toolkit_experimental::candlestick_combine, + serialfunc = toolkit_experimental::candlestick_serialize, + deserialfunc = toolkit_experimental::candlestick_deserialize, + parallel = safe, +)] +fn tick_data_transition( + state: Option>, + #[sql_type("timestamptz")] ts: Option, + #[sql_type("double precision")] price: Option, + #[sql_type("double precision")] volume: Option, +) -> Option> { + if let (Some(ts), Some(price)) = (ts, price) { + match state { + None => { + let cs = Candlestick::from_tick(ts.into(), price, volume); + Some(cs.into()) } - }) + Some(mut cs) => { + cs.add_tick_data(ts.into(), price, volume); + Some(cs) + } + } + } else { + state } } @@ -295,7 +299,7 @@ pub fn candlestick_final( state: Internal, fcinfo: pg_sys::FunctionCallInfo, ) -> Option> { - unsafe { candlestick_final_inner(state.to_inner(), fcinfo) } + candlestick_final_inner(unsafe { state.to_inner() }, fcinfo) } pub fn candlestick_final_inner( @@ -313,36 +317,27 @@ pub fn candlestick_final_inner( } } -#[pg_extern(immutable, parallel_safe, schema = "toolkit_experimental")] -pub fn candlestick_combine( - state1: Internal, - state2: Internal, - fcinfo: pg_sys::FunctionCallInfo, -) -> Option { - unsafe { candlestick_combine_inner(state1.to_inner(), state2.to_inner(), fcinfo).internal() } -} - -pub fn candlestick_combine_inner<'input>( - state1: Option>>, - state2: Option>>, - fcinfo: pg_sys::FunctionCallInfo, -) -> Option>> { - unsafe { - in_aggregate_context(fcinfo, || match (state1, state2) { - (None, None) => None, - (None, Some(only)) | (Some(only), None) => Some((*only).into()), - (Some(a), Some(b)) => { - let (mut a, b) = (*a, *b); - a.combine(&b); - Some(a.into()) - } - }) +#[aggregate_builder::combine(immutable, parallel_safe, schema = "toolkit_experimental")] +fn candlestick_combine( + // TODO Teach AggregateFn parser to handle lifetime generics and change this + // back to a scoped lifetime ('input), not static. + state1: Option<&Candlestick<'static>>, + state2: Option<&Candlestick<'static>>, +) -> Option> { + match (state1, state2) { + (None, None) => None, + (None, Some(only)) | (Some(only), None) => Some(*only), + (Some(a), Some(b)) => { + let (mut a, b) = (*a, *b); + a.combine(&b); + Some(a) + } } } #[pg_extern(immutable, parallel_safe, strict, schema = "toolkit_experimental")] pub fn candlestick_serialize(state: Internal) -> bytea { - let cs: &mut Candlestick = unsafe { state.get_mut().unwrap() }; + let cs: &Candlestick = unsafe { state.get() }.unwrap(); let ser = &**cs; crate::do_serialize!(ser) } @@ -359,54 +354,7 @@ pub fn candlestick_deserialize_inner(bytes: bytea) -> Inner cs.into() } -extension_sql!( - "\n\ - CREATE AGGREGATE toolkit_experimental.ohlc( ts timestamptz, price DOUBLE PRECISION )\n\ - (\n\ - sfunc = toolkit_experimental.tick_data_no_vol_transition,\n\ - stype = internal,\n\ - finalfunc = toolkit_experimental.candlestick_final,\n\ - combinefunc = toolkit_experimental.candlestick_combine,\n\ - serialfunc = toolkit_experimental.candlestick_serialize,\n\ - deserialfunc = toolkit_experimental.candlestick_deserialize,\n\ - parallel = safe\n\ - );\n", - name = "ohlc", - requires = [ - tick_data_no_vol_transition, - candlestick_final, - candlestick_combine, - candlestick_serialize, - candlestick_deserialize - ], -); - -extension_sql!( - "\n\ - CREATE AGGREGATE toolkit_experimental.candlestick_agg( \n\ - ts TIMESTAMPTZ,\n\ - price DOUBLE PRECISION,\n\ - volume DOUBLE PRECISION\n\ - )\n\ - (\n\ - sfunc = toolkit_experimental.tick_data_transition,\n\ - stype = internal,\n\ - finalfunc = toolkit_experimental.candlestick_final,\n\ - combinefunc = toolkit_experimental.candlestick_combine,\n\ - serialfunc = toolkit_experimental.candlestick_serialize,\n\ - deserialfunc = toolkit_experimental.candlestick_deserialize,\n\ - parallel = safe\n\ - );\n", - name = "candlestick_agg", - requires = [ - tick_data_transition, - candlestick_final, - candlestick_combine, - candlestick_serialize, - candlestick_deserialize - ], -); - +// TODO Automate generation of this too. extension_sql!( "\n\ CREATE AGGREGATE toolkit_experimental.rollup( candlestick toolkit_experimental.Candlestick)\n\ diff --git a/extension/src/raw.rs b/extension/src/raw.rs index 099b4b67..b990e165 100644 --- a/extension/src/raw.rs +++ b/extension/src/raw.rs @@ -24,6 +24,8 @@ extension_sql!( bootstrap, ); +// What's this about? I don't think we've been migrating anything. If this +// raw_type stuff is old n busted, let's bite the bullet now. // TODO temporary holdover types while we migrate from nominal types to actual macro_rules! raw_type {