diff --git a/prost-validate-derive-core/src/any.rs b/prost-validate-derive-core/src/any.rs index 9b76d45..903c3c8 100644 --- a/prost-validate-derive-core/src/any.rs +++ b/prost-validate-derive-core/src/any.rs @@ -17,12 +17,13 @@ impl ToValidationTokens for AnyRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; let rules = prost_validate_types::AnyRules::from(self.to_owned()); + let maybe_return = ctx.maybe_return(); let r#in = rules.r#in.is_empty().not().then(|| { let v = rules.r#in; quote! { let values = vec![#(#v),*]; if !values.contains(&#name.type_url.as_str()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::any::Error::In(values.iter().map(|v|v.to_string()).collect()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::any::Error::In(values.iter().map(|v|v.to_string()).collect()))); } } }); @@ -31,7 +32,7 @@ impl ToValidationTokens for AnyRules { quote! { let values = vec![#(#v),*]; if values.contains(&#name.type_url.as_str()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::any::Error::NotIn(values.iter().map(|v|v.to_string()).collect()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::any::Error::NotIn(values.iter().map(|v|v.to_string()).collect()))); } } }); diff --git a/prost-validate-derive-core/src/bool.rs b/prost-validate-derive-core/src/bool.rs index a1fa27f..cf9ac25 100644 --- a/prost-validate-derive-core/src/bool.rs +++ b/prost-validate-derive-core/src/bool.rs @@ -11,10 +11,11 @@ pub struct BoolRules { impl ToValidationTokens for BoolRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; + let maybe_return = ctx.maybe_return(); let r#const = self.r#const.map(|v| { quote! { if *#name != #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#bool::Error::Const(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::r#bool::Error::Const(#v))); } } }); diff --git a/prost-validate-derive-core/src/bytes.rs b/prost-validate-derive-core/src/bytes.rs index 2186922..ded4bdf 100644 --- a/prost-validate-derive-core/src/bytes.rs +++ b/prost-validate-derive-core/src/bytes.rs @@ -27,11 +27,12 @@ impl ToValidationTokens for BytesRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; let rules = prost_validate_types::BytesRules::from(self.to_owned()); + let maybe_return = ctx.maybe_return(); let r#const = rules.r#const.map(|v| { let v = LitByteStr::new(v.as_slice(), Span::call_site()); quote! { if !#name.iter().eq(#v.iter()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Const(#v.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Const(#v.to_vec()))); } } }); @@ -39,7 +40,7 @@ impl ToValidationTokens for BytesRules { let v = v as usize; quote! { if #name.len() != #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Len(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Len(#v))); } } }); @@ -47,7 +48,7 @@ impl ToValidationTokens for BytesRules { let v = v as usize; quote! { if #name.len() < #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MinLen(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MinLen(#v))); } } }); @@ -55,7 +56,7 @@ impl ToValidationTokens for BytesRules { let v = v as usize; quote! { if #name.len() > #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MaxLen(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MaxLen(#v))); } } }); @@ -65,10 +66,10 @@ impl ToValidationTokens for BytesRules { } quote! { match ::regex::bytes::Regex::new(#v) { - Err(e) => return Err(::prost_validate::Error::new(#field, format!("Invalid regex pattern: {e}"))), + Err(e) => #maybe_return(::prost_validate::Error::new(#field, format!("Invalid regex pattern: {e}"))), Ok(regex) => { if !regex.is_match(#name.iter().as_slice()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Pattern(#v.to_string()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Pattern(#v.to_string()))); } } } @@ -78,7 +79,7 @@ impl ToValidationTokens for BytesRules { let v = LitByteStr::new(v.as_slice(), Span::call_site()); quote! { if !#name.starts_with(#v) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Prefix(#v.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Prefix(#v.to_vec()))); } } }); @@ -86,7 +87,7 @@ impl ToValidationTokens for BytesRules { let v = LitByteStr::new(v.as_slice(), Span::call_site()); quote! { if !#name.ends_with(#v) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Suffix(#v.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Suffix(#v.to_vec()))); } } }); @@ -94,7 +95,7 @@ impl ToValidationTokens for BytesRules { let v = LitByteStr::new(v.as_slice(), Span::call_site()); quote! { if !::prost_validate::ValidateBytesExt::contains(&#name, #v.as_slice()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Contains(#v.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Contains(#v.to_vec()))); } } }); @@ -107,7 +108,7 @@ impl ToValidationTokens for BytesRules { quote! { let values = [#(#v.to_vec()),*]; if !values.contains(&#name) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::In(values.iter().map(|v| v.to_vec()).collect()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::In(values.iter().map(|v| v.to_vec()).collect()))); } } }); @@ -120,7 +121,7 @@ impl ToValidationTokens for BytesRules { quote! { let values = [#(#v.to_vec()),*]; if values.contains(&#name) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::NotIn(values.iter().map(|v| v.to_vec()).collect()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::NotIn(values.iter().map(|v| v.to_vec()).collect()))); } } }); @@ -128,21 +129,21 @@ impl ToValidationTokens for BytesRules { bytes_rules::WellKnown::Ip(true) => { quote! { if #name.len() != 4 && #name.len() != 16 { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ip)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ip)); } } } bytes_rules::WellKnown::Ipv4(true) => { quote! { if #name.len() != 4 { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv4)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv4)); } } } bytes_rules::WellKnown::Ipv6(true) => { quote! { if #name.len() != 16 { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv6)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv6)); } } } diff --git a/prost-validate-derive-core/src/derive.rs b/prost-validate-derive-core/src/derive.rs index 7b2d483..4b0bc06 100644 --- a/prost-validate-derive-core/src/derive.rs +++ b/prost-validate-derive-core/src/derive.rs @@ -23,27 +23,11 @@ pub fn derive_with_module( let opts = Opts::from_derive_input(&input).expect("Wrong validate options"); let DeriveInput { ident, .. } = input; - let implementation = match opts.data { - Data::Enum(e) => e - .iter() - .map(|v| Field { - module: module.clone().map(|v| v.to_string()), - ..v.clone() - }) - .map(|v| v.to_token_stream()) - .collect::(), - Data::Struct(s) => s - .fields - .iter() - .map(|v| Field { - module: module.clone().map(|v| v.to_string()), - ..v.clone() - }) - .map(|field| field.into_token_stream()) - .collect::(), - }; + let implementation = body_tokens(&opts, module.clone(), false); + let impl_multierrs = body_tokens(&opts, module.clone(), true); let allow = quote! { + #[allow(clippy::collapsible_if)] #[allow(clippy::regex_creation_in_loops)] #[allow(irrefutable_let_patterns)] #[allow(unused_variables)] @@ -62,6 +46,13 @@ pub fn derive_with_module( #implementation Ok(()) } + + #allow + fn validate_all(&self) -> ::core::result::Result<(), Vec<::prost_validate::Error>> { + let mut errs = vec![]; + #impl_multierrs + if errs.is_empty() { Ok(()) } else { Err(errs) } + } } } } else { @@ -71,6 +62,30 @@ pub fn derive_with_module( } } +fn body_tokens(opts: &Opts, module: Option, multierrs: bool) -> TokenStream { + match &opts.data { + Data::Enum(e) => e + .iter() + .map(|v| Field { + module: module.clone().map(|v| v.to_string()), + multierrs, + ..v.clone() + }) + .map(|v| v.to_token_stream()) + .collect(), + Data::Struct(s) => s + .fields + .iter() + .map(|v| Field { + module: module.clone().map(|v| v.to_string()), + multierrs, + ..v.clone() + }) + .map(|field| field.into_token_stream()) + .collect(), + } +} + #[cfg(test)] pub fn derive_2(input: proc_macro2::TokenStream) -> String { let output = derive(input); @@ -107,6 +122,7 @@ mod tests { use super::*; #[test] + // cargo test --all-targets --all-features --locked --frozen --offline --no-fail-fast --manifest-path prost-validate-derive-core/Cargo.toml derive::tests -- --nocapture fn tests() { let input = quote! { pub struct WrapperRequiredFloat { diff --git a/prost-validate-derive-core/src/duration.rs b/prost-validate-derive-core/src/duration.rs index 1b4a8ea..3c15c4c 100644 --- a/prost-validate-derive-core/src/duration.rs +++ b/prost-validate-derive-core/src/duration.rs @@ -34,11 +34,12 @@ impl ToValidationTokens for DurationRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; let rules = prost_validate_types::DurationRules::from(self.clone()); + let maybe_return = ctx.maybe_return(); let r#const = rules.r#const.map(|v| v.as_duration()).map(|v| { let (got, want) = duration_to_tokens(name, &v); quote! { if #got != #want { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Const(#want))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Const(#want))); } } }); @@ -50,7 +51,7 @@ impl ToValidationTokens for DurationRules { let (_, gt) = duration_to_tokens(name, >); quote! { if #val <= #gt || #val >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::in_range(false, #gt, #lt, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::in_range(false, #gt, #lt, false))); } } } else { @@ -58,7 +59,7 @@ impl ToValidationTokens for DurationRules { let (_, gt) = duration_to_tokens(name, >); quote! { if #val >= #lt && #val <= #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::not_in_range(true, #lt, #gt, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::not_in_range(true, #lt, #gt, true))); } } } @@ -68,7 +69,7 @@ impl ToValidationTokens for DurationRules { let (_, gte) = duration_to_tokens(name, >e); quote! { if #val < #gte || #val >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::in_range(true, #gte, #lt, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::in_range(true, #gte, #lt, false))); } } } else { @@ -76,7 +77,7 @@ impl ToValidationTokens for DurationRules { let (_, gte) = duration_to_tokens(name, >e); quote! { if #val >= #lt && #val < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::not_in_range(true, #lt, #gte, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::not_in_range(true, #lt, #gte, false))); } } } @@ -84,7 +85,7 @@ impl ToValidationTokens for DurationRules { let (val, lt) = duration_to_tokens(name, <); quote! { if #val >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Lt(#lt))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Lt(#lt))); } } } @@ -95,7 +96,7 @@ impl ToValidationTokens for DurationRules { let (_, gt) = duration_to_tokens(name, >); quote! { if #val <= #gt || #val > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::in_range(false, #gt, #lte, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::in_range(false, #gt, #lte, true))); } } } else { @@ -103,7 +104,7 @@ impl ToValidationTokens for DurationRules { let (_, gt) = duration_to_tokens(name, >); quote! { if #val >= #lte && #val < #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::not_in_range(false, #lte, #gt, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::not_in_range(false, #lte, #gt, true))); } } } @@ -113,7 +114,7 @@ impl ToValidationTokens for DurationRules { let (_, gte) = duration_to_tokens(name, >e); quote! { if #val < #gte || #val > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::in_range(true, #gte, #lte, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::in_range(true, #gte, #lte, true))); } } } else { @@ -121,7 +122,7 @@ impl ToValidationTokens for DurationRules { let (_, gte) = duration_to_tokens(name, >e); quote! { if #val > #lte && #val < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::not_in_range(false, #lte, #gte, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::not_in_range(false, #lte, #gte, false))); } } } @@ -129,7 +130,7 @@ impl ToValidationTokens for DurationRules { let (val, lte) = duration_to_tokens(name, <e); quote! { if #val > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Lte(#lte))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Lte(#lte))); } } } @@ -137,14 +138,14 @@ impl ToValidationTokens for DurationRules { let (val, gt) = duration_to_tokens(name, >); quote! { if #val <= #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Gt(#gt))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Gt(#gt))); } } } else if let Some(gte) = rules.gte.map(|v| v.as_duration()) { let (val, gte) = duration_to_tokens(name, >e); quote! { if #val < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Gte(#gte))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::Gte(#gte))); } } } else { @@ -165,7 +166,7 @@ impl ToValidationTokens for DurationRules { quote! { let values = [#(#vals),*]; if !values.contains(&#val) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::In(values.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::In(values.to_vec()))); } } }); @@ -184,7 +185,7 @@ impl ToValidationTokens for DurationRules { quote! { let values = [#(#vals),*]; if values.contains(&#val) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::NotIn(values.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::duration::Error::NotIn(values.to_vec()))); } } }); diff --git a/prost-validate-derive-core/src/enum.rs b/prost-validate-derive-core/src/enum.rs index 3316675..8352c78 100644 --- a/prost-validate-derive-core/src/enum.rs +++ b/prost-validate-derive-core/src/enum.rs @@ -19,10 +19,11 @@ impl ToValidationTokens for EnumRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; let rules = prost_validate_types::EnumRules::from(self.to_owned()); + let maybe_return = ctx.maybe_return(); let r#const = rules.r#const.map(|v| { quote! { if (*#name as i32) != #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::Const(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::Const(#v))); } } }); @@ -47,7 +48,7 @@ impl ToValidationTokens for EnumRules { .expect("Invalid enum path"); quote! { if !#enum_type::is_valid(*#name) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::DefinedOnly)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::DefinedOnly)); } } }); @@ -56,7 +57,7 @@ impl ToValidationTokens for EnumRules { quote! { let values = [#(#v),*]; if !values.contains(&#name) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::In(values.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::In(values.to_vec()))); } } }); @@ -65,7 +66,7 @@ impl ToValidationTokens for EnumRules { quote! { let values = [#(#v),*]; if values.contains(#name) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::NotIn(values.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::r#enum::Error::NotIn(values.to_vec()))); } } }); diff --git a/prost-validate-derive-core/src/field.rs b/prost-validate-derive-core/src/field.rs index da14557..16ebcc7 100644 --- a/prost-validate-derive-core/src/field.rs +++ b/prost-validate-derive-core/src/field.rs @@ -27,9 +27,20 @@ pub struct Context<'a> { pub oneof: bool, pub prost_types: bool, pub wrapper: bool, + pub multierrs: bool, pub module: Option, } +impl<'a> Context<'a> { + pub(crate) fn maybe_return(&self) -> TokenStream { + if self.multierrs { + quote! { errs.push } + } else { + quote! { return Err } + } + } +} + pub trait ToValidationTokens { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream; } @@ -42,6 +53,7 @@ pub struct Field { pub prost: ProstField, pub oneof: bool, pub map: bool, + pub multierrs: bool, pub module: Option, } @@ -74,6 +86,7 @@ impl Field { validation, oneof, map: map.is_some(), + multierrs: false, module: None, } } @@ -558,6 +571,7 @@ impl ToTokens for Field { map: self.map, oneof: self.oneof, prost_types: self.is_prost_types(), + multierrs: self.multierrs, module: self.module.clone(), }; let name = if self.oneof { @@ -574,9 +588,10 @@ impl ToTokens for Field { let body = self.validation.to_validation_tokens(&ctx, name); let required = ctx.required.then(|| { let field = &ctx.name; + let maybe_return = ctx.maybe_return(); quote! { if self.#name.is_none() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::message::Error::Required)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::message::Error::Required)); } } }); @@ -775,10 +790,9 @@ pub fn with_ignore_empty(name: &syn::Ident, ignore_empty: bool, body: TokenStrea } if ignore_empty { quote! { - if #name.is_empty() { - return Ok(()); + if !#name.is_empty() { + #body } - #body } } else { quote! { diff --git a/prost-validate-derive-core/src/list.rs b/prost-validate-derive-core/src/list.rs index a2d4a97..2641dc0 100644 --- a/prost-validate-derive-core/src/list.rs +++ b/prost-validate-derive-core/src/list.rs @@ -17,11 +17,12 @@ pub struct RepeatedRules { impl ToValidationTokens for RepeatedRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; + let maybe_return = ctx.maybe_return(); let min_items = self.min_items.map(|v| { let v = v as usize; quote! { if #name.len() < #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::MinItems(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::MinItems(#v))); } } }); @@ -29,20 +30,33 @@ impl ToValidationTokens for RepeatedRules { let v = v as usize; quote! { if #name.len() > #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::MaxItems(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::MaxItems(#v))); } } }); let unique = self.unique.is_true_and(|| { quote! { if ::prost_validate::VecExt::unique(#name).len() != #name.len() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::Unique)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::list::Error::Unique)); } } }); let map = quote! { |e| ::prost_validate::Error::new(format!("{}[{i}]", #field), ::prost_validate::errors::list::Error::Item(Box::new(e))) }; let items = self.items.as_ref().map(|v| { let validation = v.to_validation_tokens(ctx, &format_ident!("item")); + if ctx.multierrs { + return quote! { + for (i, item) in #name.iter().enumerate() { + if let Err(es) = || -> ::core::result::Result<(), Vec<::prost_validate::Error>> { + let mut errs = vec![]; + #validation + if errs.is_empty() { Ok(()) } else { Err(errs) } + }() { + errs.extend(es.into_iter().map(#map)); + } + } + }; + } quote! { for (i, item) in #name.iter().enumerate() { || -> ::prost_validate::Result<_> { @@ -65,6 +79,16 @@ impl ToValidationTokens for RepeatedRules { } else { (quote! { #name.iter() }, quote! {}) }; + if ctx.multierrs { + return quote! { + for (i, item) in #name_iter.enumerate() { + #item_ref + if let Err(es) = ::prost_validate::validate_all!(item) { + errs.extend(es.into_iter().map(#map)); + } + } + }; + } quote! { for (i, item) in #name_iter.enumerate() { #item_ref diff --git a/prost-validate-derive-core/src/map.rs b/prost-validate-derive-core/src/map.rs index 7589dcb..294d7de 100644 --- a/prost-validate-derive-core/src/map.rs +++ b/prost-validate-derive-core/src/map.rs @@ -20,11 +20,12 @@ impl ToValidationTokens for MapRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; let rules = prost_validate_types::MapRules::from(self.to_owned()); + let maybe_return = ctx.maybe_return(); let min_pairs = rules.min_pairs.map(|v| { let v = v as usize; quote! { if #name.len() < #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::map::Error::MinPairs(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::map::Error::MinPairs(#v))); } } }); @@ -32,27 +33,54 @@ impl ToValidationTokens for MapRules { let v = v as usize; quote! { if #name.len() > #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::map::Error::MaxPairs(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::map::Error::MaxPairs(#v))); } } }); - let key = format_ident!("key"); let keys = self.keys.as_ref().map(|rules| { + let key = format_ident!("key"); let validate = rules.to_validation_tokens(ctx, &key); + let map = quote! { |e| ::prost_validate::Error::new(format!("{}[{}]", #field, #key), ::prost_validate::errors::map::Error::Keys(Box::new(e))) }; validate.is_empty().not().then(|| { + if ctx.multierrs { + return quote! { + for #key in #name.keys() { + if let Err(es) = || -> ::core::result::Result<(), Vec<::prost_validate::Error>> { + let mut errs = vec![]; + #validate + if errs.is_empty() { Ok(()) } else { Err(errs) } + }() { + errs.extend(es.into_iter().map(#map)); + } + } + }; + } quote! { for #key in #name.keys() { || -> ::prost_validate::Result<_> { #validate Ok(()) - }().map_err(|e| ::prost_validate::Error::new(format!("{}[{}]", #field, #key), ::prost_validate::errors::map::Error::Keys(Box::new(e))))?; + }().map_err(#map)?; } } }) }); let value = format_ident!("value"); - let map = quote! { |e| ::prost_validate::Error::new(format!("{}[{k}]", #field), ::prost_validate::errors::map::Error::Values(Box::new(e))) }; let quote_values = |validation: TokenStream| { + let map = quote! { |e| ::prost_validate::Error::new(format!("{}[{k}]", #field), ::prost_validate::errors::map::Error::Values(Box::new(e))) }; + if ctx.multierrs { + return quote! { + for (k, #value) in #name.iter() { + if let Err(es) = || -> ::core::result::Result<(), Vec<::prost_validate::Error>> { + let mut errs = vec![]; + #validation + if errs.is_empty() { Ok(()) } else { Err(errs) } + }() { + errs.extend(es.into_iter().map(#map)); + } + } + }; + } quote! { for (k, #value) in #name.iter() { || -> ::prost_validate::Result<_> { diff --git a/prost-validate-derive-core/src/message.rs b/prost-validate-derive-core/src/message.rs index 8655f9f..926f0e2 100644 --- a/prost-validate-derive-core/src/message.rs +++ b/prost-validate-derive-core/src/message.rs @@ -21,6 +21,14 @@ impl ToValidationTokens for MessageRules { let validate = self.skip.not().then(|| { let map = quote! { |e| ::prost_validate::Error::new(#field, ::prost_validate::errors::message::Error::Message(Box::new(e))) }; let name_ref = ctx.boxed.then(|| quote! { let #name = #name.as_ref(); }); + if ctx.multierrs { + return quote! { + #name_ref + if let Err(es) = ::prost_validate::validate_all!(#name) { + errs.extend(es.into_iter().map(#map)); + } + }; + } quote! { #name_ref ::prost_validate::validate!(#name).map_err(#map)?; diff --git a/prost-validate-derive-core/src/number.rs b/prost-validate-derive-core/src/number.rs index 31eacf0..d07d9a4 100644 --- a/prost-validate-derive-core/src/number.rs +++ b/prost-validate-derive-core/src/number.rs @@ -24,10 +24,11 @@ macro_rules! make_number_rules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; let rules = prost_validate_types::$name::from(self.to_owned()); + let maybe_return = ctx.maybe_return(); let r#const = rules.r#const.map(|v| { quote! { if *#name != #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Const(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Const(#v))); } } }); @@ -37,13 +38,13 @@ macro_rules! make_number_rules { if lt > gt { quote! { if *#name <= #gt || *#name >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(false, #gt, #lt, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(false, #gt, #lt, false))); } } } else { quote! { if *#name >= #lt && *#name <= #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(true, #lt, #gt, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(true, #lt, #gt, true))); } } } @@ -51,20 +52,20 @@ macro_rules! make_number_rules { if lt > gte { quote! { if *#name < #gte || *#name >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(true, #gte, #lt, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(true, #gte, #lt, false))); } } } else { quote! { if *#name >= #lt && *#name < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(true, #lt, #gte, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(true, #lt, #gte, false))); } } } } else { quote! { if *#name >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Lt(#lt))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Lt(#lt))); } } } @@ -73,13 +74,13 @@ macro_rules! make_number_rules { if lte > gt { quote! { if *#name <= #gt || *#name > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(false, #gt, #lte, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(false, #gt, #lte, true))); } } } else { quote! { if *#name > #lte && *#name <= #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(false, #lte, #gt, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(false, #lte, #gt, true))); } } } @@ -87,33 +88,33 @@ macro_rules! make_number_rules { if lte > gte { quote! { if *#name < #gte || *#name > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(true, #gte, #lte, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::in_range(true, #gte, #lte, true))); } } } else { quote! { if *#name > #lte && *#name < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(false, #lte, #gte, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::not_in_range(false, #lte, #gte, false))); } } } } else { quote! { if *#name > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Lte(#lte))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Lte(#lte))); } } } } else if let Some(gt) = rules.gt { quote! { if *#name <= #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Gt(#gt))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Gt(#gt))); } } } else if let Some(gte) = rules.gte { quote! { if *#name < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Gte(#gte))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::Gte(#gte))); } } } else { @@ -124,7 +125,7 @@ macro_rules! make_number_rules { quote! { let values = vec![#(#v),*]; if !values.contains(#name) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::In(values.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::In(values.to_vec()))); } } }); @@ -133,7 +134,7 @@ macro_rules! make_number_rules { quote! { let values = vec![#(#v),*]; if values.contains(#name) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::NotIn(values.to_vec()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::$module::Error::NotIn(values.to_vec()))); } } }); diff --git a/prost-validate-derive-core/src/oneof.rs b/prost-validate-derive-core/src/oneof.rs index 9b3f9e0..a003793 100644 --- a/prost-validate-derive-core/src/oneof.rs +++ b/prost-validate-derive-core/src/oneof.rs @@ -10,7 +10,14 @@ pub struct OneOfRules { } impl ToValidationTokens for OneOfRules { - fn to_validation_tokens(&self, _: &Context, name: &Ident) -> TokenStream { + fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { + if ctx.multierrs { + return quote! { + if let Err(es) = ::prost_validate::validate_all!(#name) { + errs.extend(es.into_iter()); + } + }; + } quote! { ::prost_validate::validate!(#name)?; } diff --git a/prost-validate-derive-core/src/string.rs b/prost-validate-derive-core/src/string.rs index 16c58ea..5f8f34f 100644 --- a/prost-validate-derive-core/src/string.rs +++ b/prost-validate-derive-core/src/string.rs @@ -31,10 +31,11 @@ impl ToValidationTokens for StringRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; let rules = prost_validate_types::StringRules::from(self.to_owned()); + let maybe_return = ctx.maybe_return(); let r#const = rules.r#const.map(|v| { quote! { if #name != #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Const(#v.to_string()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Const(#v.to_string()))); } } }); @@ -42,7 +43,7 @@ impl ToValidationTokens for StringRules { let v = v as usize; quote! { if #name.chars().count() != #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Len(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Len(#v))); } } }); @@ -50,7 +51,7 @@ impl ToValidationTokens for StringRules { let v = v as usize; quote! { if #name.chars().count() < #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MinLen(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MinLen(#v))); } } }); @@ -58,7 +59,7 @@ impl ToValidationTokens for StringRules { let v = v as usize; quote! { if #name.chars().count() > #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MaxLen(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MaxLen(#v))); } } }); @@ -66,7 +67,7 @@ impl ToValidationTokens for StringRules { let v = v as usize; quote! { if #name.len() != #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::LenBytes(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::LenBytes(#v))); } } }); @@ -74,7 +75,7 @@ impl ToValidationTokens for StringRules { let v = v as usize; quote! { if #name.len() < #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MinLenBytes(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MinLenBytes(#v))); } } }); @@ -82,7 +83,7 @@ impl ToValidationTokens for StringRules { let v = v as usize; quote! { if #name.len() > #v { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MaxLenBytes(#v))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::MaxLenBytes(#v))); } } }); @@ -92,10 +93,10 @@ impl ToValidationTokens for StringRules { } quote! { match ::regex::Regex::new(#v) { - Err(e) => return Err(::prost_validate::Error::new(#field, format!("Invalid regex pattern: {e}"))), + Err(e) => #maybe_return(::prost_validate::Error::new(#field, format!("Invalid regex pattern: {e}"))), Ok(regex) => { if !regex.is_match(#name.as_str()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Pattern(#v.to_string()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Pattern(#v.to_string()))); } } } @@ -104,28 +105,28 @@ impl ToValidationTokens for StringRules { let prefix = rules.prefix.map(|v| { quote! { if !#name.starts_with(#v) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Prefix(#v.to_string()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Prefix(#v.to_string()))); } } }); let suffix = rules.suffix.map(|v| { quote! { if !#name.ends_with(#v) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Suffix(#v.to_string()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Suffix(#v.to_string()))); } } }); let contains = rules.contains.map(|v| { quote! { if !#name.contains(#v) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Contains(#v.to_string()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Contains(#v.to_string()))); } } }); let not_contains = rules.not_contains.map(|v| { quote! { if #name.contains(#v) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::NotContains(#v.to_string()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::NotContains(#v.to_string()))); } } }); @@ -134,7 +135,7 @@ impl ToValidationTokens for StringRules { quote! { let values = [#(#v),*]; if !values.contains(&#name.as_str()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::In(values.iter().map(|v| v.to_string()).collect()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::In(values.iter().map(|v| v.to_string()).collect()))); } } }); @@ -143,7 +144,7 @@ impl ToValidationTokens for StringRules { quote! { let values = [#(#v),*]; if values.contains(&#name.as_str()) { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::NotIn(values.iter().map(|v| v.to_string()).collect()))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::NotIn(values.iter().map(|v| v.to_string()).collect()))); } } }); @@ -152,63 +153,63 @@ impl ToValidationTokens for StringRules { string_rules::WellKnown::Email(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_email(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Email)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Email)); } } } string_rules::WellKnown::Hostname(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_hostname(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Hostname)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Hostname)); } } } string_rules::WellKnown::Ip(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_ip(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ip)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ip)); } } } string_rules::WellKnown::Ipv4(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_ipv4(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ipv4)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ipv4)); } } } string_rules::WellKnown::Ipv6(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_ipv6(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ipv6)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Ipv6)); } } } string_rules::WellKnown::Uri(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_uri(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Uri)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Uri)); } } } string_rules::WellKnown::UriRef(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_uri_ref(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::UriRef)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::UriRef)); } } } string_rules::WellKnown::Address(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_address(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Address)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Address)); } } } string_rules::WellKnown::Uuid(true) => { quote! { if ::prost_validate::ValidateStringExt::validate_uuid(&#name).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Uuid)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::Uuid)); } } } @@ -218,14 +219,14 @@ impl ToValidationTokens for StringRules { Ok(prost_validate_types::KnownRegex::HttpHeaderName) => { quote! { if ::prost_validate::ValidateStringExt::validate_header_name(&#name, #strict).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::HttpHeaderName)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::HttpHeaderName)); } } } Ok(prost_validate_types::KnownRegex::HttpHeaderValue) => { quote! { if ::prost_validate::ValidateStringExt::validate_header_value(&#name, #strict).is_err() { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::HttpHeaderValue)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::string::Error::HttpHeaderValue)); } } } diff --git a/prost-validate-derive-core/src/timestamp.rs b/prost-validate-derive-core/src/timestamp.rs index f3d9e54..c6280ed 100644 --- a/prost-validate-derive-core/src/timestamp.rs +++ b/prost-validate-derive-core/src/timestamp.rs @@ -50,11 +50,12 @@ impl ToValidationTokens for TimestampRules { fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream { let field = &ctx.name; let rules = prost_validate_types::TimestampRules::from(self.clone()); + let maybe_return = ctx.maybe_return(); let r#const = rules.r#const.map(|v| v.as_datetime()).map(|v| { let (got, want) = datetime_to_tokens(name, &v); quote! { if #got != #want { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Const(#want))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Const(#want))); } } }); @@ -66,7 +67,7 @@ impl ToValidationTokens for TimestampRules { let (_, gt) = datetime_to_tokens(name, >); quote! { if #val <= #gt || #val >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::in_range(false, #gt, #lt, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::in_range(false, #gt, #lt, false))); } } } else { @@ -74,7 +75,7 @@ impl ToValidationTokens for TimestampRules { let (_, gt) = datetime_to_tokens(name, >); quote! { if #val >= #lt && #val <= #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::not_in_range(true, #lt, #gt, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::not_in_range(true, #lt, #gt, true))); } } } @@ -84,7 +85,7 @@ impl ToValidationTokens for TimestampRules { let (_, gte) = datetime_to_tokens(name, >e); quote! { if #val < #gte || #val >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::in_range(true, #gte, #lt, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::in_range(true, #gte, #lt, false))); } } } else { @@ -92,7 +93,7 @@ impl ToValidationTokens for TimestampRules { let (_, gte) = datetime_to_tokens(name, >e); quote! { if #val >= #lt && #val < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::not_in_range(true, #lt, #gte, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::not_in_range(true, #lt, #gte, false))); } } } @@ -100,7 +101,7 @@ impl ToValidationTokens for TimestampRules { let (val, lt) = datetime_to_tokens(name, <); quote! { if #val >= #lt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Lt(#lt))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Lt(#lt))); } } } @@ -111,7 +112,7 @@ impl ToValidationTokens for TimestampRules { let (_, gt) = datetime_to_tokens(name, >); quote! { if #val <= #gt || #val > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::in_range(false, #gt, #lte, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::in_range(false, #gt, #lte, true))); } } } else { @@ -119,7 +120,7 @@ impl ToValidationTokens for TimestampRules { let (_, gt) = datetime_to_tokens(name, >); quote! { if #val >= #lte && #val < #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::not_in_range(false, #lte, #gt, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::not_in_range(false, #lte, #gt, true))); } } } @@ -129,7 +130,7 @@ impl ToValidationTokens for TimestampRules { let (_, gte) = datetime_to_tokens(name, >e); quote! { if #val < #gte || #val > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::in_range(true, #gte, #lte, true))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::in_range(true, #gte, #lte, true))); } } } else { @@ -137,7 +138,7 @@ impl ToValidationTokens for TimestampRules { let (_, gte) = datetime_to_tokens(name, >e); quote! { if #val > #lte && #val < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::not_in_range(false, #lte, #gte, false))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::not_in_range(false, #lte, #gte, false))); } } } @@ -145,7 +146,7 @@ impl ToValidationTokens for TimestampRules { let (val, lte) = datetime_to_tokens(name, <e); quote! { if #val > #lte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Lte(#lte))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Lte(#lte))); } } } @@ -153,14 +154,14 @@ impl ToValidationTokens for TimestampRules { let (val, gt) = datetime_to_tokens(name, >); quote! { if #val <= #gt { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Gt(#gt))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Gt(#gt))); } } } else if let Some(gte) = rules.gte.map(|v| v.as_datetime()) { let (val, gte) = datetime_to_tokens(name, >e); quote! { if #val < #gte { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Gte(#gte))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Gte(#gte))); } } } else if let Some(true) = rules.lt_now { @@ -171,7 +172,7 @@ impl ToValidationTokens for TimestampRules { let now = ::time::OffsetDateTime::now_utc(); let d = #d; if #val >= now || #val < now - d { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::LtNowWithin(d))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::LtNowWithin(d))); } } } else { @@ -179,7 +180,7 @@ impl ToValidationTokens for TimestampRules { quote! { let now = ::time::OffsetDateTime::now_utc(); if #val >= now { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::LtNow)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::LtNow)); } } } @@ -191,7 +192,7 @@ impl ToValidationTokens for TimestampRules { let now = ::time::OffsetDateTime::now_utc(); let d = #d; if #val <= now || #val > now + d { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::GtNowWithin(d))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::GtNowWithin(d))); } } } else { @@ -199,7 +200,7 @@ impl ToValidationTokens for TimestampRules { quote! { let now = ::time::OffsetDateTime::now_utc(); if #val <= now { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::GtNow)); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::GtNow)); } } } @@ -210,7 +211,7 @@ impl ToValidationTokens for TimestampRules { let now = ::time::OffsetDateTime::now_utc(); let d = #d; if #val < now - d || #val > now + d { - return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Within(d))); + #maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::timestamp::Error::Within(d))); } } } else { diff --git a/prost-validate-tests/benches/harness.rs b/prost-validate-tests/benches/harness.rs index e44ff76..d4a8476 100644 --- a/prost-validate-tests/benches/harness.rs +++ b/prost-validate-tests/benches/harness.rs @@ -16,18 +16,37 @@ fn reflect_validate() { fn derive_validate() { for (name, f) in CASES.iter() { let (message, failures) = f(); - match ::prost_validate::Validator::validate(message.as_ref()) { + match prost_validate::Validator::validate(message.as_ref()) { Ok(_) => assert_eq!(failures, 0, "{name}: unexpected validation success"), Err(err) => assert!(failures > 0, "{name}: unexpected validation failure: {err}"), } } } +#[cfg(feature = "derive")] +fn derive_validate_all() { + for (name, f) in CASES.iter() { + let (message, failures) = f(); + match prost_validate::Validator::validate_all(message.as_ref()) { + Ok(()) => assert_eq!(failures, 0, "{name}: unexpected validation success"), + Err(errs) => { + assert_eq!( + failures as usize, + errs.len(), + "{name}: unexpected validation failures: {errs:?}" + ); + } + } + } +} + fn criterion_benchmark(c: &mut Criterion) { #[cfg(feature = "reflect")] c.bench_function("harness reflect", |b| b.iter(reflect_validate)); #[cfg(feature = "derive")] c.bench_function("harness derive", |b| b.iter(derive_validate)); + #[cfg(feature = "derive")] + c.bench_function("harness derive_all", |b| b.iter(derive_validate_all)); } criterion_group!(benches, criterion_benchmark); diff --git a/prost-validate-tests/src/test_cases.rs b/prost-validate-tests/src/test_cases.rs index f7c42f3..604d357 100644 --- a/prost-validate-tests/src/test_cases.rs +++ b/prost-validate-tests/src/test_cases.rs @@ -34,6 +34,19 @@ macro_rules! test_cases { Ok(_) => assert_eq!(failures, 0, "unexpected validation success"), } } + + #[cfg(feature = "derive")] + #[test] + fn derive_all() { + let (message, failures) = crate::cases::CASES.get(stringify!($name)).unwrap()(); + match ValidatorDerive::validate_all(message.as_ref()) { + Ok(()) => assert_eq!(failures, 0, "unexpected validation success"), + Err(errs) => { + println!("{errs:?}"); + assert_eq!(failures as usize, errs.len(), "unexpected validation failures: {errs:?}"); + }, + } + } } )* } diff --git a/prost-validate-tests/src/test_pbjson_cases.rs b/prost-validate-tests/src/test_pbjson_cases.rs index 189c7c0..26e2ca6 100644 --- a/prost-validate-tests/src/test_pbjson_cases.rs +++ b/prost-validate-tests/src/test_pbjson_cases.rs @@ -34,6 +34,19 @@ macro_rules! test_cases { Ok(_) => assert_eq!(failures, 0, "unexpected validation success"), } } + + #[cfg(feature = "derive")] + #[test] + fn derive_all() { + let (message, failures) = crate::cases_pbjson::CASES.get(stringify!($name)).unwrap()(); + match ValidatorDerive::validate_all(message.as_ref()) { + Ok(()) => assert_eq!(failures, 0, "unexpected validation success"), + Err(errs) => { + println!("{errs:?}"); + assert_eq!(failures as usize, errs.len(), "unexpected validation failures: {errs:?}"); + }, + } + } } )* } diff --git a/prost-validate/src/lib.rs b/prost-validate/src/lib.rs index 3975491..294c624 100644 --- a/prost-validate/src/lib.rs +++ b/prost-validate/src/lib.rs @@ -25,6 +25,10 @@ pub trait Validator: Send + Sync { fn validate(&self) -> Result { Ok(()) } + + fn validate_all(&self) -> Result<(), Vec> { + Ok(()) + } } // NoopValidator is the same trait as `Validator`. @@ -34,6 +38,10 @@ pub trait NoopValidator { fn validate(&self) -> Result { Ok(()) } + + fn validate_all(&self) -> Result<(), Vec> { + Ok(()) + } } // Implement `NoopValidator` for any type. @@ -49,6 +57,10 @@ impl SafeValidator<'_, T> { pub fn validate(&self) -> Result { Validator::validate(self.0) } + + pub fn validate_all(&self) -> Result<(), Vec> { + Validator::validate_all(self.0) + } } /// Validate any value if it implements the Validator trait. @@ -62,6 +74,17 @@ macro_rules! validate { }}; } +/// Validate any value if it implements the Validator trait, returning all errors. +/// If the value does not implement the Validator trait, it will return vec![]. +#[macro_export] +macro_rules! validate_all { + ($value:tt) => {{ + use ::prost_validate::NoopValidator; + use std::ops::Deref; + ::prost_validate::SafeValidator($value.deref()).validate_all() + }}; +} + #[cfg(test)] mod tests { pub struct A {} @@ -73,6 +96,13 @@ mod tests { prost_validate::errors::Error::InvalidRules("failed".to_string()), )) } + + fn validate_all(&self) -> Result<(), Vec> { + Err(vec![prost_validate::Error::new( + "", + prost_validate::errors::Error::InvalidRules("failed".to_string()), + )]) + } } pub struct B {} @@ -81,20 +111,24 @@ mod tests { fn test_validator() { let a = &A {}; assert!(prost_validate::validate!(a).is_err()); + assert!(prost_validate::validate_all!(a).is_err()); } #[test] fn test_validator_double_ref() { let a = &&A {}; assert!(prost_validate::validate!(a).is_err()); + assert!(prost_validate::validate_all!(a).is_err()); } #[test] fn test_non_validator() { let b = &B {}; assert!(prost_validate::validate!(b).is_ok()); + assert!(prost_validate::validate_all!(b).is_ok()); } #[test] fn test_scalar() { let c = &42; assert!(prost_validate::validate!(c).is_ok()); + assert!(prost_validate::validate_all!(c).is_ok()); } }