Skip to content

Commit 8cc4a63

Browse files
committed
alternative aggregate builder experiment
1 parent 218c099 commit 8cc4a63

File tree

7 files changed

+842
-118
lines changed

7 files changed

+842
-118
lines changed
Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
use proc_macro::TokenStream;
2+
use quote::quote;
3+
use syn::parse::Parser as _;
4+
use syn::{
5+
parse::{Parse, ParseStream},
6+
punctuated::Punctuated,
7+
Token,
8+
};
9+
10+
// TODO move to crate rather than duplicating
11+
macro_rules! error {
12+
($span: expr, $fmt: literal, $($arg:expr),* $(,)?) => {
13+
return Err(syn::Error::new($span, format!($fmt, $($arg),*)))
14+
};
15+
($span: expr, $msg: literal) => {
16+
return Err(syn::Error::new($span, $msg))
17+
};
18+
}
19+
20+
/// Parsed representation of the source function we generate from.
21+
#[derive(Debug)]
22+
pub struct SourceFunction {
23+
ident: syn::Ident,
24+
state_parameter: crate::AggregateArg,
25+
extra_parameters: Vec<crate::AggregateArg>,
26+
return_type: syn::ReturnType,
27+
body: syn::Block,
28+
}
29+
impl Parse for SourceFunction {
30+
fn parse(input: ParseStream) -> syn::Result<Self> {
31+
let crate::AggregateFn {
32+
ident,
33+
parens,
34+
args,
35+
ret: return_type,
36+
body,
37+
..
38+
} = input.parse()?;
39+
let mut iter = args.iter();
40+
let state_parameter = iter
41+
.next()
42+
.ok_or_else(|| syn::Error::new(parens.span, "state parameter required"))?
43+
.clone();
44+
let extra_parameters = iter.map(|p| p.clone()).collect();
45+
Ok(Self {
46+
ident,
47+
state_parameter,
48+
extra_parameters,
49+
return_type,
50+
body,
51+
})
52+
}
53+
}
54+
55+
#[derive(Debug)]
56+
pub struct Attributes {
57+
name: syn::Ident,
58+
schema: Option<syn::Ident>,
59+
immutable: bool,
60+
parallel: Parallel,
61+
strict: bool,
62+
63+
finalfunc: Option<Func>,
64+
combinefunc: Option<Func>,
65+
serialfunc: Option<Func>,
66+
deserialfunc: Option<Func>,
67+
}
68+
69+
impl Attributes {
70+
pub fn parse(input: TokenStream) -> syn::Result<Self> {
71+
let mut aggregate_name = None;
72+
let mut schema = None;
73+
let mut immutable = false;
74+
let mut parallel = Parallel::default();
75+
let mut strict = false;
76+
let mut finalfunc = None;
77+
let mut combinefunc = None;
78+
let mut serialfunc = None;
79+
let mut deserialfunc = None;
80+
81+
let parser = Punctuated::<Attr, Token![,]>::parse_terminated;
82+
for attr in parser.parse2(input.into())?.iter_mut() {
83+
assert!(
84+
!attr.value.is_empty(),
85+
"Attr::Parse should not allow empty attribute value"
86+
);
87+
let name = attr.name.to_string();
88+
match name.as_str() {
89+
"name" | "schema" | "immutable" | "parallel" | "strict" => {
90+
if attr.value.len() > 1 {
91+
error!(attr.name.span(), "{} requires simple identifier", name);
92+
}
93+
let value = attr.value.pop().ok_or_else(|| {
94+
syn::Error::new(
95+
attr.name.span(),
96+
format!("{} requires simple identifier", name),
97+
)
98+
})?;
99+
match name.as_str() {
100+
"name" => aggregate_name = Some(value),
101+
"schema" => schema = Some(value),
102+
"parallel" => {
103+
parallel = match value.to_string().as_str() {
104+
"restricted" => Parallel::Restricted,
105+
"safe" => Parallel::Safe,
106+
"unsafe" => Parallel::Unsafe,
107+
_ => error!(value.span(), "illegal parallel"),
108+
}
109+
}
110+
"immutable" | "strict" => {
111+
let value = match value.to_string().as_str() {
112+
"true" => true,
113+
"false" => false,
114+
_ => {
115+
error!(attr.value[0].span(), "{} requires true or false", name)
116+
}
117+
};
118+
match name.as_str() {
119+
"immutable" => immutable = value,
120+
"strict" => strict = value,
121+
_ => unreachable!("processing subset here"),
122+
}
123+
}
124+
_ => unreachable!("processing subset here"),
125+
}
126+
}
127+
128+
"finalfunc" | "combinefunc" | "serialfunc" | "deserialfunc" => {
129+
if attr.value.len() > 2 {
130+
error!(
131+
attr.name.span(),
132+
"{} requires one or two path segments only (`foo` or `foo::bar`)", name
133+
);
134+
}
135+
let func = {
136+
let name = attr.value.pop().ok_or_else(||syn::Error::new(
137+
attr.name.span(),
138+
format!("{} requires one or two path segments only (`foo` or `foo::bar`)", name)
139+
))?;
140+
match attr.value.pop() {
141+
None => Func { name, schema: None },
142+
schema => Func { name, schema },
143+
}
144+
};
145+
match name.as_str() {
146+
"finalfunc" => finalfunc = Some(func),
147+
"combinefunc" => combinefunc = Some(func),
148+
"serialfunc" => serialfunc = Some(func),
149+
"deserialfunc" => deserialfunc = Some(func),
150+
_ => unreachable!("processing subset here"),
151+
}
152+
}
153+
_ => error!(attr.name.span(), "unexpected"),
154+
};
155+
}
156+
let name = aggregate_name
157+
.ok_or_else(|| syn::Error::new(proc_macro2::Span::call_site(), "name required"))?;
158+
Ok(Self {
159+
name,
160+
schema,
161+
immutable,
162+
parallel,
163+
strict,
164+
finalfunc,
165+
combinefunc,
166+
serialfunc,
167+
deserialfunc,
168+
})
169+
}
170+
}
171+
172+
#[derive(Debug)]
173+
pub struct Generator {
174+
attributes: Attributes,
175+
schema: Option<syn::Ident>,
176+
function: SourceFunction,
177+
}
178+
179+
impl Generator {
180+
pub(crate) fn new(attributes: Attributes, function: SourceFunction) -> syn::Result<Self> {
181+
// TODO Default None but `schema=` attribute overrides; or just don't
182+
// support `schema=` and instead require using pg_extern's treating
183+
// enclosing mod as schema. Why have more than one way to do things?
184+
let schema = match &attributes.schema {
185+
Some(schema) => Some(schema.clone()),
186+
None => None,
187+
};
188+
Ok(Self {
189+
attributes,
190+
schema,
191+
function,
192+
})
193+
}
194+
195+
pub fn generate(self) -> proc_macro2::TokenStream {
196+
let Self {
197+
attributes,
198+
schema,
199+
function,
200+
} = self;
201+
202+
let name = attributes.name.to_string();
203+
204+
let transition_fn_name = function.ident;
205+
206+
// TODO It's redundant to require us to mark every type with its sql
207+
// type. We should do that just once and derive it here.
208+
let mut sql_args = vec![];
209+
let state_signature = function.state_parameter.rust;
210+
let mut all_arg_signatures = vec![&state_signature];
211+
let mut extra_arg_signatures = vec![];
212+
for arg in function.extra_parameters.iter() {
213+
let super::AggregateArg { rust, sql } = arg;
214+
sql_args.push({
215+
let name = match rust.pat.as_ref() {
216+
syn::Pat::Ident(syn::PatIdent { ident, .. }) => ident,
217+
_ => unreachable!("parsing made this name available"),
218+
};
219+
format!(
220+
"{} {}",
221+
name,
222+
match sql {
223+
None => unreachable!("parsing made this sql type available"),
224+
Some(sql) => sql.value(),
225+
}
226+
)
227+
});
228+
extra_arg_signatures.push(rust);
229+
all_arg_signatures.push(rust);
230+
}
231+
232+
let ret = function.return_type;
233+
let body = function.body;
234+
235+
let (sql_schema, pg_extern_schema) = match schema.as_ref() {
236+
None => (String::new(), None),
237+
Some(schema) => {
238+
let schema = schema.to_string();
239+
(format!("{schema}."), Some(quote!(, schema = #schema)))
240+
}
241+
};
242+
243+
let impl_fn_name = syn::Ident::new(
244+
&format!("{}__impl", transition_fn_name),
245+
proc_macro2::Span::call_site(),
246+
);
247+
248+
let mut create = format!(
249+
r#"CREATE AGGREGATE {}{}(
250+
{})
251+
(
252+
stype = internal,
253+
sfunc = {}{},
254+
"#,
255+
sql_schema,
256+
name,
257+
sql_args.join(",\n "),
258+
sql_schema,
259+
transition_fn_name,
260+
);
261+
let final_fn_name = attributes
262+
.finalfunc
263+
.map(|func| fmt_agg_func(&mut create, "final", &func));
264+
let combine_fn_name = attributes
265+
.combinefunc
266+
.map(|func| fmt_agg_func(&mut create, "combine", &func));
267+
let serial_fn_name = attributes
268+
.serialfunc
269+
.map(|func| fmt_agg_func(&mut create, "serial", &func));
270+
let deserial_fn_name = attributes
271+
.deserialfunc
272+
.map(|func| fmt_agg_func(&mut create, "deserial", &func));
273+
let create = format!(
274+
r#"{}
275+
immutable = {},
276+
parallel = {},
277+
strict = {});"#,
278+
create, attributes.immutable, attributes.parallel, attributes.strict
279+
);
280+
281+
let extension_sql_name = format!("{}_extension_sql", name);
282+
283+
let name = format!("{}", transition_fn_name);
284+
let name = quote! { name = #name };
285+
286+
quote! {
287+
// TODO type checks
288+
289+
fn #transition_fn_name(
290+
#(#all_arg_signatures,)*
291+
) #ret {
292+
#body
293+
}
294+
295+
// TODO derive immutable and parallel_safe from above
296+
#[pgx::pg_extern(#name, immutable, parallel_safe #pg_extern_schema)]
297+
fn #impl_fn_name(
298+
state: crate::palloc::Internal,
299+
#(#extra_arg_signatures,)*
300+
fcinfo: pgx::pg_sys::FunctionCallInfo,
301+
) -> Option<crate::palloc::Internal> {
302+
// TODO Extract extra_arg_NAMES so we can call directly into transition_fn above rather than duplicate.
303+
let f = |#state_signature| #body;
304+
unsafe { crate::aggregate_utils::transition(state, fcinfo, f) }
305+
}
306+
307+
pgx::extension_sql!(
308+
#create,
309+
name=#extension_sql_name,
310+
requires = [
311+
#impl_fn_name,
312+
#final_fn_name
313+
#combine_fn_name
314+
#serial_fn_name
315+
#deserial_fn_name
316+
],
317+
);
318+
}
319+
}
320+
}
321+
322+
fn fmt_agg_func(create: &mut String, funcprefix: &str, func: &Func) -> proc_macro2::TokenStream {
323+
create.push_str(&format!(" {}func = ", funcprefix));
324+
if let Some(schema) = func.schema.as_ref() {
325+
create.push_str(&format!("{}.", schema));
326+
}
327+
create.push_str(&format!("{},\n", func.name));
328+
let name = &func.name;
329+
quote! { #name, }
330+
}
331+
332+
#[derive(Debug)]
333+
enum Parallel {
334+
Unsafe,
335+
Restricted,
336+
Safe,
337+
}
338+
impl Default for Parallel {
339+
fn default() -> Self {
340+
Self::Unsafe
341+
}
342+
}
343+
impl std::fmt::Display for Parallel {
344+
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
345+
formatter.write_str(match self {
346+
Self::Unsafe => "unsafe",
347+
Self::Restricted => "restricted",
348+
Self::Safe => "safe",
349+
})
350+
}
351+
}
352+
353+
#[derive(Debug)]
354+
struct Attr {
355+
name: syn::Ident,
356+
value: Vec<syn::Ident>,
357+
}
358+
impl Parse for Attr {
359+
fn parse(input: ParseStream) -> syn::Result<Self> {
360+
let name = input.parse()?;
361+
let _: Token![=] = input.parse()?;
362+
let path: syn::Path = input.parse()?;
363+
let value;
364+
match path.segments.iter().collect::<Vec<_>>().as_slice() {
365+
[syn::PathSegment { ident, .. }] => value = vec![ident.clone()],
366+
[schema, ident] => {
367+
value = vec![schema.ident.clone(), ident.ident.clone()];
368+
}
369+
what => todo!("hmm got {:?}", what),
370+
}
371+
Ok(Self { name, value })
372+
}
373+
}
374+
375+
#[derive(Debug)]
376+
struct Func {
377+
name: syn::Ident,
378+
schema: Option<syn::Ident>,
379+
}

0 commit comments

Comments
 (0)