diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index b41904761..e69de29bb 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,3 +0,0 @@ -# TODO(tokio-rs/prost#2): prost doesn't preserve unknown fields. -Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput -Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 7314c69e6..d4601c942 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -279,6 +279,11 @@ impl<'b> CodeGenerator<'_, 'b> { } self.path.pop(); } + if let Some(unknown_fields) = &self.config().include_unknown_fields { + if let Some(field_name) = unknown_fields.get_first(&fq_message_name).cloned() { + self.append_unknown_field_set(&fq_message_name, &field_name); + } + } self.path.pop(); self.path.push(8); @@ -581,6 +586,14 @@ impl<'b> CodeGenerator<'_, 'b> { )); } + fn append_unknown_field_set(&mut self, fq_message_name: &str, field_name: &str) { + self.buf.push_str("#[prost(unknown_fields)]\n"); + self.append_field_attributes(fq_message_name, field_name); + self.push_indent(); + self.buf + .push_str(&format!("pub {}: ::prost::UnknownFieldList,\n", field_name,)); + } + fn append_oneof_field( &mut self, message_name: &str, diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index 8c6a6d93e..0e5455cce 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -36,6 +36,7 @@ pub struct Config { pub(crate) message_attributes: PathMap, pub(crate) enum_attributes: PathMap, pub(crate) field_attributes: PathMap, + pub(crate) include_unknown_fields: Option>, pub(crate) boxed: PathMap<()>, pub(crate) prost_types: bool, pub(crate) strip_enum_prefix: bool, @@ -266,6 +267,36 @@ impl Config { self } + /// Preserve unknown fields for the message type. + /// + /// # Arguments + /// + /// **`paths`** - paths to specific messages, or packages which should preserve unknown + /// fields during deserialization. + /// + /// **`field_name`** - the name of the field to place unknown fields in. A field with this + /// name and type `prost::UnknownFieldList` will be added to the generated struct + /// + /// # Examples + /// + /// ```rust + /// # let mut config = prost_build::Config::new(); + /// config.include_unknown_fields(".my_messages.MyMessageType", "unknown_fields"); + /// ``` + pub fn include_unknown_fields(&mut self, path: P, field_name: A) -> &mut Self + where + P: AsRef, + A: AsRef, + { + if self.include_unknown_fields.is_none() { + self.include_unknown_fields = Some(PathMap::default()); + } + if let Some(unknown_fields) = &mut self.include_unknown_fields { + unknown_fields.insert(path.as_ref().to_string(), field_name.as_ref().to_string()); + } + self + } + /// Add additional attribute to matched messages. /// /// # Arguments @@ -1202,6 +1233,7 @@ impl default::Default for Config { message_attributes: PathMap::default(), enum_attributes: PathMap::default(), field_attributes: PathMap::default(), + include_unknown_fields: None, boxed: PathMap::default(), prost_types: true, strip_enum_prefix: true, @@ -1234,6 +1266,7 @@ impl fmt::Debug for Config { .field("bytes_type", &self.bytes_type) .field("type_attributes", &self.type_attributes) .field("field_attributes", &self.field_attributes) + .field("include_unknown_fields", &self.include_unknown_fields) .field("prost_types", &self.prost_types) .field("strip_enum_prefix", &self.strip_enum_prefix) .field("out_dir", &self.out_dir) diff --git a/prost-build/src/context.rs b/prost-build/src/context.rs index 7fecc8dfc..db75eb049 100644 --- a/prost-build/src/context.rs +++ b/prost-build/src/context.rs @@ -182,6 +182,13 @@ impl<'a> Context<'a> { /// Returns `true` if this message can automatically derive Copy trait. pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool { assert_eq!(".", &fq_message_name[..1]); + // Unknown fields can potentially include an unbounded Bytes object, which + // cannot implement Copy + if let Some(unknown_fields) = &self.config().include_unknown_fields { + if unknown_fields.get_first(fq_message_name).is_some() { + return false; + } + }; self.message_graph .get_message(fq_message_name) .unwrap() diff --git a/prost-derive/src/field/mod.rs b/prost-derive/src/field/mod.rs index d3922b1b4..e93263327 100644 --- a/prost-derive/src/field/mod.rs +++ b/prost-derive/src/field/mod.rs @@ -3,6 +3,7 @@ mod map; mod message; mod oneof; mod scalar; +mod unknown; use std::fmt; use std::slice; @@ -26,6 +27,8 @@ pub enum Field { Oneof(oneof::Field), /// A group field. Group(group::Field), + /// A set of unknown message fields. + Unknown(unknown::Field), } impl Field { @@ -48,6 +51,8 @@ impl Field { Field::Oneof(field) } else if let Some(field) = group::Field::new(&attrs, inferred_tag)? { Field::Group(field) + } else if let Some(field) = unknown::Field::new(&attrs)? { + Field::Unknown(field) } else { bail!("no type attribute"); }; @@ -86,6 +91,7 @@ impl Field { Field::Map(ref map) => vec![map.tag], Field::Oneof(ref oneof) => oneof.tags.clone(), Field::Group(ref group) => vec![group.tag], + Field::Unknown(_) => vec![], } } @@ -97,6 +103,7 @@ impl Field { Field::Map(ref map) => map.encode(prost_path, ident), Field::Oneof(ref oneof) => oneof.encode(ident), Field::Group(ref group) => group.encode(prost_path, ident), + Field::Unknown(ref unknown) => unknown.encode(ident), } } @@ -109,6 +116,7 @@ impl Field { Field::Map(ref map) => map.merge(prost_path, ident), Field::Oneof(ref oneof) => oneof.merge(ident), Field::Group(ref group) => group.merge(prost_path, ident), + Field::Unknown(ref unknown) => unknown.merge(ident), } } @@ -120,6 +128,7 @@ impl Field { Field::Message(ref msg) => msg.encoded_len(prost_path, ident), Field::Oneof(ref oneof) => oneof.encoded_len(ident), Field::Group(ref group) => group.encoded_len(prost_path, ident), + Field::Unknown(ref unknown) => unknown.encoded_len(ident), } } @@ -131,6 +140,7 @@ impl Field { Field::Map(ref map) => map.clear(ident), Field::Oneof(ref oneof) => oneof.clear(ident), Field::Group(ref group) => group.clear(ident), + Field::Unknown(ref unknown) => unknown.clear(ident), } } @@ -173,6 +183,10 @@ impl Field { _ => None, } } + + pub fn is_unknown(&self) -> bool { + matches!(self, Field::Unknown(_)) + } } #[derive(Clone, Copy, PartialEq, Eq)] diff --git a/prost-derive/src/field/unknown.rs b/prost-derive/src/field/unknown.rs new file mode 100644 index 000000000..b117f79a3 --- /dev/null +++ b/prost-derive/src/field/unknown.rs @@ -0,0 +1,66 @@ +use anyhow::{bail, Error}; +use proc_macro2::TokenStream; +use quote::quote; +use syn::Meta; + +use crate::field::{set_bool, word_attr}; + +#[derive(Clone)] +pub struct Field {} + +impl Field { + pub fn new(attrs: &[Meta]) -> Result, Error> { + let mut unknown = false; + let mut unknown_attrs = Vec::new(); + + for attr in attrs { + if word_attr("unknown_fields", attr) { + set_bool(&mut unknown, "duplicate message attribute")?; + } else { + unknown_attrs.push(attr); + } + } + + if !unknown { + return Ok(None); + } + + match unknown_attrs.len() { + 0 => (), + 1 => bail!( + "unknown attribute for unknown field set: {:?}", + unknown_attrs[0] + ), + _ => bail!( + "unknown attributes for unknown field set: {:?}", + unknown_attrs + ), + } + + Ok(Some(Field {})) + } + + pub fn encode(&self, ident: TokenStream) -> TokenStream { + quote! { + #ident.encode_raw(buf) + } + } + + pub fn merge(&self, ident: TokenStream) -> TokenStream { + quote! { + #ident.merge_field(tag, wire_type, buf, ctx) + } + } + + pub fn encoded_len(&self, ident: TokenStream) -> TokenStream { + quote! { + #ident.encoded_len() + } + } + + pub fn clear(&self, ident: TokenStream) -> TokenStream { + quote! { + #ident.clear() + } + } +} diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 2804ddfbe..195f6065d 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -84,11 +84,17 @@ fn try_message(input: TokenStream) -> Result { // We want Debug to be in declaration order let unsorted_fields = fields.clone(); - // Sort the fields by tag number so that fields will be encoded in tag order. + // Sort the fields by tag number so that fields will be encoded in tag order, + // and unknown fields are encoded last. // TODO: This encodes oneof fields in the position of their lowest tag, // regardless of the currently occupied variant, is that consequential? // See: https://protobuf.dev/programming-guides/encoding/#order - fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap()); + fields.sort_by_key(|(_, field)| { + ( + field.is_unknown(), + field.tags().into_iter().min().unwrap_or(0), + ) + }); let fields = fields; if let Some(duplicate_tag) = fields @@ -113,6 +119,9 @@ fn try_message(input: TokenStream) -> Result { .map(|(field_ident, field)| field.encode(&prost_path, quote!(self.#field_ident))); let merge = fields.iter().map(|(field_ident, field)| { + if field.is_unknown() { + return quote!(); + } let merge = field.merge(&prost_path, quote!(value)); let tags = field.tags().into_iter().map(|tag| quote!(#tag)); let tags = Itertools::intersperse(tags, quote!(|)); @@ -127,6 +136,23 @@ fn try_message(input: TokenStream) -> Result { }, } }); + let merge_fallback = match fields.iter().find(|&(_, f)| f.is_unknown()) { + Some((field_ident, field)) => { + let merge = field.merge(&prost_path, quote!(value)); + quote! { + _ => { + let mut value = &mut self.#field_ident; + #merge.map_err(|mut error| { + error.push(STRUCT_NAME, stringify!(#field_ident)); + error + }) + }, + } + } + None => quote! { + _ => #prost_path::encoding::skip_field(wire_type, tag, buf, ctx), + }, + }; let struct_name = if fields.is_empty() { quote!() @@ -192,7 +218,7 @@ fn try_message(input: TokenStream) -> Result { #struct_name match tag { #(#merge)* - _ => #prost_path::encoding::skip_field(wire_type, tag, buf, ctx), + #merge_fallback } } diff --git a/prost/src/lib.rs b/prost/src/lib.rs index 7526b8459..794941961 100644 --- a/prost/src/lib.rs +++ b/prost/src/lib.rs @@ -13,6 +13,7 @@ mod error; mod message; mod name; mod types; +mod unknown; #[doc(hidden)] pub mod encoding; @@ -23,6 +24,7 @@ pub use crate::encoding::length_delimiter::{ pub use crate::error::{DecodeError, EncodeError, UnknownEnumValue}; pub use crate::message::Message; pub use crate::name::Name; +pub use crate::unknown::{UnknownField, UnknownFieldIter, UnknownFieldList}; // See `encoding::DecodeContext` for more info. // 100 is the default recursion limit in the C++ implementation. diff --git a/prost/src/unknown.rs b/prost/src/unknown.rs new file mode 100644 index 000000000..db7817526 --- /dev/null +++ b/prost/src/unknown.rs @@ -0,0 +1,182 @@ +use alloc::collections::btree_map::{self, BTreeMap}; +use alloc::vec::Vec; +use core::slice; + +use bytes::{Buf, BufMut, Bytes}; + +use crate::encoding::{self, DecodeContext, WireType}; +use crate::{DecodeError, Message}; + +/// A set of unknown fields in a protobuf message. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +pub struct UnknownFieldList { + fields: BTreeMap>, +} + +/// An unknown field in a protobuf message. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum UnknownField { + /// An unknown field with the `Varint` wire type. + Varint(u64), + /// An unknown field with the `SixtyFourBit` wire type. + SixtyFourBit(u64), + /// An unknown field with the `LengthDelimited` wire type. + LengthDelimited(Bytes), + /// An unknown field with the group wire type. + Group(UnknownFieldList), + /// An unknown field with the `ThirtyTwoBit` wire type. + ThirtyTwoBit(u32), +} + +/// An iterator over the fields of an [UnknownFieldList]. +#[derive(Debug)] +pub struct UnknownFieldIter<'a> { + tags_iter: btree_map::Iter<'a, u32, Vec>, + current_tag: Option<(u32, slice::Iter<'a, UnknownField>)>, +} + +impl UnknownFieldList { + /// Creates an empty [UnknownFieldList]. + pub fn new() -> Self { + Default::default() + } + + /// Gets an iterator over the fields contained in this set. + pub fn iter(&self) -> UnknownFieldIter<'_> { + UnknownFieldIter { + tags_iter: self.fields.iter(), + current_tag: None, + } + } +} + +impl<'a> Iterator for UnknownFieldIter<'a> { + type Item = (u32, &'a UnknownField); + + fn next(&mut self) -> Option { + loop { + if let Some((tag, iter)) = &mut self.current_tag { + if let Some(value) = iter.next() { + return Some((*tag, value)); + } else { + self.current_tag = None; + } + } + if let Some((tag, values)) = self.tags_iter.next() { + self.current_tag = Some((*tag, values.iter())); + } else { + return None; + } + } + } +} + +impl Message for UnknownFieldList { + fn encode_raw(&self, buf: &mut impl BufMut) + where + Self: Sized, + { + for (&tag, fields) in &self.fields { + for field in fields { + match field { + UnknownField::Varint(value) => { + encoding::encode_key(tag, WireType::Varint, buf); + encoding::encode_varint(*value, buf); + } + UnknownField::SixtyFourBit(value) => { + encoding::encode_key(tag, WireType::SixtyFourBit, buf); + buf.put_u64_le(*value); + } + UnknownField::LengthDelimited(value) => { + encoding::bytes::encode(tag, value, buf); + } + UnknownField::Group(value) => { + encoding::group::encode(tag, value, buf); + } + UnknownField::ThirtyTwoBit(value) => { + encoding::encode_key(tag, WireType::ThirtyTwoBit, buf); + buf.put_u32_le(*value); + } + } + } + } + } + + fn merge_field( + &mut self, + tag: u32, + wire_type: WireType, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> + where + Self: Sized, + { + let field = match wire_type { + WireType::Varint => { + let value = encoding::decode_varint(buf)?; + UnknownField::Varint(value) + } + WireType::SixtyFourBit => { + let mut value = [0; 8]; + if buf.remaining() < value.len() { + return Err(DecodeError::new("buffer underflow")); + } + buf.copy_to_slice(&mut value); + //https://protobuf.dev/programming-guides/encoding/ + let return_val = u64::from_le_bytes(value); + UnknownField::SixtyFourBit(return_val) + } + WireType::LengthDelimited => { + let mut value = Bytes::default(); + encoding::bytes::merge(wire_type, &mut value, buf, ctx)?; + UnknownField::LengthDelimited(value) + } + WireType::StartGroup => { + let mut value = UnknownFieldList::default(); + encoding::group::merge(tag, wire_type, &mut value, buf, ctx)?; + UnknownField::Group(value) + } + WireType::EndGroup => { + return Err(DecodeError::new("unexpected end group tag")); + } + WireType::ThirtyTwoBit => { + let mut value = [0; 4]; + if buf.remaining() < value.len() { + return Err(DecodeError::new("buffer underflow")); + } + buf.copy_to_slice(&mut value); + //https://protobuf.dev/programming-guides/encoding/ + let return_val = u32::from_le_bytes(value); + UnknownField::ThirtyTwoBit(return_val) + } + }; + + self.fields.entry(tag).or_default().push(field); + Ok(()) + } + + fn encoded_len(&self) -> usize { + let mut len = 0; + for (&tag, fields) in &self.fields { + for field in fields { + len += match field { + UnknownField::Varint(value) => { + encoding::key_len(tag) + encoding::encoded_len_varint(*value) + } + UnknownField::SixtyFourBit(_) => encoding::key_len(tag) + 8, + UnknownField::LengthDelimited(value) => { + encoding::bytes::encoded_len(tag, value) + } + UnknownField::Group(value) => encoding::group::encoded_len(tag, value), + UnknownField::ThirtyTwoBit(_) => encoding::key_len(tag) + 4, + }; + } + } + len + } + + fn clear(&mut self) { + self.fields.clear(); + } +} diff --git a/protobuf/build.rs b/protobuf/build.rs index 431e4cae7..87af107d4 100644 --- a/protobuf/build.rs +++ b/protobuf/build.rs @@ -54,6 +54,7 @@ fn main() -> Result<()> { prost_build::Config::new() .protoc_executable(&protoc_executable) .btree_map(["."]) + .include_unknown_fields(".", "unknown_fields") .compile_protos( &[ proto_dir.join("google/protobuf/test_messages_proto2.proto"), diff --git a/tests/build.rs b/tests/build.rs index 278e0d7f6..77d16d3dc 100644 --- a/tests/build.rs +++ b/tests/build.rs @@ -27,6 +27,8 @@ fn main() { // values. let mut config = prost_build::Config::new(); config.btree_map(["."]); + config.include_unknown_fields(".unknown_fields.BlankMessage", "unknown_fields"); + config.include_unknown_fields(".unknown_fields.MessageWithData", "unknown_fields"); config.type_attribute( "Foo.Custom.OneOfAttrs.Msg.field", "#[derive(PartialOrd, Ord)]", @@ -186,6 +188,10 @@ fn main() { .compile_protos(&[src.join("no_package.proto")], includes) .unwrap(); + config + .compile_protos(&[src.join("unknown_fields.proto")], includes) + .unwrap(); + let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR environment variable not set")); // Check that attempting to compile a .proto without a package declaration succeeds. diff --git a/tests/src/decode_error.rs b/tests/src/decode_error.rs index 16fb0c8d8..6af3b580c 100644 --- a/tests/src/decode_error.rs +++ b/tests/src/decode_error.rs @@ -27,7 +27,10 @@ fn test_decode_error_multiple_levels() { use protobuf::test_messages::proto3::ForeignMessage; let msg = TestAllTypesProto3 { recursive_message: Some(Box::new(TestAllTypesProto3 { - optional_foreign_message: Some(ForeignMessage { c: -1 }), + optional_foreign_message: Some(ForeignMessage { + c: -1, + ..Default::default() + }), ..Default::default() })), ..Default::default() diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 8321187a7..743b2a947 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -48,6 +48,8 @@ mod no_unused_results; mod submessage_without_package; #[cfg(test)] mod type_names; +#[cfg(test)] +mod unknown_fields; #[cfg(test)] mod boxed_field; diff --git a/tests/src/unknown_fields.proto b/tests/src/unknown_fields.proto new file mode 100644 index 000000000..9b371285e --- /dev/null +++ b/tests/src/unknown_fields.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package unknown_fields; + +message BlankMessage { +} + +message MessageWithData { + int32 a = 1; + fixed32 b = 2; + fixed64 c = 3; + string d = 4; +} diff --git a/tests/src/unknown_fields.rs b/tests/src/unknown_fields.rs new file mode 100644 index 000000000..cb1d9d3a7 --- /dev/null +++ b/tests/src/unknown_fields.rs @@ -0,0 +1,57 @@ +include!(concat!(env!("OUT_DIR"), "/unknown_fields.rs")); + +#[cfg(feature = "std")] +#[test] +fn test_iter_unknown_fields() { + use prost::bytes::Bytes; + use prost::{Message, UnknownField}; + + let v2 = MessageWithData { + a: 12345, + b: 6, + c: 7, + d: "hello".to_owned(), + unknown_fields: Default::default(), + }; + + let bytes = v2.encode_to_vec(); + let v1 = BlankMessage::decode(&*bytes).unwrap(); + + let mut fields = v1.unknown_fields.iter(); + assert_eq!(fields.next(), Some((1, &UnknownField::Varint(12345)))); + assert_eq!(fields.next(), Some((2, &UnknownField::ThirtyTwoBit(6)))); + assert_eq!(fields.next(), Some((3, &UnknownField::SixtyFourBit(7)))); + assert_eq!( + fields.next(), + Some(( + 4, + &UnknownField::LengthDelimited(Bytes::from(&b"hello"[..])) + )) + ); + assert_eq!(fields.next(), None); + + assert_eq!(v2.unknown_fields.iter().count(), 0); +} + +#[cfg(feature = "std")] +#[test] +fn test_roundtrip_unknown_fields() { + use prost::Message; + + let original = MessageWithData { + a: 12345, + b: 6, + c: 7, + d: "hello".to_owned(), + unknown_fields: Default::default(), + }; + + let original_bytes = original.encode_to_vec(); + let roundtripped_bytes = BlankMessage::decode(&*original_bytes) + .unwrap() + .encode_to_vec(); + + let roundtripped = MessageWithData::decode(&*roundtripped_bytes).unwrap(); + assert_eq!(original, roundtripped); + assert_eq!(roundtripped.unknown_fields.iter().count(), 0); +}