From 4f92552d47af7950559aee9e36a1f5d9689a02a2 Mon Sep 17 00:00:00 2001 From: BigFish2086 Date: Thu, 3 Jul 2025 20:23:52 +0300 Subject: [PATCH 1/5] feat: adopt macros to apply to struct-impl block as well as traits #43 --- macros/src/ast.rs | 436 ++++++++++++++++++++++++++++++++++++----- macros/src/gen.rs | 489 +++++++++++++++++++++++++++++++++++++++++----- macros/src/lib.rs | 41 +++- 3 files changed, 867 insertions(+), 99 deletions(-) diff --git a/macros/src/ast.rs b/macros/src/ast.rs index 7a9aadf..feefe70 100644 --- a/macros/src/ast.rs +++ b/macros/src/ast.rs @@ -14,10 +14,10 @@ use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::spanned::Spanned; -use syn::token::Comma; use syn::{ - braced, parenthesized, parse_quote, Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, - Ident, Lit, Pat, PatType, Path, PathArguments, Result, ReturnType, Token, Type, Visibility, + parse_quote, Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, Ident, ImplItem, + ImplItemFn, Item, ItemImpl, ItemTrait, Lit, Meta, Pat, PatType, Path, PathArguments, Result, + ReturnType, TraitItem, TraitItemFn, Type, Visibility, }; /// Accumulates multiple errors into a result. @@ -63,7 +63,75 @@ impl Parse for Workflow { } } -pub(crate) struct ServiceInner { +pub(crate) struct ValidArgs { + pub(crate) vis: Visibility, + pub(crate) restate_name: Option, +} + +impl Parse for ValidArgs { + fn parse(input: ParseStream) -> Result { + let mut vis = None; + let mut restate_name = None; + + let punctuated = + syn::punctuated::Punctuated::::parse_terminated(input)?; + + for meta in punctuated { + match meta { + Meta::NameValue(name_value) if name_value.path.is_ident("vis") => { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = &name_value.value + { + let vis_str = lit_str.value(); + vis = Some(syn::parse_str::(&vis_str).map_err(|e| { + Error::new( + name_value.value.span(), + format!( + "Invalid visibility modifier '{}'. Expected \"pub\", \"pub(crate)\", etc.: {}", + vis_str, e + ), + ) + })?); + } else { + return Err(Error::new( + name_value.value.span(), + "Expected a string literal for 'vis' (e.g., vis = \"pub\", vis = \"pub(crate)\")", + )); + } + } + Meta::NameValue(name_value) if name_value.path.is_ident("name") => { + if let Expr::Lit(ExprLit { + lit: Lit::Str(lit_str), + .. + }) = &name_value.value + { + restate_name = Some(lit_str.value()); + } else { + return Err(Error::new( + name_value.span(), + "Expected a string literal for 'name'", + )); + } + } + bad_meta => { + return Err(Error::new( + bad_meta.span(), + "Invalid attribute format. Expected #[service(vis = pub(crate), name = \"...\")]", + )); + } + } + } + + Ok(Self { + vis: vis.unwrap_or(Visibility::Inherited), + restate_name, + }) + } +} + +pub(crate) struct TraitBlockServiceInner { pub(crate) attrs: Vec, pub(crate) restate_name: String, pub(crate) vis: Visibility, @@ -71,26 +139,23 @@ pub(crate) struct ServiceInner { pub(crate) handlers: Vec, } -impl ServiceInner { - fn parse(service_type: ServiceType, input: ParseStream) -> Result { - let parsed_attrs = input.call(Attribute::parse_outer)?; - let vis = input.parse()?; - input.parse::()?; - let ident: Ident = input.parse()?; - let content; - braced!(content in input); +impl TraitBlockServiceInner { + fn parse(service_type: ServiceType, input: ItemTrait) -> Result { + let parsed_attrs = input.attrs; + let vis = input.vis; + let ident: Ident = input.ident; let mut rpcs = Vec::::new(); - while !content.is_empty() { - let h: Handler = content.parse()?; - - if h.is_shared && service_type == ServiceType::Service { - return Err(Error::new( - h.ident.span(), - "Service handlers cannot be annotated with #[shared]", - )); + for item in input.items { + if let TraitItem::Fn(handler) = item { + let handler: Handler = Handler::parse(handler)?; + if handler.is_shared && service_type == ServiceType::Service { + return Err(Error::new( + handler.ident.span(), + "Service handlers cannot be annotated with #[shared]", + )); + } + rpcs.push(handler); } - - rpcs.push(h); } let mut ident_errors = Ok(()); for rpc in &rpcs { @@ -139,6 +204,137 @@ impl ServiceInner { } } +pub(crate) struct ImplBlockServiceInner { + pub(crate) attrs: Vec, + pub(crate) restate_name: String, + pub(crate) vis: Visibility, + pub(crate) ident: Ident, + pub(crate) handlers: Vec, + pub(crate) impl_block: ItemImpl, +} + +impl ImplBlockServiceInner { + fn parse(service_type: ServiceType, mut input: ItemImpl) -> Result { + let ident = match input.self_ty.as_ref() { + Type::Path(path) => path.path.segments[0].ident.clone(), + bad_path => { + return Err(Error::new(bad_path.span(), "Only on impl blocks")); + } + }; + + let mut rpcs = Vec::new(); + for item in input.items.iter_mut() { + match item { + ImplItem::Const(_) => {} + ImplItem::Fn(handler) => { + let mut is_handler = false; + let mut is_shared = false; + let mut restate_name = None; + + let mut attrs = Vec::with_capacity(handler.attrs.len()); + for attr in &handler.attrs { + if attr.path().is_ident("handler") { + if is_handler { + return Err(Error::new( + attr.span(), + "Multiple `#[handler]` attributes found.", + )); + } + if handler.sig.asyncness.is_none() { + return Err(Error::new( + handler.sig.fn_token.span(), + "expected async, handlers are async fn", + )); + } + is_handler = true; + (is_shared, restate_name) = + extract_handler_attributes(service_type, attr)?; + } else { + attrs.push(attr.clone()); + } + } + + if is_handler { + let handler_arg = + validate_handler_arguments(service_type, is_shared, handler)?; + + let return_type: ReturnType = handler.sig.output.clone(); + let (output_ok, output_err) = match &return_type { + ReturnType::Default => { + return Err(Error::new( + return_type.span(), + "The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type", + )); + } + ReturnType::Type(_, ty) => { + if let Some((ok_ty, err_ty)) = extract_handler_result_parameter(ty) + { + (ok_ty, err_ty) + } else { + return Err(Error::new( + return_type.span(), + "Only Result or restate_sdk::prelude::HandlerResult is supported as return type", + )); + } + } + }; + + handler.attrs = attrs.clone(); + + rpcs.push(Handler { + attrs, + is_shared, + ident: handler.sig.ident.clone(), + restate_name: restate_name.unwrap_or(handler.sig.ident.to_string()), + arg: handler_arg, + output_ok, + output_err, + }); + } + } + bad_impl_item => { + return Err(Error::new(bad_impl_item.span(), "Only on consts and fns")); + } + } + } + + Ok(Self { + attrs: input.attrs.clone(), + restate_name: "".to_string(), + ident, + vis: Visibility::Inherited, + handlers: rpcs, + impl_block: input, + }) + } +} + +pub(crate) enum ServiceInner { + Trait(TraitBlockServiceInner), + Impl(ImplBlockServiceInner), +} + +impl ServiceInner { + fn parse(service_type: ServiceType, input: ParseStream) -> Result { + let item = input.parse()?; + + match item { + Item::Trait(trait_block) => Ok(Self::Trait(TraitBlockServiceInner::parse( + service_type, + trait_block, + )?)), + Item::Impl(impl_block) => Ok(Self::Impl(ImplBlockServiceInner::parse( + service_type, + impl_block, + )?)), + other => Err(syn::Error::new_spanned( + other, + "expected `impl` or `struct`", + )), + } + } +} + pub(crate) struct Handler { pub(crate) attrs: Vec, pub(crate) is_shared: bool, @@ -149,20 +345,17 @@ pub(crate) struct Handler { pub(crate) output_err: Type, } -impl Parse for Handler { - fn parse(input: ParseStream) -> Result { - let parsed_attrs = input.call(Attribute::parse_outer)?; - - input.parse::()?; - input.parse::()?; - let ident: Ident = input.parse()?; +impl Handler { + fn parse(input: TraitItemFn) -> Result { + let parsed_attrs = input.attrs; + let ident: Ident = input.sig.ident; + if input.sig.asyncness.is_none() { + return Err(Error::new(ident.span(), "Handlers must be `async`")); + } - // Parse arguments - let content; - parenthesized!(content in input); let mut args = Vec::new(); let mut errors = Ok(()); - for arg in content.parse_terminated(FnArg::parse, Comma)? { + for arg in &input.sig.inputs { match arg { FnArg::Typed(captured) if matches!(&*captured.pat, Pat::Ident(_)) => { args.push(captured); @@ -180,21 +373,29 @@ impl Parse for Handler { ); } } - } - if args.len() > 1 { - extend_errors!( - errors, - Error::new(content.span(), "Only one input argument is supported") - ); + if args.len() > 1 { + extend_errors!( + errors, + Error::new( + input.sig.inputs.span(), + "Only one input argument is supported" + ) // TODO: is this a correct span + ); + break; + } } errors?; - // Parse return type - let return_type: ReturnType = input.parse()?; - input.parse::()?; + let return_type: ReturnType = input.sig.output; + if input.default.is_some() { + return Err(Error::new( + ident.span(), + "Default trait method impl isn't supported", + )); + } let (ok_ty, err_ty) = match &return_type { - ReturnType::Default => return Err(Error::new( + ReturnType::Default => return Err(Error::new( return_type.span(), "The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type", )), @@ -230,7 +431,7 @@ impl Parse for Handler { is_shared, restate_name, ident, - arg: args.pop(), + arg: args.pop().cloned(), output_ok: ok_ty, output_err: err_ty, }) @@ -266,6 +467,153 @@ fn read_literal_attribute_name(attr: &Attribute) -> Result> { .transpose() } +fn extract_handler_attributes( + service_type: ServiceType, + attr: &Attribute, +) -> Result<(bool, Option)> { + let mut is_shared = false; + let mut restate_name = None; + + match &attr.meta { + Meta::Path(_) => {} + Meta::List(meta_list) => { + let mut seen_shared = false; + let mut seen_name = false; + meta_list.parse_nested_meta(|meta| { + if meta.path.is_ident("shared") { + if seen_shared { + return Err(Error::new(meta.path.span(), "Duplicate `shared`")); + } + if service_type == ServiceType::Service { + return Err(Error::new( + meta.path.span(), + "Service handlers cannot be annotated with #[handler(shared)]", + )); + } + is_shared = true; + seen_shared = true; + } else if meta.path.is_ident("name") { + if seen_name { + return Err(Error::new(meta.path.span(), "Duplicate `name`")); + } + let lit: Lit = meta.value()?.parse()?; + if let Lit::Str(lit_str) = lit { + seen_name = true; + restate_name = Some(lit_str.value()); + } else { + return Err(Error::new( + lit.span(), + "Expected `name` to be a string literal", + )); + } + } else { + return Err(Error::new( + meta.path.span(), + "Invalid attribute inside #[handler]", + )); + } + Ok(()) + })?; + } + Meta::NameValue(_) => { + return Err(Error::new( + attr.meta.span(), + "Invalid attribute format for #[handler]", + )); + } + } + Ok((is_shared, restate_name)) +} + +fn validate_handler_arguments( + service_type: ServiceType, + is_shared: bool, + handler: &ImplItemFn, +) -> Result> { + let mut args_iter = handler.sig.inputs.iter(); + + match args_iter.next() { + Some(FnArg::Receiver(_)) => {} + Some(arg) => { + return Err(Error::new( + arg.span(), + "handler should have a `self` argument", + )); + } + None => { + return Err(Error::new( + handler.sig.ident.span(), + "Invalid handler arguments. It should be like (`self`, `ctx`, optional arg)", + )); + } + }; + + let valid_ctx: Ident = match (&service_type, is_shared) { + (ServiceType::Service, _) => parse_quote! { Context }, + (ServiceType::Object, true) => parse_quote! { SharedObjectContext }, + (ServiceType::Object, false) => parse_quote! { ObjectContext }, + (ServiceType::Workflow, true) => parse_quote! { SharedWorkflowContext }, + (ServiceType::Workflow, false) => parse_quote! { WorkflowContext }, + }; + + // TODO: allow the user to have unused context like _:Context in the handler + match args_iter.next() { + Some(arg @ FnArg::Typed(typed_arg)) if matches!(&*typed_arg.pat, Pat::Ident(_)) => { + if let Type::Path(type_path) = &*typed_arg.ty { + let ctx_ident = &type_path.path.segments.last().unwrap().ident; + + if ctx_ident != &valid_ctx { + let service_desc = match service_type { + ServiceType::Service => "service", + ServiceType::Object => { + if is_shared { + "shared object" + } else { + "object" + } + } + ServiceType::Workflow => { + if is_shared { + "shared workflow" + } else { + "workflow" + } + } + }; + + return Err(Error::new( + ctx_ident.span(), + format!( + "Expects `{}` type for this `{}`, but `{}` was provided.", + valid_ctx, service_desc, ctx_ident + ), + )); + } + } else { + return Err(Error::new( + arg.span(), + "Second argument must be one of the allowed context types", + )); + } + } + _ => { + return Err(Error::new( + handler.sig.ident.span(), + "Invalid handler arguments. It should be like (`self`, `ctx`, optional arg)", + )); + } + }; + + match args_iter.next() { + Some(FnArg::Typed(type_arg)) => Ok(Some(type_arg.clone())), + Some(FnArg::Receiver(arg)) => Err(Error::new( + arg.span(), + "Invalid handler arguments. It should be like (`self`, `ctx`, arg)", + )), + None => Ok(None), + } +} + fn extract_handler_result_parameter(ty: &Type) -> Option<(Type, Type)> { let path = match ty { Type::Path(ty) => &ty.path, diff --git a/macros/src/gen.rs b/macros/src/gen.rs index 18f2c25..2085628 100644 --- a/macros/src/gen.rs +++ b/macros/src/gen.rs @@ -1,31 +1,45 @@ -use crate::ast::{Handler, Object, Service, ServiceInner, ServiceType, Workflow}; +use crate::ast::{ + Handler, ImplBlockServiceInner, Object, Service, ServiceInner, ServiceType, + TraitBlockServiceInner, Workflow, +}; use proc_macro2::TokenStream as TokenStream2; use proc_macro2::{Ident, Literal}; use quote::{format_ident, quote, ToTokens}; -use syn::{Attribute, PatType, Visibility}; +use syn::{Attribute, ItemImpl, PatType, Visibility}; -pub(crate) struct ServiceGenerator<'a> { - pub(crate) service_ty: ServiceType, - pub(crate) restate_name: &'a str, - pub(crate) service_ident: &'a Ident, - pub(crate) client_ident: Ident, - pub(crate) serve_ident: Ident, - pub(crate) vis: &'a Visibility, - pub(crate) attrs: &'a [Attribute], - pub(crate) handlers: &'a [Handler], +pub(crate) enum ServiceGenerator<'a> { + Trait(TraitBlockServiceGenerator<'a>), + Impl(ImplBlockServiceGenerator<'a>), } impl<'a> ServiceGenerator<'a> { fn new(service_ty: ServiceType, s: &'a ServiceInner) -> Self { - ServiceGenerator { - service_ty, - restate_name: &s.restate_name, - service_ident: &s.ident, - client_ident: format_ident!("{}Client", s.ident), - serve_ident: format_ident!("Serve{}", s.ident), - vis: &s.vis, - attrs: &s.attrs, - handlers: &s.handlers, + match s { + ServiceInner::Trait(s @ TraitBlockServiceInner { .. }) => { + Self::Trait(TraitBlockServiceGenerator { + service_ty, + restate_name: &s.restate_name, + service_ident: &s.ident, + client_ident: format_ident!("{}Client", s.ident), + serve_ident: format_ident!("Serve{}", s.ident), + vis: &s.vis, + attrs: &s.attrs, + handlers: &s.handlers, + }) + } + ServiceInner::Impl(s @ ImplBlockServiceInner { .. }) => { + Self::Impl(ImplBlockServiceGenerator { + service_ty, + restate_name: &s.restate_name, + service_ident: &s.ident, + client_ident: format_ident!("{}Client", s.ident), + serve_ident: format_ident!("Serve{}", s.ident), + vis: &s.vis, + attrs: &s.attrs, + handlers: &s.handlers, + impl_block: &s.impl_block, + }) + } } } @@ -41,6 +55,74 @@ impl<'a> ServiceGenerator<'a> { Self::new(ServiceType::Workflow, &s.0) } + fn trait_service(&self) -> TokenStream2 { + match self { + Self::Trait(s @ TraitBlockServiceGenerator { .. }) => s.trait_service(), + Self::Impl(s @ ImplBlockServiceGenerator { .. }) => s.trait_service(), + } + } + + fn struct_serve(&self) -> TokenStream2 { + match self { + Self::Trait(s @ TraitBlockServiceGenerator { .. }) => s.struct_serve(), + Self::Impl(s @ ImplBlockServiceGenerator { .. }) => s.struct_serve(), + } + } + + fn impl_service_for_serve(&self) -> TokenStream2 { + match self { + Self::Trait(s @ TraitBlockServiceGenerator { .. }) => s.impl_service_for_serve(), + Self::Impl(s @ ImplBlockServiceGenerator { .. }) => s.impl_service_for_serve(), + } + } + + fn impl_discoverable(&self) -> TokenStream2 { + match self { + Self::Trait(s @ TraitBlockServiceGenerator { .. }) => s.impl_discoverable(), + Self::Impl(s @ ImplBlockServiceGenerator { .. }) => s.impl_discoverable(), + } + } + + fn struct_client(&self) -> TokenStream2 { + match self { + Self::Trait(s @ TraitBlockServiceGenerator { .. }) => s.struct_client(), + Self::Impl(s @ ImplBlockServiceGenerator { .. }) => s.struct_client(), + } + } + + fn impl_client(&self) -> TokenStream2 { + match self { + Self::Trait(s @ TraitBlockServiceGenerator { .. }) => s.impl_client(), + Self::Impl(s @ ImplBlockServiceGenerator { .. }) => s.impl_client(), + } + } +} + +impl<'a> ToTokens for ServiceGenerator<'a> { + fn to_tokens(&self, output: &mut TokenStream2) { + output.extend(vec![ + self.trait_service(), + self.struct_serve(), + self.impl_service_for_serve(), + self.impl_discoverable(), + self.struct_client(), + self.impl_client(), + ]); + } +} + +pub(crate) struct TraitBlockServiceGenerator<'a> { + pub(crate) service_ty: ServiceType, + pub(crate) restate_name: &'a str, + pub(crate) service_ident: &'a Ident, + pub(crate) client_ident: Ident, + pub(crate) serve_ident: Ident, + pub(crate) vis: &'a Visibility, + pub(crate) attrs: &'a [Attribute], + pub(crate) handlers: &'a [Handler], +} + +impl<'a> TraitBlockServiceGenerator<'a> { fn trait_service(&self) -> TokenStream2 { let Self { attrs, @@ -151,8 +233,8 @@ impl<'a> ServiceGenerator<'a> { #( #match_arms ),* _ => { return Err(::restate_sdk::endpoint::Error::unknown_handler( - ctx.service_name(), - ctx.handler_name(), + ctx.service_name(), + ctx.handler_name(), )) } } @@ -237,21 +319,345 @@ impl<'a> ServiceGenerator<'a> { quote! { impl ::restate_sdk::service::Discoverable for #serve_ident where S: #service_ident, + { + fn discover() -> ::restate_sdk::discovery::Service { + ::restate_sdk::discovery::Service { + ty: #service_ty_token, + name: ::restate_sdk::discovery::ServiceName::try_from(#service_literal.to_string()) + .expect("Service name valid"), + handlers: vec![#( #handlers ),*], + documentation: None, + metadata: Default::default(), + abort_timeout: None, + inactivity_timeout: None, + journal_retention: None, + idempotency_retention: None, + enable_lazy_state: None, + ingress_private: None, + } + } + } + } + } + + fn struct_client(&self) -> TokenStream2 { + let &Self { + vis, + ref client_ident, + // service_ident, + ref service_ty, + .. + } = self; + + let key_field = match service_ty { + ServiceType::Service => quote! {}, + ServiceType::Object | ServiceType::Workflow => quote! { + key: String, + }, + }; + + let into_client_impl = match service_ty { + ServiceType::Service => { + quote! { + impl<'ctx> ::restate_sdk::context::IntoServiceClient<'ctx> for #client_ident<'ctx> { + fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal) -> Self { + Self { ctx } + } + } + } + } + ServiceType::Object => quote! { + impl<'ctx> ::restate_sdk::context::IntoObjectClient<'ctx> for #client_ident<'ctx> { + fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, key: String) -> Self { + Self { ctx, key } + } + } + }, + ServiceType::Workflow => quote! { + impl<'ctx> ::restate_sdk::context::IntoWorkflowClient<'ctx> for #client_ident<'ctx> { + fn create_client(ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, key: String) -> Self { + Self { ctx, key } + } + } + }, + }; + + quote! { + /// Struct exposing the client to invoke [#service_ident] from another service. + #vis struct #client_ident<'ctx> { + ctx: &'ctx ::restate_sdk::endpoint::ContextInternal, + #key_field + } + + #into_client_impl + } + } + + fn impl_client(&self) -> TokenStream2 { + let &Self { + vis, + ref client_ident, + service_ident, + handlers, + restate_name, + service_ty, + .. + } = self; + + let service_literal = Literal::string(restate_name); + + let handlers_fns = handlers.iter().map(|handler| { + let handler_ident = &handler.ident; + let handler_literal = Literal::string(&handler.restate_name); + + let argument = match &handler.arg { + None => quote! {}, + Some(PatType { + ty, .. + }) => quote! { req: #ty } + }; + let argument_ty = match &handler.arg { + None => quote! { () }, + Some(PatType { + ty, .. + }) => quote! { #ty } + }; + let res_ty = &handler.output_ok; + let input = match &handler.arg { + None => quote! { () }, + Some(_) => quote! { req } + }; + let request_target = match service_ty { + ServiceType::Service => quote! { + ::restate_sdk::context::RequestTarget::service(#service_literal, #handler_literal) + }, + ServiceType::Object => quote! { + ::restate_sdk::context::RequestTarget::object(#service_literal, &self.key, #handler_literal) + }, + ServiceType::Workflow => quote! { + ::restate_sdk::context::RequestTarget::workflow(#service_literal, &self.key, #handler_literal) + } + }; + + quote! { + #vis fn #handler_ident(&self, #argument) -> ::restate_sdk::context::Request<'ctx, #argument_ty, #res_ty> { + self.ctx.request(#request_target, #input) + } + } + }); + + let doc_msg = format!( + "Struct exposing the client to invoke [`{service_ident}`] from another service." + ); + quote! { + #[doc = #doc_msg] + impl<'ctx> #client_ident<'ctx> { + #( #handlers_fns )* + } + } + } +} + +pub(crate) struct ImplBlockServiceGenerator<'a> { + pub(crate) service_ty: ServiceType, + pub(crate) restate_name: &'a str, + pub(crate) service_ident: &'a Ident, + pub(crate) client_ident: Ident, + pub(crate) serve_ident: Ident, + pub(crate) vis: &'a Visibility, + pub(crate) attrs: &'a [Attribute], + pub(crate) handlers: &'a [Handler], + pub(crate) impl_block: &'a ItemImpl, +} + +impl<'a> ImplBlockServiceGenerator<'a> { + fn trait_service(&self) -> TokenStream2 { + let Self { + attrs, + serve_ident, + service_ident, + impl_block, + vis, + .. + } = self; + + quote! { + #impl_block + + #( #attrs )* + impl #service_ident { + /// Returns a serving function to use with [::restate_sdk::endpoint::Builder::with_service]. + #vis fn serve(self) -> #serve_ident { + #serve_ident { service: ::std::sync::Arc::new(self) } + } + } + } + } + + fn struct_serve(&self) -> TokenStream2 { + let &Self { + ref serve_ident, + vis, + .. + } = self; + + quote! { + /// Struct implementing [::restate_sdk::service::Service], to be used with [::restate_sdk::endpoint::Builder::with_service]. + #[derive(Clone)] + #vis struct #serve_ident { + service: ::std::sync::Arc, + } + } + } + + fn impl_service_for_serve(&self) -> TokenStream2 { + let Self { + serve_ident, + service_ident, + handlers, + .. + } = self; + + let match_arms = handlers.iter().map(|handler| { + let handler_ident = &handler.ident; + + let get_input_and_call = if handler.arg.is_some() { + quote! { + let (input, metadata) = ctx.input().await; + let fut = #service_ident::#handler_ident(&service_clone, (&ctx, metadata).into(), input); + } + } else { + quote! { + let (_, metadata) = ctx.input::<()>().await; + let fut = #service_ident::#handler_ident(&service_clone, (&ctx, metadata).into()); + } + }; + + let handler_literal = Literal::string(&handler.restate_name); + + quote! { + #handler_literal => { + #get_input_and_call + let res = fut.await.map_err(::restate_sdk::errors::HandlerError::from); + ctx.handle_handler_result(res); + ctx.end(); + Ok(()) + } + } + }); + + quote! { + impl ::restate_sdk::service::Service for #serve_ident<#service_ident> + { + type Future = ::restate_sdk::service::ServiceBoxFuture; + + fn handle(&self, ctx: ::restate_sdk::endpoint::ContextInternal) -> Self::Future { + let service_clone = ::std::sync::Arc::clone(&self.service); + Box::pin(async move { + match ctx.handler_name() { + #( #match_arms ),* + _ => { + return Err(::restate_sdk::endpoint::Error::unknown_handler( + ctx.service_name(), + ctx.handler_name(), + )) + } + } + }) + } + } + } + } + + fn impl_discoverable(&self) -> TokenStream2 { + let Self { + service_ty, + serve_ident, + service_ident, + handlers, + restate_name, + .. + } = self; + + let service_literal = Literal::string(restate_name); + + let service_ty_token = match service_ty { + ServiceType::Service => quote! { ::restate_sdk::discovery::ServiceType::Service }, + ServiceType::Object => { + quote! { ::restate_sdk::discovery::ServiceType::VirtualObject } + } + ServiceType::Workflow => quote! { ::restate_sdk::discovery::ServiceType::Workflow }, + }; + + let handlers = handlers.iter().map(|handler| { + let handler_literal = Literal::string(&handler.restate_name); + + let handler_ty = if handler.is_shared { + quote! { Some(::restate_sdk::discovery::HandlerType::Shared) } + } else if *service_ty == ServiceType::Workflow { + quote! { Some(::restate_sdk::discovery::HandlerType::Workflow) } + } else { + // Macro has same defaulting rules of the discovery manifest + quote! { None } + }; + + let input_schema = match &handler.arg { + Some(PatType { ty, .. }) => { + quote! { + Some(::restate_sdk::discovery::InputPayload::from_metadata::<#ty>()) + } + } + None => quote! { + Some(::restate_sdk::discovery::InputPayload::empty()) + } + }; + + let output_ty = &handler.output_ok; + let output_schema = match output_ty { + syn::Type::Tuple(tuple) if tuple.elems.is_empty() => quote! { + Some(::restate_sdk::discovery::OutputPayload::empty()) + }, + _ => quote! { + Some(::restate_sdk::discovery::OutputPayload::from_metadata::<#output_ty>()) + } + }; + + quote! { + ::restate_sdk::discovery::Handler { + name: ::restate_sdk::discovery::HandlerName::try_from(#handler_literal).expect("Handler name valid"), + input: #input_schema, + output: #output_schema, + ty: #handler_ty, + documentation: None, + metadata: Default::default(), + abort_timeout: None, + inactivity_timeout: None, + journal_retention: None, + idempotency_retention: None, + workflow_completion_retention: None, + enable_lazy_state: None, + ingress_private: None, + } + } + }); + + quote! { + impl::restate_sdk::service::Discoverable for #serve_ident<#service_ident> { fn discover() -> ::restate_sdk::discovery::Service { ::restate_sdk::discovery::Service { ty: #service_ty_token, name: ::restate_sdk::discovery::ServiceName::try_from(#service_literal.to_string()) .expect("Service name valid"), - handlers: vec![#( #handlers ),*], - documentation: None, - metadata: Default::default(), - abort_timeout: None, - inactivity_timeout: None, - journal_retention: None, - idempotency_retention: None, - enable_lazy_state: None, - ingress_private: None, + handlers: vec![#( #handlers ),*], + documentation: None, + metadata: Default::default(), + abort_timeout: None, + inactivity_timeout: None, + journal_retention: None, + idempotency_retention: None, + enable_lazy_state: None, + ingress_private: None, } } } @@ -331,14 +737,14 @@ impl<'a> ServiceGenerator<'a> { let argument = match &handler.arg { None => quote! {}, Some(PatType { - ty, .. - }) => quote! { req: #ty } + ty, .. + }) => quote! { req: #ty } }; let argument_ty = match &handler.arg { None => quote! { () }, Some(PatType { - ty, .. - }) => quote! { #ty } + ty, .. + }) => quote! { #ty } }; let res_ty = &handler.output_ok; let input = match &handler.arg { @@ -375,16 +781,3 @@ impl<'a> ServiceGenerator<'a> { } } } - -impl<'a> ToTokens for ServiceGenerator<'a> { - fn to_tokens(&self, output: &mut TokenStream2) { - output.extend(vec![ - self.trait_service(), - self.struct_serve(), - self.impl_service_for_serve(), - self.impl_discoverable(), - self.struct_client(), - self.impl_client(), - ]); - } -} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index 8290af3..58c1fba 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -16,15 +16,24 @@ extern crate proc_macro; mod ast; mod gen; -use crate::ast::{Object, Service, Workflow}; +use crate::ast::{Object, Service, ServiceInner, ValidArgs, Workflow}; use crate::gen::ServiceGenerator; use proc_macro::TokenStream; use quote::ToTokens; use syn::parse_macro_input; #[proc_macro_attribute] -pub fn service(_: TokenStream, input: TokenStream) -> TokenStream { - let svc = parse_macro_input!(input as Service); +pub fn service(args: TokenStream, input: TokenStream) -> TokenStream { + let mut svc = parse_macro_input!(input as Service); + + match &mut svc.0 { + ServiceInner::Trait(..) => {} + ServiceInner::Impl(inner) => { + let args = parse_macro_input!(args as ValidArgs); + inner.restate_name = args.restate_name.unwrap_or(inner.ident.to_string()); + inner.vis = args.vis; + } + } ServiceGenerator::new_service(&svc) .into_token_stream() @@ -32,8 +41,17 @@ pub fn service(_: TokenStream, input: TokenStream) -> TokenStream { } #[proc_macro_attribute] -pub fn object(_: TokenStream, input: TokenStream) -> TokenStream { - let svc = parse_macro_input!(input as Object); +pub fn object(args: TokenStream, input: TokenStream) -> TokenStream { + let mut svc = parse_macro_input!(input as Object); + + match &mut svc.0 { + ServiceInner::Trait(..) => {} + ServiceInner::Impl(inner) => { + let args = parse_macro_input!(args as ValidArgs); + inner.restate_name = args.restate_name.unwrap_or(inner.ident.to_string()); + inner.vis = args.vis; + } + } ServiceGenerator::new_object(&svc) .into_token_stream() @@ -41,8 +59,17 @@ pub fn object(_: TokenStream, input: TokenStream) -> TokenStream { } #[proc_macro_attribute] -pub fn workflow(_: TokenStream, input: TokenStream) -> TokenStream { - let svc = parse_macro_input!(input as Workflow); +pub fn workflow(args: TokenStream, input: TokenStream) -> TokenStream { + let mut svc = parse_macro_input!(input as Workflow); + + match &mut svc.0 { + ServiceInner::Trait(..) => {} + ServiceInner::Impl(inner) => { + let args = parse_macro_input!(args as ValidArgs); + inner.restate_name = args.restate_name.unwrap_or(inner.ident.to_string()); + inner.vis = args.vis; + } + } ServiceGenerator::new_workflow(&svc) .into_token_stream() From a19e2cbb4324f491d301a1576cad535076b6d9fa Mon Sep 17 00:00:00 2001 From: BigFish2086 Date: Thu, 3 Jul 2025 20:26:25 +0300 Subject: [PATCH 2/5] update examples to show case how to use new impl syntax with the same old examples #43 --- Cargo.toml | 4 +- examples/impl_block/counter.rs | 40 +++++++++ examples/impl_block/cron.rs | 84 +++++++++++++++++++ examples/impl_block/failures.rs | 34 ++++++++ examples/impl_block/greeter.rs | 20 +++++ examples/impl_block/run.rs | 44 ++++++++++ examples/{ => impl_block}/services/mod.rs | 0 examples/impl_block/services/my_service.rs | 19 +++++ .../impl_block/services/my_virtual_object.rs | 36 ++++++++ examples/impl_block/services/my_workflow.rs | 32 +++++++ examples/impl_block/tracing.rs | 34 ++++++++ examples/{ => trait_block}/counter.rs | 0 examples/{ => trait_block}/cron.rs | 0 examples/{ => trait_block}/failures.rs | 0 examples/{ => trait_block}/greeter.rs | 0 examples/{ => trait_block}/run.rs | 0 examples/{ => trait_block}/schema.rs | 0 examples/trait_block/services/mod.rs | 3 + .../{ => trait_block}/services/my_service.rs | 0 .../services/my_virtual_object.rs | 0 .../{ => trait_block}/services/my_workflow.rs | 0 examples/{ => trait_block}/tracing.rs | 0 src/context/mod.rs | 8 +- src/http_server.rs | 4 +- 24 files changed, 354 insertions(+), 8 deletions(-) create mode 100644 examples/impl_block/counter.rs create mode 100644 examples/impl_block/cron.rs create mode 100644 examples/impl_block/failures.rs create mode 100644 examples/impl_block/greeter.rs create mode 100644 examples/impl_block/run.rs rename examples/{ => impl_block}/services/mod.rs (100%) create mode 100644 examples/impl_block/services/my_service.rs create mode 100644 examples/impl_block/services/my_virtual_object.rs create mode 100644 examples/impl_block/services/my_workflow.rs create mode 100644 examples/impl_block/tracing.rs rename examples/{ => trait_block}/counter.rs (100%) rename examples/{ => trait_block}/cron.rs (100%) rename examples/{ => trait_block}/failures.rs (100%) rename examples/{ => trait_block}/greeter.rs (100%) rename examples/{ => trait_block}/run.rs (100%) rename examples/{ => trait_block}/schema.rs (100%) create mode 100644 examples/trait_block/services/mod.rs rename examples/{ => trait_block}/services/my_service.rs (100%) rename examples/{ => trait_block}/services/my_virtual_object.rs (100%) rename examples/{ => trait_block}/services/my_workflow.rs (100%) rename examples/{ => trait_block}/tracing.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 0005a5d..3f907bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,12 +9,12 @@ rust-version = "1.76.0" [[example]] name = "tracing" -path = "examples/tracing.rs" +path = "examples/trait_block/tracing.rs" required-features = ["tracing-span-filter"] [[example]] name = "schema" -path = "examples/schema.rs" +path = "examples/trait_block/schema.rs" required-features = ["schemars"] [features] diff --git a/examples/impl_block/counter.rs b/examples/impl_block/counter.rs new file mode 100644 index 0000000..f41d8c0 --- /dev/null +++ b/examples/impl_block/counter.rs @@ -0,0 +1,40 @@ +use restate_sdk::prelude::*; + +const COUNT: &str = "count"; + +struct Counter; + +#[restate_sdk::object] +impl Counter { + #[handler(shared)] + async fn get(&self, ctx: SharedObjectContext<'_>) -> Result { + Ok(ctx.get::(COUNT).await?.unwrap_or(0)) + } + + #[handler] + async fn add(&self, ctx: ObjectContext<'_>, val: u64) -> Result { + let current = ctx.get::(COUNT).await?.unwrap_or(0); + let new = current + val; + ctx.set(COUNT, new); + Ok(new) + } + + #[handler] + async fn increment(&self, ctx: ObjectContext<'_>) -> Result { + self.add(ctx, 1).await + } + + #[handler] + async fn reset(&self, ctx: ObjectContext<'_>) -> Result<(), TerminalError> { + ctx.clear(COUNT); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new(Endpoint::builder().bind(Counter.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/impl_block/cron.rs b/examples/impl_block/cron.rs new file mode 100644 index 0000000..81badc4 --- /dev/null +++ b/examples/impl_block/cron.rs @@ -0,0 +1,84 @@ +use restate_sdk::prelude::*; +use std::time::Duration; + +/// This example shows how to implement a periodic task, by invoking itself in a loop. +/// +/// The `start()` handler schedules the first call to `run()`, and then each `run()` will re-schedule itself. +/// +/// To "break" the loop, we use a flag we persist in state, which is removed when `stop()` is invoked. +/// Its presence determines whether the task is active or not. +/// +/// To start it: +/// +/// ```shell +/// $ curl -v http://localhost:8080/PeriodicTask/my-periodic-task/start +/// ``` +struct PeriodicTask; + +const ACTIVE: &str = "active"; + +#[restate_sdk::object] +impl PeriodicTask { + #[handler] + async fn start(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { + if context + .get::(ACTIVE) + .await? + .is_some_and(|enabled| enabled) + { + // If it's already activated, just do nothing + return Ok(()); + } + + // Schedule the periodic task + PeriodicTask::schedule_next(&context); + + // Mark the periodic task as active + context.set(ACTIVE, true); + + Ok(()) + } + + #[handler] + async fn stop(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { + // Remove the active flag + context.clear(ACTIVE); + + Ok(()) + } + + #[handler] + async fn run(&self, context: ObjectContext<'_>) -> Result<(), TerminalError> { + if context.get::(ACTIVE).await?.is_none() { + // Task is inactive, do nothing + return Ok(()); + } + + // --- Periodic task business logic! + println!("Triggered the periodic task!"); + + // Schedule the periodic task + PeriodicTask::schedule_next(&context); + + Ok(()) + } +} + +impl PeriodicTask { + fn schedule_next(context: &ObjectContext<'_>) { + // To schedule, create a client to the callee handler (in this case, we're calling ourselves) + context + .object_client::(context.key()) + .run() + // And send with a delay + .send_after(Duration::from_secs(10)); + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new(Endpoint::builder().bind(PeriodicTask.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/impl_block/failures.rs b/examples/impl_block/failures.rs new file mode 100644 index 0000000..5545d94 --- /dev/null +++ b/examples/impl_block/failures.rs @@ -0,0 +1,34 @@ +use rand::RngCore; +use restate_sdk::prelude::*; + +#[derive(Debug, thiserror::Error)] +#[error("I'm very bad, retry me")] +struct MyError; + +struct FailureExample; + +#[restate_sdk::service] +impl FailureExample { + #[handler(name = "doRun")] + async fn do_run(&self, context: Context<'_>) -> Result<(), TerminalError> { + context + .run::<_, _, ()>(|| async move { + if rand::thread_rng().next_u32() % 4 == 0 { + Err(TerminalError::new("Failed!!!"))? + } + + Err(MyError)? + }) + .await?; + + Ok(()) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new(Endpoint::builder().bind(FailureExample.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/impl_block/greeter.rs b/examples/impl_block/greeter.rs new file mode 100644 index 0000000..8d537ac --- /dev/null +++ b/examples/impl_block/greeter.rs @@ -0,0 +1,20 @@ +use restate_sdk::prelude::*; +use std::convert::Infallible; + +struct Greeter; + +#[restate_sdk::service] +impl Greeter { + #[handler] + async fn greet(&self, _ctx: Context<'_>, name: String) -> Result { + Ok(format!("Greetings {name}")) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new(Endpoint::builder().bind(Greeter.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/impl_block/run.rs b/examples/impl_block/run.rs new file mode 100644 index 0000000..0accdda --- /dev/null +++ b/examples/impl_block/run.rs @@ -0,0 +1,44 @@ +use restate_sdk::prelude::*; +use std::collections::HashMap; + +struct RunExample(reqwest::Client); + +#[restate_sdk::service] +impl RunExample { + #[handler] + async fn do_run( + &self, + context: Context<'_>, + ) -> Result>, HandlerError> { + let res = context + .run(|| async move { + let req = self.0.get("https://httpbin.org/ip").build()?; + + let res = self + .0 + .execute(req) + .await? + .json::>() + .await?; + + Ok(Json::from(res)) + }) + .name("get_ip") + .await? + .into_inner(); + + Ok(res.into()) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new( + Endpoint::builder() + .bind(RunExample(reqwest::Client::new()).serve()) + .build(), + ) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/services/mod.rs b/examples/impl_block/services/mod.rs similarity index 100% rename from examples/services/mod.rs rename to examples/impl_block/services/mod.rs diff --git a/examples/impl_block/services/my_service.rs b/examples/impl_block/services/my_service.rs new file mode 100644 index 0000000..1f2231d --- /dev/null +++ b/examples/impl_block/services/my_service.rs @@ -0,0 +1,19 @@ +use restate_sdk::prelude::*; + +pub struct MyService; + +#[restate_sdk::service(vis = "pub(crate)")] +impl MyService { + #[handler] + async fn my_handler(&self, _ctx: Context<'_>, greeting: String) -> Result { + Ok(format!("{greeting}!")) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new(Endpoint::builder().bind(MyService.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/impl_block/services/my_virtual_object.rs b/examples/impl_block/services/my_virtual_object.rs new file mode 100644 index 0000000..635dc4c --- /dev/null +++ b/examples/impl_block/services/my_virtual_object.rs @@ -0,0 +1,36 @@ +use restate_sdk::prelude::*; + +pub struct MyVirtualObject; + +#[restate_sdk::object(vis = "pub(crate)")] +impl MyVirtualObject { + #[handler] + async fn my_handler( + &self, + ctx: ObjectContext<'_>, + greeting: String, + ) -> Result { + Ok(format!("Greetings {} {}", greeting, ctx.key())) + } + + #[handler(shared)] + async fn my_concurrent_handler( + &self, + ctx: SharedObjectContext<'_>, + greeting: String, + ) -> Result { + Ok(format!("Greetings {} {}", greeting, ctx.key())) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new( + Endpoint::builder() + .bind(MyVirtualObject.serve()) + .build(), + ) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/impl_block/services/my_workflow.rs b/examples/impl_block/services/my_workflow.rs new file mode 100644 index 0000000..8cf5f65 --- /dev/null +++ b/examples/impl_block/services/my_workflow.rs @@ -0,0 +1,32 @@ +use restate_sdk::prelude::*; + +pub struct MyWorkflow; + +#[restate_sdk::workflow(vis = "pub(crate)")] +impl MyWorkflow { + #[handler] + async fn run(&self, _ctx: WorkflowContext<'_>, _req: String) -> Result { + // implement workflow logic here + + Ok(String::from("success")) + } + + #[handler(shared)] + async fn interact_with_workflow( + &self, + _ctx: SharedWorkflowContext<'_>, + ) -> Result<(), HandlerError> { + // implement interaction logic here + // e.g. resolve a promise that the workflow is waiting on + + Ok(()) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new(Endpoint::builder().bind(MyWorkflow.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/impl_block/tracing.rs b/examples/impl_block/tracing.rs new file mode 100644 index 0000000..dec3a9a --- /dev/null +++ b/examples/impl_block/tracing.rs @@ -0,0 +1,34 @@ +use restate_sdk::prelude::*; +use std::time::Duration; +use tracing::info; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer}; + +struct Greeter; + +#[restate_sdk::service] +impl Greeter { + #[handler] + async fn greet(&self, ctx: Context<'_>, name: String) -> Result { + info!("Before sleep"); + ctx.sleep(Duration::from_secs(61)).await?; // More than suspension timeout to trigger replay + info!("After sleep"); + Ok(format!("Greetings {name}")) + } +} + +#[tokio::main] +async fn main() { + let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "info,restate_sdk=debug".into()); + let replay_filter = restate_sdk::filter::ReplayAwareFilter; + tracing_subscriber::registry() + .with( + tracing_subscriber::fmt::layer() + .with_filter(env_filter) + .with_filter(replay_filter), + ) + .init(); + HttpServer::new(Endpoint::builder().bind(Greeter.serve()).build()) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/examples/counter.rs b/examples/trait_block/counter.rs similarity index 100% rename from examples/counter.rs rename to examples/trait_block/counter.rs diff --git a/examples/cron.rs b/examples/trait_block/cron.rs similarity index 100% rename from examples/cron.rs rename to examples/trait_block/cron.rs diff --git a/examples/failures.rs b/examples/trait_block/failures.rs similarity index 100% rename from examples/failures.rs rename to examples/trait_block/failures.rs diff --git a/examples/greeter.rs b/examples/trait_block/greeter.rs similarity index 100% rename from examples/greeter.rs rename to examples/trait_block/greeter.rs diff --git a/examples/run.rs b/examples/trait_block/run.rs similarity index 100% rename from examples/run.rs rename to examples/trait_block/run.rs diff --git a/examples/schema.rs b/examples/trait_block/schema.rs similarity index 100% rename from examples/schema.rs rename to examples/trait_block/schema.rs diff --git a/examples/trait_block/services/mod.rs b/examples/trait_block/services/mod.rs new file mode 100644 index 0000000..fc734f9 --- /dev/null +++ b/examples/trait_block/services/mod.rs @@ -0,0 +1,3 @@ +pub mod my_service; +pub mod my_virtual_object; +pub mod my_workflow; \ No newline at end of file diff --git a/examples/services/my_service.rs b/examples/trait_block/services/my_service.rs similarity index 100% rename from examples/services/my_service.rs rename to examples/trait_block/services/my_service.rs diff --git a/examples/services/my_virtual_object.rs b/examples/trait_block/services/my_virtual_object.rs similarity index 100% rename from examples/services/my_virtual_object.rs rename to examples/trait_block/services/my_virtual_object.rs diff --git a/examples/services/my_workflow.rs b/examples/trait_block/services/my_workflow.rs similarity index 100% rename from examples/services/my_workflow.rs rename to examples/trait_block/services/my_workflow.rs diff --git a/examples/tracing.rs b/examples/trait_block/tracing.rs similarity index 100% rename from examples/tracing.rs rename to examples/trait_block/tracing.rs diff --git a/src/context/mod.rs b/src/context/mod.rs index d9b66fd..15fdfa7 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -273,7 +273,7 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextTimers<'ctx> for CTX {} /// You can do request-response calls to Services, Virtual Objects, and Workflows, in the following way: /// /// ```rust,no_run -/// # #[path = "../../examples/services/mod.rs"] +/// # #[path = "../../examples/trait_block/services/mod.rs"] /// # mod services; /// # use services::my_virtual_object::MyVirtualObjectClient; /// # use services::my_service::MyServiceClient; @@ -324,7 +324,7 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextTimers<'ctx> for CTX {} /// Handlers can send messages (a.k.a. one-way calls, or fire-and-forget calls), as follows: /// /// ```rust,no_run -/// # #[path = "../../examples/services/mod.rs"] +/// # #[path = "../../examples/trait_block/services/mod.rs"] /// # mod services; /// # use services::my_virtual_object::MyVirtualObjectClient; /// # use services::my_service::MyServiceClient; @@ -364,7 +364,7 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextTimers<'ctx> for CTX {} /// To schedule a delayed call, send a message with a delay parameter, as follows: /// /// ```rust,no_run -/// # #[path = "../../examples/services/mod.rs"] +/// # #[path = "../../examples/trait_block/services/mod.rs"] /// # mod services; /// # use services::my_virtual_object::MyVirtualObjectClient; /// # use services::my_service::MyServiceClient; @@ -403,7 +403,7 @@ impl<'ctx, CTX: private::SealedContext<'ctx>> ContextTimers<'ctx> for CTX {} /// For example, assume a handler calls the same Virtual Object twice: /// /// ```rust,no_run -/// # #[path = "../../examples/services/my_virtual_object.rs"] +/// # #[path = "../../examples/trait_block/services/my_virtual_object.rs"] /// # mod my_virtual_object; /// # use my_virtual_object::MyVirtualObjectClient; /// # use restate_sdk::prelude::*; diff --git a/src/http_server.rs b/src/http_server.rs index fb78a7a..46c8f32 100644 --- a/src/http_server.rs +++ b/src/http_server.rs @@ -7,7 +7,7 @@ //! 3. Listen on the specified port (default `9080`) for connections and requests. //! //! ```rust,no_run -//! # #[path = "../examples/services/mod.rs"] +//! # #[path = "../examples/trait_block/services/mod.rs"] //! # mod services; //! # use services::my_service::{MyService, MyServiceImpl}; //! # use services::my_virtual_object::{MyVirtualObject, MyVirtualObjectImpl}; @@ -37,7 +37,7 @@ //! Add the identity key to your endpoint as follows: //! //! ```rust,no_run -//! # #[path = "../examples/services/mod.rs"] +//! # #[path = "../examples/trait_block/services/mod.rs"] //! # mod services; //! # use services::my_service::{MyService, MyServiceImpl}; //! # use restate_sdk::endpoint::Endpoint; From 905d9669bf4a4affdbe5b4753e29f46f0654ac88 Mon Sep 17 00:00:00 2001 From: BigFish2086 Date: Thu, 3 Jul 2025 20:28:36 +0300 Subject: [PATCH 3/5] add test-services testcases to test the new impl-syntax #43 - updated Cargo.toml to have mutliple binaries, one that used the old macro trait-syntax, and another one for the new macro struct-impl syntax, so we can test them later, so now `cargo build -p test-services` would generate 2 binaries trait-test-services & impl-test-services - updated Dockerfile: - to cache the building process through the new mount syntax as well as building in multiple stages - to have a --build-arg which is BIN that deteremins which macro syntax to test when running the restate-sdk-test-suite.jar --- test-services/Cargo.toml | 8 + test-services/Dockerfile | 36 ++- .../src/impl_block/awakeable_holder.rs | 31 +++ .../src/impl_block/block_and_wait_workflow.rs | 40 +++ test-services/src/impl_block/cancel_test.rs | 87 +++++++ test-services/src/impl_block/counter.rs | 59 +++++ test-services/src/impl_block/failing.rs | 124 +++++++++ test-services/src/impl_block/kill_test.rs | 46 ++++ test-services/src/impl_block/list_object.rs | 38 +++ test-services/src/impl_block/main.rs | 93 +++++++ test-services/src/impl_block/map_object.rs | 49 ++++ test-services/src/impl_block/mod.rs | 1 + .../src/impl_block/non_deterministic.rs | 88 +++++++ test-services/src/impl_block/proxy.rs | 119 +++++++++ .../src/impl_block/test_utils_service.rs | 94 +++++++ .../virtual_object_command_interpreter.rs | 246 ++++++++++++++++++ test-services/src/mod.rs | 2 + .../src/{ => trait_block}/awakeable_holder.rs | 0 .../block_and_wait_workflow.rs | 0 .../src/{ => trait_block}/cancel_test.rs | 0 .../src/{ => trait_block}/counter.rs | 0 .../src/{ => trait_block}/failing.rs | 0 .../src/{ => trait_block}/kill_test.rs | 0 .../src/{ => trait_block}/list_object.rs | 0 test-services/src/{ => trait_block}/main.rs | 0 .../src/{ => trait_block}/map_object.rs | 0 test-services/src/trait_block/mod.rs | 1 + .../{ => trait_block}/non_deterministic.rs | 0 test-services/src/{ => trait_block}/proxy.rs | 0 .../{ => trait_block}/test_utils_service.rs | 0 .../virtual_object_command_interpreter.rs | 0 31 files changed, 1156 insertions(+), 6 deletions(-) create mode 100644 test-services/src/impl_block/awakeable_holder.rs create mode 100644 test-services/src/impl_block/block_and_wait_workflow.rs create mode 100644 test-services/src/impl_block/cancel_test.rs create mode 100644 test-services/src/impl_block/counter.rs create mode 100644 test-services/src/impl_block/failing.rs create mode 100644 test-services/src/impl_block/kill_test.rs create mode 100644 test-services/src/impl_block/list_object.rs create mode 100644 test-services/src/impl_block/main.rs create mode 100644 test-services/src/impl_block/map_object.rs create mode 100644 test-services/src/impl_block/mod.rs create mode 100644 test-services/src/impl_block/non_deterministic.rs create mode 100644 test-services/src/impl_block/proxy.rs create mode 100644 test-services/src/impl_block/test_utils_service.rs create mode 100644 test-services/src/impl_block/virtual_object_command_interpreter.rs create mode 100644 test-services/src/mod.rs rename test-services/src/{ => trait_block}/awakeable_holder.rs (100%) rename test-services/src/{ => trait_block}/block_and_wait_workflow.rs (100%) rename test-services/src/{ => trait_block}/cancel_test.rs (100%) rename test-services/src/{ => trait_block}/counter.rs (100%) rename test-services/src/{ => trait_block}/failing.rs (100%) rename test-services/src/{ => trait_block}/kill_test.rs (100%) rename test-services/src/{ => trait_block}/list_object.rs (100%) rename test-services/src/{ => trait_block}/main.rs (100%) rename test-services/src/{ => trait_block}/map_object.rs (100%) create mode 100644 test-services/src/trait_block/mod.rs rename test-services/src/{ => trait_block}/non_deterministic.rs (100%) rename test-services/src/{ => trait_block}/proxy.rs (100%) rename test-services/src/{ => trait_block}/test_utils_service.rs (100%) rename test-services/src/{ => trait_block}/virtual_object_command_interpreter.rs (100%) diff --git a/test-services/Cargo.toml b/test-services/Cargo.toml index 477e793..776f128 100644 --- a/test-services/Cargo.toml +++ b/test-services/Cargo.toml @@ -14,3 +14,11 @@ restate-sdk = { path = "..", features = ["schemars"] } schemars = "1.0.0-alpha.17" serde = { version = "1", features = ["derive"] } tracing = "0.1.40" + +[[bin]] +name = "trait-test-services" +path = "src/trait_block/main.rs" + +[[bin]] +name = "impl-test-services" +path = "src/impl_block/main.rs" diff --git a/test-services/Dockerfile b/test-services/Dockerfile index 50e4ff7..ba580b9 100644 --- a/test-services/Dockerfile +++ b/test-services/Dockerfile @@ -1,12 +1,36 @@ -FROM rust:1.81 +# syntax=docker/dockerfile:1.6 +ARG BIN=trait-test-services + +FROM rust:1.81 AS builder +ARG BIN WORKDIR /app +RUN <<'EOSH' + mkdir -p /usr/local/cargo /app/target +EOSH + +ENV CARGO_HOME=/usr/local/cargo + +COPY Cargo.toml Cargo.lock ./ +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=/app/target \ + cargo build -p test-services || true + COPY . . -RUN cargo build -p test-services -RUN cp ./target/debug/test-services /bin/server -ENV RUST_LOG="debug,restate_shared_core=trace" -ENV RUST_BACKTRACE=1 +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=/app/target \ + cargo build -p test-services && cp /app/target/debug/${BIN} /app/${BIN} + +################################### + +FROM debian:bookworm-slim +ARG BIN + +COPY --from=builder /app/${BIN} /bin/server + +ENV RUST_LOG="debug,restate_shared_core=trace" \ + RUST_BACKTRACE=1 -CMD ["/bin/server"] \ No newline at end of file +CMD ["/bin/server"] diff --git a/test-services/src/impl_block/awakeable_holder.rs b/test-services/src/impl_block/awakeable_holder.rs new file mode 100644 index 0000000..ae7c5d8 --- /dev/null +++ b/test-services/src/impl_block/awakeable_holder.rs @@ -0,0 +1,31 @@ +use restate_sdk::prelude::*; + +pub(crate) struct AwakeableHolder; + +const ID: &str = "id"; + +#[restate_sdk::object(vis = "pub(crate)", name = "AwakeableHolder")] +impl AwakeableHolder { + #[handler(name = "hold")] + async fn hold(&self, context: ObjectContext<'_>, id: String) -> HandlerResult<()> { + context.set(ID, id); + Ok(()) + } + + #[handler(shared, name = "hasAwakeable")] + async fn has_awakeable(&self, context: SharedObjectContext<'_>) -> HandlerResult { + Ok(context.get::(ID).await?.is_some()) + } + + #[handler(name = "unlock")] + async fn unlock(&self, context: ObjectContext<'_>, payload: String) -> HandlerResult<()> { + let k: String = context.get(ID).await?.ok_or_else(|| { + TerminalError::new(format!( + "No awakeable stored for awakeable holder {}", + context.key() + )) + })?; + context.resolve_awakeable(&k, payload); + Ok(()) + } +} diff --git a/test-services/src/impl_block/block_and_wait_workflow.rs b/test-services/src/impl_block/block_and_wait_workflow.rs new file mode 100644 index 0000000..591aa79 --- /dev/null +++ b/test-services/src/impl_block/block_and_wait_workflow.rs @@ -0,0 +1,40 @@ +use restate_sdk::prelude::*; + +pub(crate) struct BlockAndWaitWorkflow; + +const MY_PROMISE: &str = "my-promise"; +const MY_STATE: &str = "my-state"; + +#[restate_sdk::workflow(vis = "pub(crate)", name = "BlockAndWaitWorkflow")] +impl BlockAndWaitWorkflow { + #[handler(name = "run")] + async fn run(&self, context: WorkflowContext<'_>, input: String) -> HandlerResult { + context.set(MY_STATE, input); + + let promise: String = context.promise(MY_PROMISE).await?; + + if context.peek_promise::(MY_PROMISE).await?.is_none() { + return Err(TerminalError::new("Durable promise should be completed").into()); + } + + Ok(promise) + } + + #[handler(shared, name = "unblock")] + async fn unblock( + &self, + context: SharedWorkflowContext<'_>, + output: String, + ) -> HandlerResult<()> { + context.resolve_promise(MY_PROMISE, output); + Ok(()) + } + + #[handler(shared, name = "getState")] + async fn get_state( + &self, + context: SharedWorkflowContext<'_>, + ) -> HandlerResult>> { + Ok(Json(context.get::(MY_STATE).await?)) + } +} diff --git a/test-services/src/impl_block/cancel_test.rs b/test-services/src/impl_block/cancel_test.rs new file mode 100644 index 0000000..694167f --- /dev/null +++ b/test-services/src/impl_block/cancel_test.rs @@ -0,0 +1,87 @@ +use crate::awakeable_holder; +use anyhow::anyhow; +use restate_sdk::prelude::*; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub(crate) enum BlockingOperation { + Call, + Sleep, + Awakeable, +} + +pub(crate) struct CancelTestRunner; + +const CANCELED: &str = "canceled"; + +#[restate_sdk::object(vis = "pub(crate)", name = "CancelTestRunner")] +impl CancelTestRunner { + #[handler(name = "startTest")] + async fn start_test( + &self, + context: ObjectContext<'_>, + op: Json, + ) -> HandlerResult<()> { + let this = context.object_client::(context.key()); + + match this.block(op).call().await { + Ok(_) => Err(anyhow!("Block succeeded, this is unexpected").into()), + Err(e) if e.code() == 409 => { + context.set(CANCELED, true); + Ok(()) + } + Err(e) => Err(e.into()), + } + } + + #[handler(name = "verifyTest")] + async fn verify_test(&self, context: ObjectContext<'_>) -> HandlerResult { + Ok(context.get::(CANCELED).await?.unwrap_or(false)) + } +} + +pub(crate) struct CancelTestBlockingService; + +#[restate_sdk::object(vis = "pub(crate)", name = "CancelTestBlockingService")] +impl CancelTestBlockingService { + #[handler(name = "block")] + async fn block( + &self, + context: ObjectContext<'_>, + op: Json, + ) -> HandlerResult<()> { + let this = context.object_client::(context.key()); + let awakeable_holder_client = + context.object_client::(context.key()); + + let (awk_id, awakeable) = context.awakeable::(); + awakeable_holder_client.hold(awk_id).call().await?; + awakeable.await?; + + match &op.0 { + BlockingOperation::Call => { + this.block(op).call().await?; + } + BlockingOperation::Sleep => { + context + .sleep(Duration::from_secs(60 * 60 * 24 * 1024)) + .await?; + } + BlockingOperation::Awakeable => { + let (_, uncompletable) = context.awakeable::(); + uncompletable.await?; + } + } + + Ok(()) + } + + #[handler(name = "isUnlocked")] + async fn is_unlocked(&self, _ctx: ObjectContext<'_>) -> HandlerResult<()> { + // no-op + Ok(()) + } +} diff --git a/test-services/src/impl_block/counter.rs b/test-services/src/impl_block/counter.rs new file mode 100644 index 0000000..a4ab08a --- /dev/null +++ b/test-services/src/impl_block/counter.rs @@ -0,0 +1,59 @@ +use restate_sdk::prelude::*; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use tracing::info; + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct CounterUpdateResponse { + old_value: u64, + new_value: u64, +} + +pub(crate) struct Counter; + +const COUNT: &str = "counter"; + +#[restate_sdk::object(vis = "pub(crate)", name = "Counter")] +impl Counter { + #[handler(shared, name = "get")] + async fn get(&self, ctx: SharedObjectContext<'_>) -> HandlerResult { + Ok(ctx.get::(COUNT).await?.unwrap_or(0)) + } + + #[handler(name = "add")] + async fn add( + &self, + ctx: ObjectContext<'_>, + val: u64, + ) -> HandlerResult> { + let current = ctx.get::(COUNT).await?.unwrap_or(0); + let new = current + val; + ctx.set(COUNT, new); + + info!("Old count {}, new count {}", current, new); + + Ok(CounterUpdateResponse { + old_value: current, + new_value: new, + } + .into()) + } + + #[handler(name = "reset")] + async fn reset(&self, ctx: ObjectContext<'_>) -> HandlerResult<()> { + ctx.clear(COUNT); + Ok(()) + } + + #[handler(name = "addThenFail")] + async fn add_then_fail(&self, ctx: ObjectContext<'_>, val: u64) -> HandlerResult<()> { + let current = ctx.get::(COUNT).await?.unwrap_or(0); + let new = current + val; + ctx.set(COUNT, new); + + info!("Old count {}, new count {}", current, new); + + Err(TerminalError::new(ctx.key()).into()) + } +} diff --git a/test-services/src/impl_block/failing.rs b/test-services/src/impl_block/failing.rs new file mode 100644 index 0000000..58326b0 --- /dev/null +++ b/test-services/src/impl_block/failing.rs @@ -0,0 +1,124 @@ +use anyhow::anyhow; +use restate_sdk::prelude::*; +use std::sync::atomic::{AtomicI32, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +#[derive(Clone, Default)] +pub(crate) struct Failing { + eventual_success_calls: Arc, + eventual_success_side_effects: Arc, + eventual_failure_side_effects: Arc, +} + +#[restate_sdk::object(vis = "pub(crate)", name = "Failing")] +impl Failing { + #[handler(name = "terminallyFailingCall")] + async fn terminally_failing_call( + &self, + _ctx: ObjectContext<'_>, + error_message: String, + ) -> HandlerResult<()> { + Err(TerminalError::new(error_message).into()) + } + + #[handler(name = "callTerminallyFailingCall")] + async fn call_terminally_failing_call( + &self, + mut context: ObjectContext<'_>, + error_message: String, + ) -> HandlerResult { + let uuid = context.rand_uuid().to_string(); + context + .object_client::(uuid) + .terminally_failing_call(error_message) + .call() + .await?; + + unreachable!("This should be unreachable") + } + + #[handler(name = "failingCallWithEventualSuccess")] + async fn failing_call_with_eventual_success( + &self, + _ctx: ObjectContext<'_>, + ) -> HandlerResult { + let current_attempt = self.eventual_success_calls.fetch_add(1, Ordering::SeqCst) + 1; + + if current_attempt >= 4 { + self.eventual_success_calls.store(0, Ordering::SeqCst); + Ok(current_attempt) + } else { + Err(anyhow!("Failed at attempt ${current_attempt}").into()) + } + } + + #[handler(name = "terminallyFailingSideEffect")] + async fn terminally_failing_side_effect( + &self, + context: ObjectContext<'_>, + error_message: String, + ) -> HandlerResult<()> { + context + .run::<_, _, ()>(|| async move { Err(TerminalError::new(error_message))? }) + .await?; + + unreachable!("This should be unreachable") + } + + #[handler(name = "sideEffectSucceedsAfterGivenAttempts")] + async fn side_effect_succeeds_after_given_attempts( + &self, + context: ObjectContext<'_>, + minimum_attempts: i32, + ) -> HandlerResult { + let cloned_counter = Arc::clone(&self.eventual_success_side_effects); + let success_attempt = context + .run(|| async move { + let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; + + if current_attempt >= minimum_attempts { + cloned_counter.store(0, Ordering::SeqCst); + Ok(current_attempt) + } else { + Err(anyhow!("Failed at attempt {current_attempt}"))? + } + }) + .retry_policy( + RunRetryPolicy::new() + .initial_delay(Duration::from_millis(10)) + .exponentiation_factor(1.0), + ) + .name("failing_side_effect") + .await?; + + Ok(success_attempt) + } + + #[handler(name = "sideEffectFailsAfterGivenAttempts")] + async fn side_effect_fails_after_given_attempts( + &self, + context: ObjectContext<'_>, + retry_policy_max_retry_count: i32, + ) -> HandlerResult { + let cloned_counter = Arc::clone(&self.eventual_failure_side_effects); + if context + .run(|| async move { + let current_attempt = cloned_counter.fetch_add(1, Ordering::SeqCst) + 1; + Err::<(), _>(anyhow!("Failed at attempt {current_attempt}").into()) + }) + .retry_policy( + RunRetryPolicy::new() + .initial_delay(Duration::from_millis(10)) + .exponentiation_factor(1.0) + .max_attempts(retry_policy_max_retry_count as u32), + ) + .await + .is_err() + { + Ok(self.eventual_failure_side_effects.load(Ordering::SeqCst)) + } else { + Err(TerminalError::new("Expecting the side effect to fail!"))? + } + } +} diff --git a/test-services/src/impl_block/kill_test.rs b/test-services/src/impl_block/kill_test.rs new file mode 100644 index 0000000..a7a960c --- /dev/null +++ b/test-services/src/impl_block/kill_test.rs @@ -0,0 +1,46 @@ +use crate::awakeable_holder; +use restate_sdk::prelude::*; + +pub(crate) struct KillTestRunner; + +#[restate_sdk::object(vis = "pub(crate)", name = "KillTestRunner")] +impl KillTestRunner { + #[handler(name = "startCallTree")] + async fn start_call_tree(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + context + .object_client::(context.key()) + .recursive_call() + .call() + .await?; + Ok(()) + } +} + +pub(crate) struct KillTestSingleton; + +#[restate_sdk::object(vis = "pub(crate)", name = "KillTestSingleton")] +impl KillTestSingleton { + #[handler(name = "recursiveCall")] + async fn recursive_call(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + let awakeable_holder_client = + context.object_client::(context.key()); + + let (awk_id, awakeable) = context.awakeable::<()>(); + awakeable_holder_client.hold(awk_id).send(); + awakeable.await?; + + context + .object_client::(context.key()) + .recursive_call() + .call() + .await?; + + Ok(()) + } + + #[handler(name = "isUnlocked")] + async fn is_unlocked(&self, _ctx: ObjectContext<'_>) -> HandlerResult<()> { + // no-op + Ok(()) + } +} diff --git a/test-services/src/impl_block/list_object.rs b/test-services/src/impl_block/list_object.rs new file mode 100644 index 0000000..3cd99d1 --- /dev/null +++ b/test-services/src/impl_block/list_object.rs @@ -0,0 +1,38 @@ +use restate_sdk::prelude::*; + +pub(crate) struct ListObject; + +const LIST: &str = "list"; + +#[restate_sdk::object(vis = "pub(crate)", name = "ListObject")] +impl ListObject { + #[handler(name = "append")] + async fn append(&self, ctx: ObjectContext<'_>, value: String) -> HandlerResult<()> { + let mut list = ctx + .get::>>(LIST) + .await? + .unwrap_or_default() + .into_inner(); + list.push(value); + ctx.set(LIST, Json(list)); + Ok(()) + } + + #[handler(name = "get")] + async fn get(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { + Ok(ctx + .get::>>(LIST) + .await? + .unwrap_or_default()) + } + + #[handler(name = "clear")] + async fn clear(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { + let get = ctx + .get::>>(LIST) + .await? + .unwrap_or_default(); + ctx.clear(LIST); + Ok(get) + } +} diff --git a/test-services/src/impl_block/main.rs b/test-services/src/impl_block/main.rs new file mode 100644 index 0000000..249e4a6 --- /dev/null +++ b/test-services/src/impl_block/main.rs @@ -0,0 +1,93 @@ +mod awakeable_holder; +mod block_and_wait_workflow; +mod cancel_test; +mod counter; +mod failing; +mod kill_test; +mod list_object; +mod map_object; +mod non_deterministic; +mod proxy; +mod test_utils_service; +mod virtual_object_command_interpreter; + +use restate_sdk::prelude::{Endpoint, HttpServer}; +use std::env; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + let port = env::var("PORT").ok().unwrap_or("9080".to_string()); + let services = env::var("SERVICES").ok().unwrap_or("*".to_string()); + + let mut builder = Endpoint::builder(); + + if services == "*" || services.contains("Counter") { + builder = builder.bind(counter::Counter::serve(counter::Counter)) + } + if services == "*" || services.contains("Proxy") { + builder = builder.bind(proxy::Proxy::serve(proxy::Proxy)) + } + if services == "*" || services.contains("MapObject") { + builder = builder.bind(map_object::MapObject::serve(map_object::MapObject)) + } + if services == "*" || services.contains("ListObject") { + builder = builder.bind(list_object::ListObject::serve(list_object::ListObject)) + } + if services == "*" || services.contains("AwakeableHolder") { + builder = builder.bind(awakeable_holder::AwakeableHolder::serve( + awakeable_holder::AwakeableHolder, + )) + } + if services == "*" || services.contains("BlockAndWaitWorkflow") { + builder = builder.bind(block_and_wait_workflow::BlockAndWaitWorkflow::serve( + block_and_wait_workflow::BlockAndWaitWorkflow, + )) + } + if services == "*" || services.contains("CancelTestRunner") { + builder = builder.bind(cancel_test::CancelTestRunner::serve( + cancel_test::CancelTestRunner, + )) + } + if services == "*" || services.contains("CancelTestBlockingService") { + builder = builder.bind(cancel_test::CancelTestBlockingService::serve( + cancel_test::CancelTestBlockingService, + )) + } + if services == "*" || services.contains("Failing") { + builder = builder.bind(failing::Failing::serve(failing::Failing::default())) + } + if services == "*" || services.contains("KillTestRunner") { + builder = builder.bind(kill_test::KillTestRunner::serve(kill_test::KillTestRunner)) + } + if services == "*" || services.contains("KillTestSingleton") { + builder = builder.bind(kill_test::KillTestSingleton::serve( + kill_test::KillTestSingleton, + )) + } + if services == "*" || services.contains("NonDeterministic") { + builder = builder.bind(non_deterministic::NonDeterministic::serve( + non_deterministic::NonDeterministic::default(), + )) + } + if services == "*" || services.contains("TestUtilsService") { + builder = builder.bind(test_utils_service::TestUtilsService::serve( + test_utils_service::TestUtilsService, + )) + } + if services == "*" || services.contains("VirtualObjectCommandInterpreter") { + builder = builder.bind( + virtual_object_command_interpreter::VirtualObjectCommandInterpreter::serve( + virtual_object_command_interpreter::VirtualObjectCommandInterpreter, + ), + ) + } + + if let Ok(key) = env::var("E2E_REQUEST_SIGNING_ENV") { + builder = builder.identity_key(&key).unwrap() + } + + HttpServer::new(builder.build()) + .listen_and_serve(format!("0.0.0.0:{port}").parse().unwrap()) + .await; +} diff --git a/test-services/src/impl_block/map_object.rs b/test-services/src/impl_block/map_object.rs new file mode 100644 index 0000000..ebf7640 --- /dev/null +++ b/test-services/src/impl_block/map_object.rs @@ -0,0 +1,49 @@ +use anyhow::anyhow; +use restate_sdk::prelude::*; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct Entry { + key: String, + value: String, +} + +pub(crate) struct MapObject; + +#[restate_sdk::object(vis = "pub(crate)", name = "MapObject")] +impl MapObject { + #[handler(name = "set")] + async fn set( + &self, + ctx: ObjectContext<'_>, + Json(Entry { key, value }): Json, + ) -> HandlerResult<()> { + ctx.set(&key, value); + Ok(()) + } + + #[handler(name = "get")] + async fn get(&self, ctx: ObjectContext<'_>, key: String) -> HandlerResult { + Ok(ctx.get(&key).await?.unwrap_or_default()) + } + + #[handler(name = "clearAll")] + async fn clear_all(&self, ctx: ObjectContext<'_>) -> HandlerResult>> { + let keys = ctx.get_keys().await?; + + let mut entries = vec![]; + for k in keys { + let value = ctx + .get(&k) + .await? + .ok_or_else(|| anyhow!("Missing key {k}"))?; + entries.push(Entry { key: k, value }) + } + + ctx.clear_all(); + + Ok(entries.into()) + } +} diff --git a/test-services/src/impl_block/mod.rs b/test-services/src/impl_block/mod.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/test-services/src/impl_block/mod.rs @@ -0,0 +1 @@ + diff --git a/test-services/src/impl_block/non_deterministic.rs b/test-services/src/impl_block/non_deterministic.rs new file mode 100644 index 0000000..32feb77 --- /dev/null +++ b/test-services/src/impl_block/non_deterministic.rs @@ -0,0 +1,88 @@ +use crate::counter::CounterClient; +use restate_sdk::prelude::*; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; + +#[derive(Clone, Default)] +pub(crate) struct NonDeterministic(Arc>>); + +const STATE_A: &str = "a"; +const STATE_B: &str = "b"; + +#[restate_sdk::object(vis = "pub(crate)", name = "NonDeterministic")] +impl NonDeterministic { + #[handler(name = "eitherSleepOrCall")] + async fn either_sleep_or_call(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + if self.do_left_action(&context).await { + context.sleep(Duration::from_millis(100)).await?; + } else { + context + .object_client::("abc") + .get() + .call() + .await?; + } + Self::sleep_then_increment_counter(&context).await + } + + #[handler(name = "callDifferentMethod")] + async fn call_different_method(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + if self.do_left_action(&context).await { + context + .object_client::("abc") + .get() + .call() + .await?; + } else { + context + .object_client::("abc") + .reset() + .call() + .await?; + } + Self::sleep_then_increment_counter(&context).await + } + + #[handler(name = "backgroundInvokeWithDifferentTargets")] + async fn background_invoke_with_different_targets( + &self, + context: ObjectContext<'_>, + ) -> HandlerResult<()> { + if self.do_left_action(&context).await { + context.object_client::("abc").get().send(); + } else { + context.object_client::("abc").reset().send(); + } + Self::sleep_then_increment_counter(&context).await + } + + #[handler(name = "setDifferentKey")] + async fn set_different_key(&self, context: ObjectContext<'_>) -> HandlerResult<()> { + if self.do_left_action(&context).await { + context.set(STATE_A, "my-state".to_owned()); + } else { + context.set(STATE_B, "my-state".to_owned()); + } + Self::sleep_then_increment_counter(&context).await + } +} + +impl NonDeterministic { + async fn do_left_action(&self, ctx: &ObjectContext<'_>) -> bool { + let mut counts = self.0.lock().await; + *(counts + .entry(ctx.key().to_owned()) + .and_modify(|i| *i += 1) + .or_default()) + % 2 + == 1 + } + + async fn sleep_then_increment_counter(ctx: &ObjectContext<'_>) -> HandlerResult<()> { + ctx.sleep(Duration::from_millis(100)).await?; + ctx.object_client::(ctx.key()).add(1).send(); + Ok(()) + } +} diff --git a/test-services/src/impl_block/proxy.rs b/test-services/src/impl_block/proxy.rs new file mode 100644 index 0000000..2604d0e --- /dev/null +++ b/test-services/src/impl_block/proxy.rs @@ -0,0 +1,119 @@ +use futures::future::BoxFuture; +use futures::FutureExt; +use restate_sdk::context::RequestTarget; +use restate_sdk::prelude::*; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ProxyRequest { + service_name: String, + virtual_object_key: Option, + handler_name: String, + idempotency_key: Option, + message: Vec, + delay_millis: Option, +} + +impl ProxyRequest { + fn to_target(&self) -> RequestTarget { + if let Some(key) = &self.virtual_object_key { + RequestTarget::Object { + name: self.service_name.clone(), + key: key.clone(), + handler: self.handler_name.clone(), + } + } else { + RequestTarget::Service { + name: self.service_name.clone(), + handler: self.handler_name.clone(), + } + } + } +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ManyCallRequest { + proxy_request: ProxyRequest, + one_way_call: bool, + await_at_the_end: bool, +} + +pub(crate) struct Proxy; + +#[restate_sdk::service(vis = "pub(crate)", name = "Proxy")] +impl Proxy { + #[handler(name = "call")] + async fn call( + &self, + ctx: Context<'_>, + Json(req): Json, + ) -> HandlerResult>> { + let mut request = ctx.request::, Vec>(req.to_target(), req.message); + if let Some(idempotency_key) = req.idempotency_key { + request = request.idempotency_key(idempotency_key); + } + Ok(request.call().await?.into()) + } + + #[handler(name = "oneWayCall")] + async fn one_way_call( + &self, + ctx: Context<'_>, + Json(req): Json, + ) -> HandlerResult { + let mut request = ctx.request::<_, ()>(req.to_target(), req.message); + if let Some(idempotency_key) = req.idempotency_key { + request = request.idempotency_key(idempotency_key); + } + + let invocation_id = if let Some(delay_millis) = req.delay_millis { + request + .send_after(Duration::from_millis(delay_millis)) + .invocation_id() + .await? + } else { + request.send().invocation_id().await? + }; + + Ok(invocation_id) + } + + #[handler(name = "manyCalls")] + async fn many_calls( + &self, + ctx: Context<'_>, + Json(requests): Json>, + ) -> HandlerResult<()> { + let mut futures: Vec, TerminalError>>> = vec![]; + + for req in requests { + let mut restate_req = + ctx.request::<_, Vec>(req.proxy_request.to_target(), req.proxy_request.message); + if let Some(idempotency_key) = req.proxy_request.idempotency_key { + restate_req = restate_req.idempotency_key(idempotency_key); + } + if req.one_way_call { + if let Some(delay_millis) = req.proxy_request.delay_millis { + restate_req.send_after(Duration::from_millis(delay_millis)); + } else { + restate_req.send(); + } + } else { + let fut = restate_req.call(); + if req.await_at_the_end { + futures.push(fut.boxed()) + } + } + } + + for fut in futures { + fut.await?; + } + + Ok(()) + } +} diff --git a/test-services/src/impl_block/test_utils_service.rs b/test-services/src/impl_block/test_utils_service.rs new file mode 100644 index 0000000..6bd88dd --- /dev/null +++ b/test-services/src/impl_block/test_utils_service.rs @@ -0,0 +1,94 @@ +use futures::future::BoxFuture; +use futures::FutureExt; +use restate_sdk::prelude::*; +use std::collections::HashMap; +use std::convert::Infallible; +use std::sync::atomic::{AtomicU8, Ordering}; +use std::sync::Arc; +use std::time::Duration; + +pub(crate) struct TestUtilsService; + +#[restate_sdk::service(vis = "pub(crate)", name = "TestUtilsService")] +impl TestUtilsService { + #[handler(name = "echo")] + async fn echo(&self, _ctx: Context<'_>, input: String) -> HandlerResult { + Ok(input) + } + + #[handler(name = "uppercaseEcho")] + async fn uppercase_echo(&self, _ctx: Context<'_>, input: String) -> HandlerResult { + Ok(input.to_ascii_uppercase()) + } + + #[handler(name = "rawEcho")] + async fn raw_echo(&self, _ctx: Context<'_>, input: Vec) -> Result, Infallible> { + Ok(input) + } + + #[handler(name = "echoHeaders")] + async fn echo_headers( + &self, + context: Context<'_>, + ) -> HandlerResult>> { + let mut headers = HashMap::new(); + for k in context.headers().keys() { + headers.insert( + k.as_str().to_owned(), + context.headers().get(k).unwrap().clone(), + ); + } + + Ok(headers.into()) + } + + #[handler(name = "sleepConcurrently")] + async fn sleep_concurrently( + &self, + context: Context<'_>, + millis_durations: Json>, + ) -> HandlerResult<()> { + let mut futures: Vec>> = vec![]; + + for duration in millis_durations.into_inner() { + futures.push(context.sleep(Duration::from_millis(duration)).boxed()); + } + + for fut in futures { + fut.await?; + } + + Ok(()) + } + + #[handler(name = "countExecutedSideEffects")] + async fn count_executed_side_effects( + &self, + context: Context<'_>, + increments: u32, + ) -> HandlerResult { + let counter: Arc = Default::default(); + + for _ in 0..increments { + let counter_clone = Arc::clone(&counter); + context + .run(|| async { + counter_clone.fetch_add(1, Ordering::SeqCst); + Ok(()) + }) + .await?; + } + + Ok(counter.load(Ordering::SeqCst) as u32) + } + + #[handler(name = "cancelInvocation")] + async fn cancel_invocation( + &self, + ctx: Context<'_>, + invocation_id: String, + ) -> Result<(), TerminalError> { + ctx.invocation_handle(invocation_id).cancel().await?; + Ok(()) + } +} diff --git a/test-services/src/impl_block/virtual_object_command_interpreter.rs b/test-services/src/impl_block/virtual_object_command_interpreter.rs new file mode 100644 index 0000000..e0a9053 --- /dev/null +++ b/test-services/src/impl_block/virtual_object_command_interpreter.rs @@ -0,0 +1,246 @@ +use anyhow::anyhow; +use futures::TryFutureExt; +use restate_sdk::prelude::*; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct InterpretRequest { + commands: Vec, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type")] +#[serde(rename_all_fields = "camelCase")] +pub(crate) enum Command { + #[serde(rename = "awaitAnySuccessful")] + AwaitAnySuccessful { commands: Vec }, + #[serde(rename = "awaitAny")] + AwaitAny { commands: Vec }, + #[serde(rename = "awaitOne")] + AwaitOne { command: AwaitableCommand }, + #[serde(rename = "awaitAwakeableOrTimeout")] + AwaitAwakeableOrTimeout { + awakeable_key: String, + timeout_millis: u64, + }, + #[serde(rename = "resolveAwakeable")] + ResolveAwakeable { + awakeable_key: String, + value: String, + }, + #[serde(rename = "rejectAwakeable")] + RejectAwakeable { + awakeable_key: String, + reason: String, + }, + #[serde(rename = "getEnvVariable")] + GetEnvVariable { env_name: String }, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(tag = "type")] +#[serde(rename_all_fields = "camelCase")] +pub(crate) enum AwaitableCommand { + #[serde(rename = "createAwakeable")] + CreateAwakeable { awakeable_key: String }, + #[serde(rename = "sleep")] + Sleep { timeout_millis: u64 }, + #[serde(rename = "runThrowTerminalException")] + RunThrowTerminalException { reason: String }, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct ResolveAwakeable { + awakeable_key: String, + value: String, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "camelCase")] +pub(crate) struct RejectAwakeable { + awakeable_key: String, + reason: String, +} + +pub(crate) struct VirtualObjectCommandInterpreter; + +#[restate_sdk::object(vis = "pub(crate)", name = "VirtualObjectCommandInterpreter")] +impl VirtualObjectCommandInterpreter { + #[handler(name = "interpretCommands")] + async fn interpret_commands( + &self, + context: ObjectContext<'_>, + Json(req): Json, + ) -> HandlerResult { + let mut last_result: String = Default::default(); + + for cmd in req.commands { + match cmd { + Command::AwaitAny { .. } => { + Err(anyhow!("AwaitAny is currently unsupported in the Rust SDK"))? + } + Command::AwaitAnySuccessful { .. } => Err(anyhow!( + "AwaitAnySuccessful is currently unsupported in the Rust SDK" + ))?, + Command::AwaitAwakeableOrTimeout { + awakeable_key, + timeout_millis, + } => { + let (awakeable_id, awk_fut) = context.awakeable::(); + context.set::(&format!("awk-{awakeable_key}"), awakeable_id); + + last_result = restate_sdk::select! { + res = awk_fut => { + res + }, + _ = context.sleep(Duration::from_millis(timeout_millis)) => { + Err(TerminalError::new("await-timeout")) + } + }?; + } + Command::AwaitOne { command } => { + last_result = match command { + AwaitableCommand::CreateAwakeable { awakeable_key } => { + let (awakeable_id, fut) = context.awakeable::(); + context.set::(&format!("awk-{awakeable_key}"), awakeable_id); + fut.await? + } + AwaitableCommand::Sleep { timeout_millis } => { + context + .sleep(Duration::from_millis(timeout_millis)) + .map_ok(|_| "sleep".to_string()) + .await? + } + AwaitableCommand::RunThrowTerminalException { reason } => { + context + .run::<_, _, String>( + || async move { Err(TerminalError::new(reason))? }, + ) + .await? + } + } + } + Command::GetEnvVariable { env_name } => { + last_result = std::env::var(env_name).ok().unwrap_or_default(); + } + Command::ResolveAwakeable { + awakeable_key, + value, + } => { + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.resolve_awakeable(&awakeable_id, value); + last_result = Default::default(); + } + Command::RejectAwakeable { + awakeable_key, + reason, + } => { + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.reject_awakeable(&awakeable_id, TerminalError::new(reason)); + last_result = Default::default(); + } + } + + let mut old_results = context + .get::>>("results") + .await? + .unwrap_or_default() + .into_inner(); + old_results.push(last_result.clone()); + context.set("results", Json(old_results)); + } + + Ok(last_result) + } + + #[handler(name = "resolveAwakeable", shared)] + async fn resolve_awakeable( + &self, + context: SharedObjectContext<'_>, + req: Json, + ) -> Result<(), HandlerError> { + let ResolveAwakeable { + awakeable_key, + value, + } = req.into_inner(); + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.resolve_awakeable(&awakeable_id, value); + + Ok(()) + } + + #[handler(name = "rejectAwakeable", shared)] + async fn reject_awakeable( + &self, + context: SharedObjectContext<'_>, + req: Json, + ) -> Result<(), HandlerError> { + let RejectAwakeable { + awakeable_key, + reason, + } = req.into_inner(); + let Some(awakeable_id) = context + .get::(&format!("awk-{awakeable_key}")) + .await? + else { + Err(TerminalError::new( + "Awakeable is not registered yet".to_string(), + ))? + }; + + context.reject_awakeable(&awakeable_id, TerminalError::new(reason)); + + Ok(()) + } + + #[handler(name = "hasAwakeable", shared)] + async fn has_awakeable( + &self, + context: SharedObjectContext<'_>, + awakeable_key: String, + ) -> Result { + Ok(context + .get::(&format!("awk-{awakeable_key}")) + .await? + .is_some()) + } + + #[handler(name = "getResults", shared)] + async fn get_results( + &self, + context: SharedObjectContext<'_>, + ) -> Result>, HandlerError> { + Ok(context + .get::>>("results") + .await? + .unwrap_or_default()) + } +} diff --git a/test-services/src/mod.rs b/test-services/src/mod.rs new file mode 100644 index 0000000..b07b77e --- /dev/null +++ b/test-services/src/mod.rs @@ -0,0 +1,2 @@ +pub mod impl_block; +pub mod trait_block; diff --git a/test-services/src/awakeable_holder.rs b/test-services/src/trait_block/awakeable_holder.rs similarity index 100% rename from test-services/src/awakeable_holder.rs rename to test-services/src/trait_block/awakeable_holder.rs diff --git a/test-services/src/block_and_wait_workflow.rs b/test-services/src/trait_block/block_and_wait_workflow.rs similarity index 100% rename from test-services/src/block_and_wait_workflow.rs rename to test-services/src/trait_block/block_and_wait_workflow.rs diff --git a/test-services/src/cancel_test.rs b/test-services/src/trait_block/cancel_test.rs similarity index 100% rename from test-services/src/cancel_test.rs rename to test-services/src/trait_block/cancel_test.rs diff --git a/test-services/src/counter.rs b/test-services/src/trait_block/counter.rs similarity index 100% rename from test-services/src/counter.rs rename to test-services/src/trait_block/counter.rs diff --git a/test-services/src/failing.rs b/test-services/src/trait_block/failing.rs similarity index 100% rename from test-services/src/failing.rs rename to test-services/src/trait_block/failing.rs diff --git a/test-services/src/kill_test.rs b/test-services/src/trait_block/kill_test.rs similarity index 100% rename from test-services/src/kill_test.rs rename to test-services/src/trait_block/kill_test.rs diff --git a/test-services/src/list_object.rs b/test-services/src/trait_block/list_object.rs similarity index 100% rename from test-services/src/list_object.rs rename to test-services/src/trait_block/list_object.rs diff --git a/test-services/src/main.rs b/test-services/src/trait_block/main.rs similarity index 100% rename from test-services/src/main.rs rename to test-services/src/trait_block/main.rs diff --git a/test-services/src/map_object.rs b/test-services/src/trait_block/map_object.rs similarity index 100% rename from test-services/src/map_object.rs rename to test-services/src/trait_block/map_object.rs diff --git a/test-services/src/trait_block/mod.rs b/test-services/src/trait_block/mod.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/test-services/src/trait_block/mod.rs @@ -0,0 +1 @@ + diff --git a/test-services/src/non_deterministic.rs b/test-services/src/trait_block/non_deterministic.rs similarity index 100% rename from test-services/src/non_deterministic.rs rename to test-services/src/trait_block/non_deterministic.rs diff --git a/test-services/src/proxy.rs b/test-services/src/trait_block/proxy.rs similarity index 100% rename from test-services/src/proxy.rs rename to test-services/src/trait_block/proxy.rs diff --git a/test-services/src/test_utils_service.rs b/test-services/src/trait_block/test_utils_service.rs similarity index 100% rename from test-services/src/test_utils_service.rs rename to test-services/src/trait_block/test_utils_service.rs diff --git a/test-services/src/virtual_object_command_interpreter.rs b/test-services/src/trait_block/virtual_object_command_interpreter.rs similarity index 100% rename from test-services/src/virtual_object_command_interpreter.rs rename to test-services/src/trait_block/virtual_object_command_interpreter.rs From a69d9d140e7d0ae204d082aa1bb9a2cdd5ccb729 Mon Sep 17 00:00:00 2001 From: BigFish2086 Date: Thu, 3 Jul 2025 20:36:07 +0300 Subject: [PATCH 4/5] update testcontainers to have tests for the new macro syntax #43 --- .../tests/test_container_impl_block.rs | 98 +++++++++++++++++++ ...ainer.rs => test_container_trait_block.rs} | 0 2 files changed, 98 insertions(+) create mode 100644 testcontainers/tests/test_container_impl_block.rs rename testcontainers/tests/{test_container.rs => test_container_trait_block.rs} (100%) diff --git a/testcontainers/tests/test_container_impl_block.rs b/testcontainers/tests/test_container_impl_block.rs new file mode 100644 index 0000000..2c16ee7 --- /dev/null +++ b/testcontainers/tests/test_container_impl_block.rs @@ -0,0 +1,98 @@ +use reqwest::StatusCode; +use restate_sdk::prelude::*; +use restate_sdk_testcontainers::TestEnvironment; +use tracing::info; + +pub(crate) struct MyService; + +#[restate_sdk::service(vis = "pub(crate)")] +impl MyService { + #[handler] + async fn my_handler(&self, _ctx: Context<'_>) -> HandlerResult { + let result = "hello!"; + Ok(result.to_string()) + } +} + +// Should compile +pub(crate) struct MyObject; + +#[allow(dead_code)] +#[restate_sdk::object(vis = "pub(crate)")] +impl MyObject { + #[handler] + async fn my_handler(&self, _ctx: ObjectContext<'_>, _input: String) -> HandlerResult { + unimplemented!() + } + + #[handler(shared)] + async fn my_shared_handler( + &self, + _ctx: SharedObjectContext<'_>, + _input: String, + ) -> HandlerResult { + unimplemented!() + } +} + +pub(crate) struct MyWorkflow; + +#[allow(dead_code)] +#[restate_sdk::workflow(vis = "pub(crate)")] +impl MyWorkflow { + #[handler] + async fn my_handler(&self, _ctx: WorkflowContext<'_>, _input: String) -> HandlerResult { + unimplemented!() + } + + #[handler(shared)] + async fn my_shared_handler( + &self, + _ctx: SharedWorkflowContext<'_>, + _input: String, + ) -> HandlerResult { + unimplemented!() + } +} + +#[tokio::test] +async fn test_container() { + tracing_subscriber::fmt::fmt() + .with_max_level(tracing::Level::INFO) // Set the maximum log level + .init(); + + let endpoint = Endpoint::builder().bind(MyService.serve()).build(); + + // simple test container intialization with default configuration + //let test_container = TestContainer::default().start(endpoint).await.unwrap(); + + // custom test container initialization with builder + let test_environment = TestEnvironment::new() + // optional passthrough logging from the resstate server testcontainer + // prints container logs to tracing::info level + .with_container_logging() + .with_container( + "docker.io/restatedev/restate".to_string(), + "latest".to_string(), + ) + .start(endpoint) + .await + .unwrap(); + + let ingress_url = test_environment.ingress_url(); + + // call container ingress url for /MyService/my_handler + let response = reqwest::Client::new() + .post(format!("{}/MyService/my_handler", ingress_url)) + .header("idempotency-key", "abc") + .send() + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + info!( + "/MyService/my_handler response: {:?}", + response.text().await.unwrap() + ); +} diff --git a/testcontainers/tests/test_container.rs b/testcontainers/tests/test_container_trait_block.rs similarity index 100% rename from testcontainers/tests/test_container.rs rename to testcontainers/tests/test_container_trait_block.rs From 9e4590757da96449aa0eff2237bbfb7b5e51fa3d Mon Sep 17 00:00:00 2001 From: BigFish2086 Date: Thu, 3 Jul 2025 20:36:40 +0300 Subject: [PATCH 5/5] add unit-testing for the new macro syntax #43 --- tests/{ => impl_block}/compiletest.rs | 0 tests/impl_block/service.rs | 103 ++++++++++++++++++ .../ui/shared_handler_in_service.rs | 23 ++++ .../ui/shared_handler_in_service.stderr | 23 ++++ tests/trait_block/compiletest.rs | 5 + tests/{ => trait_block}/schema.rs | 0 tests/{ => trait_block}/service.rs | 0 .../ui/shared_handler_in_service.rs | 2 +- .../ui/shared_handler_in_service.stderr | 0 9 files changed, 155 insertions(+), 1 deletion(-) rename tests/{ => impl_block}/compiletest.rs (100%) create mode 100644 tests/impl_block/service.rs create mode 100644 tests/impl_block/ui/shared_handler_in_service.rs create mode 100644 tests/impl_block/ui/shared_handler_in_service.stderr create mode 100644 tests/trait_block/compiletest.rs rename tests/{ => trait_block}/schema.rs (100%) rename tests/{ => trait_block}/service.rs (100%) rename tests/{ => trait_block}/ui/shared_handler_in_service.rs (99%) rename tests/{ => trait_block}/ui/shared_handler_in_service.stderr (100%) diff --git a/tests/compiletest.rs b/tests/impl_block/compiletest.rs similarity index 100% rename from tests/compiletest.rs rename to tests/impl_block/compiletest.rs diff --git a/tests/impl_block/service.rs b/tests/impl_block/service.rs new file mode 100644 index 0000000..9271bcc --- /dev/null +++ b/tests/impl_block/service.rs @@ -0,0 +1,103 @@ +use restate_sdk::prelude::*; + +// Should compile + +pub(crate) struct MyService; + +#[allow(dead_code)] +#[restate_sdk::service(vis = "pub(crate)")] +impl MyService { + #[handler] + async fn my_handler(&self, _ctx: Context<'_>, _input: String) -> HandlerResult { + unimplemented!() + } + + #[handler] + async fn no_input(&self, _ctx: Context<'_>) -> HandlerResult { + unimplemented!() + } + + #[handler] + async fn no_output(&self, _ctx: Context<'_>) -> HandlerResult<()> { + unimplemented!() + } + + #[handler] + async fn no_input_no_output(&self, _ctx: Context<'_>) -> HandlerResult<()> { + unimplemented!() + } + + #[handler] + async fn std_result(&self, _ctx: Context<'_>) -> Result<(), std::io::Error> { + unimplemented!() + } + + #[handler] + async fn std_result_with_terminal_error(&self, _ctx: Context<'_>) -> Result<(), TerminalError> { + unimplemented!() + } + + #[handler] + async fn std_result_with_handler_error(&self, _ctx: Context<'_>) -> Result<(), HandlerError> { + unimplemented!() + } +} + +pub(crate) struct MyObject; + +#[allow(dead_code)] +#[restate_sdk::object(vis = "pub(crate)")] +impl MyObject { + #[handler] + async fn my_handler(&self, _ctx: ObjectContext<'_>, _input: String) -> HandlerResult { + unimplemented!() + } + + #[handler(shared)] + async fn my_shared_handler( + &self, + _ctx: SharedObjectContext<'_>, + _input: String, + ) -> HandlerResult { + unimplemented!() + } +} + +pub(crate) struct MyWorkflow; + +#[allow(dead_code)] +#[restate_sdk::workflow(vis = "pub(crate)")] +impl MyWorkflow { + #[handler] + async fn my_handler(&self, _ctx: WorkflowContext<'_>, _input: String) -> HandlerResult { + unimplemented!() + } + + #[handler(shared)] + async fn my_shared_handler( + &self, + _ctx: SharedWorkflowContext<'_>, + _input: String, + ) -> HandlerResult { + unimplemented!() + } +} + +pub(crate) struct MyRenamedService; + +#[restate_sdk::service(vis = "pub(crate)", name = "myRenamedService")] +impl MyRenamedService { + #[handler(name = "myRenamedHandler")] + async fn my_handler(&self, _ctx: Context<'_>) -> HandlerResult<()> { + Ok(()) + } +} + +#[test] +fn renamed_service_handler() { + use restate_sdk::service::Discoverable; + + let discovery = ServeMyRenamedService::::discover(); + assert_eq!(discovery.name.to_string(), "myRenamedService"); + assert_eq!(discovery.handlers[0].name.to_string(), "myRenamedHandler"); +} diff --git a/tests/impl_block/ui/shared_handler_in_service.rs b/tests/impl_block/ui/shared_handler_in_service.rs new file mode 100644 index 0000000..e6cd57d --- /dev/null +++ b/tests/impl_block/ui/shared_handler_in_service.rs @@ -0,0 +1,23 @@ +use restate_sdk::prelude::*; + +struct SharedHandlerInService; + +#[restate_sdk::service] +impl SharedHandlerInService { + #[handler(shared)] + async fn my_handler(&self, _ctx: Context<'_>) -> HandlerResult<()> { + Ok(()) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + HttpServer::new( + Endpoint::builder() + .with_service(SharedHandlerInService.serve()) + .build(), + ) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/tests/impl_block/ui/shared_handler_in_service.stderr b/tests/impl_block/ui/shared_handler_in_service.stderr new file mode 100644 index 0000000..44a4386 --- /dev/null +++ b/tests/impl_block/ui/shared_handler_in_service.stderr @@ -0,0 +1,23 @@ +error: Service handlers cannot be annotated with #[handler(shared)] + --> tests/ui/shared_handler_in_service.rs:7:15 + | +7 | #[handler(shared)] + | ^^^^^^ + +error[E0599]: no method named `with_service` found for struct `restate_sdk::endpoint::Builder` in the current scope + --> tests/ui/shared_handler_in_service.rs:18:14 + | +17 | / Endpoint::builder() +18 | | .with_service(SharedHandlerInService.serve()) + | | -^^^^^^^^^^^^ method not found in `Builder` + | |_____________| + | + +error[E0599]: no method named `serve` found for struct `SharedHandlerInService` in the current scope + --> tests/ui/shared_handler_in_service.rs:18:50 + | +3 | struct SharedHandlerInService; + | ----------------------------- method `serve` not found for this struct +... +18 | .with_service(SharedHandlerInService.serve()) + | ^^^^^ method not found in `SharedHandlerInService` diff --git a/tests/trait_block/compiletest.rs b/tests/trait_block/compiletest.rs new file mode 100644 index 0000000..870c2f9 --- /dev/null +++ b/tests/trait_block/compiletest.rs @@ -0,0 +1,5 @@ +#[test] +fn ui() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/*.rs"); +} diff --git a/tests/schema.rs b/tests/trait_block/schema.rs similarity index 100% rename from tests/schema.rs rename to tests/trait_block/schema.rs diff --git a/tests/service.rs b/tests/trait_block/service.rs similarity index 100% rename from tests/service.rs rename to tests/trait_block/service.rs diff --git a/tests/ui/shared_handler_in_service.rs b/tests/trait_block/ui/shared_handler_in_service.rs similarity index 99% rename from tests/ui/shared_handler_in_service.rs rename to tests/trait_block/ui/shared_handler_in_service.rs index ef98a45..f345271 100644 --- a/tests/ui/shared_handler_in_service.rs +++ b/tests/trait_block/ui/shared_handler_in_service.rs @@ -24,4 +24,4 @@ async fn main() { ) .listen_and_serve("0.0.0.0:9080".parse().unwrap()) .await; -} \ No newline at end of file +} diff --git a/tests/ui/shared_handler_in_service.stderr b/tests/trait_block/ui/shared_handler_in_service.stderr similarity index 100% rename from tests/ui/shared_handler_in_service.stderr rename to tests/trait_block/ui/shared_handler_in_service.stderr