Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions prost-validate-derive-core/src/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ impl ToValidationTokens for AnyRules {
fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream {
let field = &ctx.name;
let rules = prost_validate_types::AnyRules::from(self.to_owned());
let maybe_return = ctx.maybe_return();
let r#in = rules.r#in.is_empty().not().then(|| {
let v = rules.r#in;
quote! {
let values = vec![#(#v),*];
if !values.contains(&#name.type_url.as_str()) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::any::Error::In(values.iter().map(|v|v.to_string()).collect())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::any::Error::In(values.iter().map(|v|v.to_string()).collect())));
}
}
});
Expand All @@ -31,7 +32,7 @@ impl ToValidationTokens for AnyRules {
quote! {
let values = vec![#(#v),*];
if values.contains(&#name.type_url.as_str()) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::any::Error::NotIn(values.iter().map(|v|v.to_string()).collect())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::any::Error::NotIn(values.iter().map(|v|v.to_string()).collect())));
}
}
});
Expand Down
3 changes: 2 additions & 1 deletion prost-validate-derive-core/src/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ pub struct BoolRules {
impl ToValidationTokens for BoolRules {
fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream {
let field = &ctx.name;
let maybe_return = ctx.maybe_return();
let r#const = self.r#const.map(|v| {
quote! {
if *#name != #v {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::r#bool::Error::Const(#v)));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::r#bool::Error::Const(#v)));
}
}
});
Expand Down
29 changes: 15 additions & 14 deletions prost-validate-derive-core/src/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,35 +27,36 @@ impl ToValidationTokens for BytesRules {
fn to_validation_tokens(&self, ctx: &Context, name: &Ident) -> TokenStream {
let field = &ctx.name;
let rules = prost_validate_types::BytesRules::from(self.to_owned());
let maybe_return = ctx.maybe_return();
let r#const = rules.r#const.map(|v| {
let v = LitByteStr::new(v.as_slice(), Span::call_site());
quote! {
if !#name.iter().eq(#v.iter()) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Const(#v.to_vec())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Const(#v.to_vec())));
}
}
});
let len = rules.len.map(|v| {
let v = v as usize;
quote! {
if #name.len() != #v {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Len(#v)));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Len(#v)));
}
}
});
let min_len = rules.min_len.map(|v| {
let v = v as usize;
quote! {
if #name.len() < #v {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MinLen(#v)));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MinLen(#v)));
}
}
});
let max_len = rules.max_len.map(|v| {
let v = v as usize;
quote! {
if #name.len() > #v {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MaxLen(#v)));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::MaxLen(#v)));
}
}
});
Expand All @@ -65,10 +66,10 @@ impl ToValidationTokens for BytesRules {
}
quote! {
match ::regex::bytes::Regex::new(#v) {
Err(e) => return Err(::prost_validate::Error::new(#field, format!("Invalid regex pattern: {e}"))),
Err(e) => #maybe_return(::prost_validate::Error::new(#field, format!("Invalid regex pattern: {e}"))),
Ok(regex) => {
if !regex.is_match(#name.iter().as_slice()) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Pattern(#v.to_string())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Pattern(#v.to_string())));
}
}
}
Expand All @@ -78,23 +79,23 @@ impl ToValidationTokens for BytesRules {
let v = LitByteStr::new(v.as_slice(), Span::call_site());
quote! {
if !#name.starts_with(#v) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Prefix(#v.to_vec())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Prefix(#v.to_vec())));
}
}
});
let suffix = rules.suffix.map(|v| {
let v = LitByteStr::new(v.as_slice(), Span::call_site());
quote! {
if !#name.ends_with(#v) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Suffix(#v.to_vec())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Suffix(#v.to_vec())));
}
}
});
let contains = rules.contains.map(|v| {
let v = LitByteStr::new(v.as_slice(), Span::call_site());
quote! {
if !::prost_validate::ValidateBytesExt::contains(&#name, #v.as_slice()) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Contains(#v.to_vec())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Contains(#v.to_vec())));
}
}
});
Expand All @@ -107,7 +108,7 @@ impl ToValidationTokens for BytesRules {
quote! {
let values = [#(#v.to_vec()),*];
if !values.contains(&#name) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::In(values.iter().map(|v| v.to_vec()).collect())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::In(values.iter().map(|v| v.to_vec()).collect())));
}
}
});
Expand All @@ -120,29 +121,29 @@ impl ToValidationTokens for BytesRules {
quote! {
let values = [#(#v.to_vec()),*];
if values.contains(&#name) {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::NotIn(values.iter().map(|v| v.to_vec()).collect())));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::NotIn(values.iter().map(|v| v.to_vec()).collect())));
}
}
});
let well_known = rules.well_known.map(|v| match v {
bytes_rules::WellKnown::Ip(true) => {
quote! {
if #name.len() != 4 && #name.len() != 16 {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ip));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ip));
}
}
}
bytes_rules::WellKnown::Ipv4(true) => {
quote! {
if #name.len() != 4 {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv4));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv4));
}
}
}
bytes_rules::WellKnown::Ipv6(true) => {
quote! {
if #name.len() != 16 {
return Err(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv6));
#maybe_return(::prost_validate::Error::new(#field, ::prost_validate::errors::bytes::Error::Ipv6));
}
}
}
Expand Down
54 changes: 35 additions & 19 deletions prost-validate-derive-core/src/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,11 @@ pub fn derive_with_module(
let opts = Opts::from_derive_input(&input).expect("Wrong validate options");
let DeriveInput { ident, .. } = input;

let implementation = match opts.data {
Data::Enum(e) => e
.iter()
.map(|v| Field {
module: module.clone().map(|v| v.to_string()),
..v.clone()
})
.map(|v| v.to_token_stream())
.collect::<proc_macro2::TokenStream>(),
Data::Struct(s) => s
.fields
.iter()
.map(|v| Field {
module: module.clone().map(|v| v.to_string()),
..v.clone()
})
.map(|field| field.into_token_stream())
.collect::<proc_macro2::TokenStream>(),
};
let implementation = body_tokens(&opts, module.clone(), false);
let impl_multierrs = body_tokens(&opts, module.clone(), true);

let allow = quote! {
#[allow(clippy::collapsible_if)]
#[allow(clippy::regex_creation_in_loops)]
#[allow(irrefutable_let_patterns)]
#[allow(unused_variables)]
Expand All @@ -62,6 +46,13 @@ pub fn derive_with_module(
#implementation
Ok(())
}

#allow
fn validate_all(&self) -> ::core::result::Result<(), Vec<::prost_validate::Error>> {
let mut errs = vec![];
#impl_multierrs
if errs.is_empty() { Ok(()) } else { Err(errs) }
}
}
}
} else {
Expand All @@ -71,6 +62,30 @@ pub fn derive_with_module(
}
}

fn body_tokens(opts: &Opts, module: Option<TokenStream>, multierrs: bool) -> TokenStream {
match &opts.data {
Data::Enum(e) => e
.iter()
.map(|v| Field {
module: module.clone().map(|v| v.to_string()),
multierrs,
..v.clone()
})
.map(|v| v.to_token_stream())
.collect(),
Data::Struct(s) => s
.fields
.iter()
.map(|v| Field {
module: module.clone().map(|v| v.to_string()),
multierrs,
..v.clone()
})
.map(|field| field.into_token_stream())
.collect(),
}
}

#[cfg(test)]
pub fn derive_2(input: proc_macro2::TokenStream) -> String {
let output = derive(input);
Expand Down Expand Up @@ -107,6 +122,7 @@ mod tests {
use super::*;

#[test]
// cargo test --all-targets --all-features --locked --frozen --offline --no-fail-fast --manifest-path prost-validate-derive-core/Cargo.toml derive::tests -- --nocapture
fn tests() {
let input = quote! {
pub struct WrapperRequiredFloat {
Expand Down
Loading