Skip to content

Commit 3fac38e

Browse files
committed
feat(derive): add default attr for skipped fields
Prevents the skipped field from needing to implement the `Default` trait. Usage: ```rs struct Test { #[prost(int32, tag = "1")] a: i32, #[prost(skip)] c: i32, #[prost(skip, default = "create_foo")] d: Foo, } struct Foo(i32); pub fn create_foo() -> Foo { Foo(12) } ```
1 parent 26a3d25 commit 3fac38e

File tree

3 files changed

+62
-33
lines changed

3 files changed

+62
-33
lines changed

prost-derive/src/field/mod.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ impl Field {
9494
/// Returns a statement which encodes the field.
9595
pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
9696
match *self {
97-
Field::Skip(ref ignore) => ignore.encode(ident),
97+
Field::Skip(_) => TokenStream::default(),
9898
Field::Scalar(ref scalar) => scalar.encode(prost_path, ident),
9999
Field::Message(ref message) => message.encode(prost_path, ident),
100100
Field::Map(ref map) => map.encode(prost_path, ident),
@@ -107,7 +107,7 @@ impl Field {
107107
/// value into the field.
108108
pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
109109
match *self {
110-
Field::Skip(ref ignore) => ignore.merge(ident),
110+
Field::Skip(_) => TokenStream::default(),
111111
Field::Scalar(ref scalar) => scalar.merge(prost_path, ident),
112112
Field::Message(ref message) => message.merge(prost_path, ident),
113113
Field::Map(ref map) => map.merge(prost_path, ident),
@@ -119,7 +119,7 @@ impl Field {
119119
/// Returns an expression which evaluates to the encoded length of the field.
120120
pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
121121
match *self {
122-
Field::Skip(ref ignore) => ignore.encoded_len(ident),
122+
Field::Skip(_) => quote!(0),
123123
Field::Scalar(ref scalar) => scalar.encoded_len(prost_path, ident),
124124
Field::Map(ref map) => map.encoded_len(prost_path, ident),
125125
Field::Message(ref msg) => msg.encoded_len(prost_path, ident),
@@ -131,7 +131,7 @@ impl Field {
131131
/// Returns a statement which clears the field.
132132
pub fn clear(&self, ident: TokenStream) -> TokenStream {
133133
match *self {
134-
Field::Skip(ref ignore) => ignore.clear(ident),
134+
Field::Skip(ref skip) => skip.clear(ident),
135135
Field::Scalar(ref scalar) => scalar.clear(ident),
136136
Field::Message(ref message) => message.clear(ident),
137137
Field::Map(ref map) => map.clear(ident),

prost-derive/src/field/skip.rs

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,44 @@
11
use anyhow::{bail, Error};
22
use proc_macro2::TokenStream;
33
use quote::quote;
4-
use syn::{Meta};
4+
use syn::{Meta, Lit, MetaNameValue, Path, Expr, ExprLit};
55

66
use crate::field::{set_bool, word_attr};
77

88
#[derive(Clone)]
9-
pub struct Field;
9+
pub struct Field {
10+
pub default_fn: Option<Path>,
11+
}
1012

1113
impl Field {
1214
pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> {
1315
let mut skip = false;
16+
let mut default_fn = None;
1417
let mut unknown_attrs = Vec::new();
1518

1619
for attr in attrs {
1720
if word_attr("skip", attr) {
18-
set_bool(&mut skip, "duplicate ignore attribute")?;
21+
set_bool(&mut skip, "duplicate skip attribute")?;
22+
} else if let Meta::NameValue(MetaNameValue { path, value, .. }) = attr {
23+
if path.is_ident("default") {
24+
let lit_str = match value {
25+
// There has to be a better way...
26+
Expr::Lit(ExprLit { lit: Lit::Str(lit), .. }) => Some(lit),
27+
_ => None,
28+
};
29+
if let Some(lit) = lit_str {
30+
let fn_path: Path = syn::parse_str(&lit.value())
31+
.map_err(|_| anyhow::anyhow!("invalid path for default function"))?;
32+
if default_fn.is_some() {
33+
bail!("duplicate default attribute for skipped field");
34+
}
35+
default_fn = Some(fn_path);
36+
} else {
37+
bail!("default attribute value must be a string literal");
38+
}
39+
} else {
40+
unknown_attrs.push(attr);
41+
}
1942
} else {
2043
unknown_attrs.push(attr);
2144
}
@@ -27,30 +50,24 @@ impl Field {
2750

2851
if !unknown_attrs.is_empty() {
2952
bail!(
30-
"unknown attribute(s) for ignored field: #[prost({})]",
53+
"unknown attribute(s) for skipped field: #[prost({})]",
3154
quote!(#(#unknown_attrs),*)
3255
);
3356
}
3457

35-
Ok(Some(Field))
58+
Ok(Some(Field { default_fn }))
3659
}
3760

38-
/// Returns a statement which non-ops, since the field is ignored.
39-
pub fn encode(&self, _: TokenStream) -> TokenStream {
40-
quote!()
41-
}
42-
43-
/// Returns an expression which evaluates to the default value of the ignored field.
44-
pub fn merge(&self, ident: TokenStream) -> TokenStream {
45-
quote!(#ident.get_or_insert_with(::core::default::Default::default))
46-
}
47-
48-
/// Returns an expression which evaluates to 0
49-
pub fn encoded_len(&self, _: TokenStream) -> TokenStream {
50-
quote!(0)
61+
pub fn clear(&self, ident: TokenStream) -> TokenStream {
62+
let default = self.default_value();
63+
quote!( #ident = #default; )
5164
}
5265

53-
pub fn clear(&self, ident: TokenStream) -> TokenStream {
54-
quote!(#ident = ::core::default::Default::default)
66+
pub fn default_value(&self) -> TokenStream {
67+
if let Some(ref path) = self.default_fn {
68+
quote! { #path() }
69+
} else {
70+
quote! { ::core::default::Default::default() }
71+
}
5572
}
5673
}

prost-derive/src/lib.rs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,19 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
8585
let unsorted_fields = fields.clone();
8686

8787
// Filter out ignored fields
88-
fields.retain(|(_, field)| matches!(field, Field::Skip(..)));
88+
fields.retain(|(_, field)| !matches!(field, Field::Skip(..)));
8989

9090
// Sort the fields by tag number so that fields will be encoded in tag order.
9191
// TODO: This encodes oneof fields in the position of their lowest tag,
9292
// regardless of the currently occupied variant, is that consequential?
93+
let all_fields = unsorted_fields.clone();
94+
let mut active_fields = all_fields.clone();
95+
// Filter out skipped fields for encoding/decoding/length
96+
active_fields.retain(|(_, field)| !matches!(field, Field::Skip(_)));
97+
// Sort the active fields by tag number so that fields will be encoded in tag order.
9398
// See: https://protobuf.dev/programming-guides/encoding/#order
94-
fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
95-
let fields = fields;
99+
active_fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
100+
let fields = active_fields;
96101

97102
if let Some(duplicate_tag) = fields
98103
.iter()
@@ -131,29 +136,36 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
131136
}
132137
});
133138

134-
let struct_name = if fields.is_empty() {
139+
let struct_name = if all_fields.is_empty() {
135140
quote!()
136141
} else {
137142
quote!(
138143
const STRUCT_NAME: &'static str = stringify!(#ident);
139144
)
140145
};
141146

142-
let clear = fields
147+
let clear = all_fields
143148
.iter()
144149
.map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
145150

151+
// For Default implementation, use all_fields (including skipped)
146152
let default = if is_struct {
147-
let default = fields.iter().map(|(field_ident, field)| {
148-
let value = field.default(&prost_path);
153+
let default = all_fields.iter().map(|(field_ident, field)| {
154+
let value = match field {
155+
Field::Skip(skip_field) => skip_field.default_value(),
156+
_ => field.default(&prost_path),
157+
};
149158
quote!(#field_ident: #value,)
150159
});
151160
quote! {#ident {
152161
#(#default)*
153162
}}
154163
} else {
155-
let default = fields.iter().map(|(_, field)| {
156-
let value = field.default(&prost_path);
164+
let default = all_fields.iter().map(|(_, field)| {
165+
let value = match field {
166+
Field::Skip(skip_field) => skip_field.default_value(),
167+
_ => field.default(&prost_path),
168+
};
157169
quote!(#value,)
158170
});
159171
quote! {#ident (

0 commit comments

Comments
 (0)