Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benches/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn benchmark_dataset<M>(criterion: &mut Criterion, name: &str, dataset: &'static
where
M: prost::Message + Default + 'static,
{
let mut group = criterion.benchmark_group(&format!("dataset/{}", name));
let mut group = criterion.benchmark_group(format!("dataset/{}", name));

group.bench_function("merge", move |b| {
let dataset = load_dataset(dataset).unwrap();
Expand Down
65 changes: 47 additions & 18 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ use prost_types::{

use crate::ast::{Comments, Method, Service};
use crate::context::Context;
use crate::enums::EnumRepr;
use crate::ident::{strip_enum_prefix, to_snake, to_upper_camel};
use crate::syntax::Syntax;
use crate::Config;

mod c_escaping;
use c_escaping::unescape_c_escape_string;

mod syntax;
use syntax::Syntax;

/// State object for the code generation process on a single input file.
pub struct CodeGenerator<'a, 'b> {
context: &'a mut Context<'b>,
Expand Down Expand Up @@ -388,7 +387,7 @@ impl<'b> CodeGenerator<'_, 'b> {

fn append_field(&mut self, fq_message_name: &str, field: &Field) {
let type_ = field.descriptor.r#type();
let repeated = field.descriptor.label() == Label::Repeated;
let repeated = field.descriptor.label == Some(Label::Repeated as i32);
let deprecated = self.deprecated(&field.descriptor);
let optional = self.optional(&field.descriptor);
let boxed = self
Expand All @@ -415,12 +414,16 @@ impl<'b> CodeGenerator<'_, 'b> {
let type_tag = self.field_type_tag(&field.descriptor);
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
let bytes_type = self
.context
.bytes_type(fq_message_name, field.descriptor.name());
self.buf
.push_str(&format!("={:?}", bytes_type.annotation()));
match type_ {
Type::Bytes => {
let bytes_type = self
.context
.bytes_type(fq_message_name, field.descriptor.name());
self.buf
.push_str(&format!("={:?}", bytes_type.annotation()));
}
Type::Enum => self.push_enum_type_annotation(fq_message_name, &field.descriptor),
_ => {}
}

match field.descriptor.label() {
Expand Down Expand Up @@ -537,12 +540,16 @@ impl<'b> CodeGenerator<'_, 'b> {
let value_tag = self.map_value_type_tag(value);

self.buf.push_str(&format!(
"#[prost({}=\"{}, {}\", tag=\"{}\")]\n",
"#[prost({}=\"{}, {}\"",
map_type.annotation(),
key_tag,
value_tag,
field.descriptor.number()
));
if value.r#type() == Type::Enum {
self.push_enum_type_annotation(fq_message_name, &field.descriptor);
}
self.buf
.push_str(&format!(", tag=\"{}\")]\n", field.descriptor.number()));
self.append_field_attributes(fq_message_name, field.descriptor.name());
self.push_indent();
self.buf.push_str(&format!(
Expand Down Expand Up @@ -630,11 +637,12 @@ impl<'b> CodeGenerator<'_, 'b> {

self.push_indent();
let ty_tag = self.field_type_tag(&field.descriptor);
self.buf.push_str(&format!(
"#[prost({}, tag=\"{}\")]\n",
ty_tag,
field.descriptor.number()
));
self.buf.push_str(&format!("#[prost({}", ty_tag,));
if field.descriptor.r#type() == Type::Enum {
self.push_enum_type_annotation(&oneof_name, &field.descriptor);
}
self.buf
.push_str(&format!(", tag=\"{}\")]\n", field.descriptor.number()));
self.append_field_attributes(&oneof_name, field.descriptor.name());

self.push_indent();
Expand Down Expand Up @@ -930,13 +938,21 @@ impl<'b> CodeGenerator<'_, 'b> {
self.buf.push_str("}\n");
}

fn push_enum_type_annotation(&mut self, fq_message_name: &str, field: &FieldDescriptorProto) {
match self.enum_field_repr(fq_message_name, field) {
EnumRepr::Int => {}
EnumRepr::Open => self.buf.push_str(", enum_type=\"open\""),
EnumRepr::Closed => self.buf.push_str(", enum_type=\"closed\""),
}
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
match field.r#type() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Type::Uint32 | Type::Fixed32 => String::from("u32"),
Type::Uint64 | Type::Fixed64 => String::from("u64"),
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int32 | Type::Sfixed32 | Type::Sint32 => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String => format!("{}::alloc::string::String", self.context.prost_path()),
Expand All @@ -946,6 +962,15 @@ impl<'b> CodeGenerator<'_, 'b> {
.rust_type()
.to_owned(),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
Type::Enum => match self.enum_field_repr(fq_message_name, field) {
EnumRepr::Int => String::from("i32"),
EnumRepr::Open => format!(
"{}::OpenEnum<{}>",
self.context.prost_path(),
self.resolve_ident(field.type_name())
),
EnumRepr::Closed => self.resolve_ident(field.type_name()),
},
}
}

Expand Down Expand Up @@ -987,6 +1012,10 @@ impl<'b> CodeGenerator<'_, 'b> {
.join("::")
}

fn enum_field_repr(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> EnumRepr {
self.context.enum_field_repr(fq_message_name, field)
}

fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
match field.r#type() {
Type::Float => Cow::Borrowed("float"),
Expand Down
30 changes: 29 additions & 1 deletion prost-build/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use prost_types::{FileDescriptorProto, FileDescriptorSet};

use crate::code_generator::CodeGenerator;
use crate::context::Context;
use crate::enums::EnumFeatures;
use crate::extern_paths::ExternPaths;
use crate::message_graph::MessageGraph;
use crate::path::PathMap;
Expand All @@ -37,6 +38,7 @@ pub struct Config {
pub(crate) enum_attributes: PathMap<String>,
pub(crate) field_attributes: PathMap<String>,
pub(crate) boxed: PathMap<()>,
pub(crate) typed_enum_fields: PathMap<()>,
pub(crate) prost_types: bool,
pub(crate) strip_enum_prefix: bool,
pub(crate) out_dir: Option<PathBuf>,
Expand Down Expand Up @@ -374,6 +376,30 @@ impl Config {
self
}

/// Represent Protobuf enum types encountered in matched fields with types
/// bound to their corresponding Rust enum types, rather than the default `i32`.
///
/// Depending on the proto file syntax, the representation type can be:
/// * For closed enums (in proto2), the corresponding Rust enum type.
/// * For open enums (in proto3), the Rust enum type wrapped in [`OpenEnum`](prost::OpenEnum).
///
/// # Arguments
///
/// **`path`** - a path matching any number of fields. These fields will get the type-checked
/// enum representation.
/// For details about matching fields see [`btree_map`](#method.btree_map).
///
/// # Examples
///
/// ```rust
/// # let mut config = prost_build::Config::new();
/// config.typed_enum_fields(".my_messages");
/// ```
pub fn typed_enum_fields(&mut self, path: impl AsRef<str>) -> &mut Self {
self.typed_enum_fields.insert(path.as_ref().to_owned(), ());
self
}

/// Configures the code generator to use the provided service generator.
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
self.service_generator = Some(service_generator);
Expand Down Expand Up @@ -1099,9 +1125,10 @@ impl Config {
let mut packages = HashMap::new();

let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1));
let enum_features = EnumFeatures::new(requests.iter().map(|x| &x.1));
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;
let mut context = Context::new(self, message_graph, extern_paths);
let mut context = Context::new(self, message_graph, enum_features, extern_paths);

for (request_module, request_fd) in requests {
// Only record packages that have services
Expand Down Expand Up @@ -1179,6 +1206,7 @@ impl default::Default for Config {
enum_attributes: PathMap::default(),
field_attributes: PathMap::default(),
boxed: PathMap::default(),
typed_enum_fields: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
out_dir: None,
Expand Down
22 changes: 22 additions & 0 deletions prost-build/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use prost_types::{
FieldDescriptorProto,
};

use crate::enums::{EnumFeatures, EnumRepr};
use crate::extern_paths::ExternPaths;
use crate::message_graph::MessageGraph;
use crate::{BytesType, Config, MapType, ServiceGenerator};
Expand All @@ -18,18 +19,21 @@ use crate::{BytesType, Config, MapType, ServiceGenerator};
pub struct Context<'a> {
config: &'a mut Config,
message_graph: MessageGraph,
enum_features: EnumFeatures,
extern_paths: ExternPaths,
}

impl<'a> Context<'a> {
pub fn new(
config: &'a mut Config,
message_graph: MessageGraph,
enum_features: EnumFeatures,
extern_paths: ExternPaths,
) -> Self {
Self {
config,
message_graph,
enum_features,
extern_paths,
}
}
Expand Down Expand Up @@ -168,6 +172,24 @@ impl<'a> Context<'a> {
false
}

/// Returns the enum value representation for the named message field.
pub fn enum_field_repr(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> EnumRepr {
if self
.config
.typed_enum_fields
.get_first_field(fq_message_name, field.name())
.is_some()
{
if self.enum_features.is_closed(field.type_name()) {
EnumRepr::Closed
} else {
EnumRepr::Open
}
} else {
EnumRepr::Int
}
}

/// Returns `true` if this message can automatically derive Copy trait.
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
Expand Down
70 changes: 70 additions & 0 deletions prost-build/src/enums.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::collections::HashSet;

use crate::path::fq_package_path;
use crate::syntax::Syntax;
use prost_types::{DescriptorProto, EnumDescriptorProto, FileDescriptorProto};

/// Determines how to represent a Proto enum field value in Rust.
pub enum EnumRepr {
/// Open enumeration, represented with `OpenEnum` wrapping the Rust enum type.
Open,
/// Closed enumeration, represented with the Rust enum type directly.
Closed,
/// i32 representation.
Int,
}

/// Registry of enum type features tracked across Protobuf files in the input.
pub struct EnumFeatures {
// Set of fully qualified names of enums that shall be represented as closed.
// As enums are open in proto3 (and, by default, in edition 2023 and later),
// tracking only closed enums should make a smaller set on inputs
// predominantly using proto3 and, in the future, editions.
closed_enums: HashSet<String>,
}

impl EnumFeatures {
pub(crate) fn new<'a>(files: impl Iterator<Item = &'a FileDescriptorProto>) -> Self {
let mut enum_type_map = EnumFeatures {
closed_enums: HashSet::new(),
};
for file in files {
let syntax = Syntax::from(file.syntax.as_deref());
// Until support for editions is added, we only need to look into
// proto2 files to collect closed enums.
// With edition syntax, the enum_type feature will be available to
// override per file or individual enum.
match syntax {
Syntax::Proto2 => {
let package = fq_package_path(file);
enum_type_map.add_enum_types(&package, &file.enum_type);
for msg in &file.message_type {
enum_type_map.visit_message_type(&package, msg);
}
}
Syntax::Proto3 => {} // Proto3 does not have closed enums.
}
}
enum_type_map
}

fn add_enum_types(&mut self, fq_path: &str, enum_types: &[EnumDescriptorProto]) {
for enum_type in enum_types {
let enum_path = format!("{}.{}", fq_path, enum_type.name());
self.closed_enums.insert(enum_path);
}
}

fn visit_message_type(&mut self, fq_path: &str, msg: &DescriptorProto) {
let message_path = format!("{}.{}", fq_path, msg.name());
self.add_enum_types(&message_path, &msg.enum_type);
for msg in &msg.nested_type {
self.visit_message_type(&message_path, msg);
}
}

/// Returns true if the enum with the given fully qualified path is closed.
pub fn is_closed(&self, fq_path: &str) -> bool {
self.closed_enums.contains(fq_path)
}
}
2 changes: 2 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,12 @@ pub(crate) use collections::{BytesType, MapType};

mod code_generator;
mod context;
mod enums;
mod extern_paths;
mod ident;
mod message_graph;
mod path;
mod syntax;

mod config;
pub use config::{
Expand Down
8 changes: 3 additions & 5 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use prost_types::{
DescriptorProto, FileDescriptorProto,
};

use crate::path::fq_package_path;

/// `MessageGraph` builds a graph of messages whose edges correspond to nesting.
/// The goal is to recognize when message types are recursively nested, so
/// that fields can be boxed when necessary.
Expand All @@ -27,11 +29,7 @@ impl MessageGraph {
};

for file in files {
let package = format!(
"{}{}",
if file.package.is_some() { "." } else { "" },
file.package.as_ref().map(String::as_str).unwrap_or("")
);
let package = fq_package_path(file);
for msg in &file.message_type {
msg_graph.add_message(&package, msg);
}
Expand Down
10 changes: 10 additions & 0 deletions prost-build/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

use std::iter;

use prost_types::FileDescriptorProto;

/// Returns the fully-qualified package path for a given Protobuf file descriptor.
/// If the file has no package, returns an empty string.
pub fn fq_package_path(file: &FileDescriptorProto) -> String {
file.package
.as_ref()
.map_or_else(String::new, |pkg| format!(".{}", pkg))
}

/// Maps a fully-qualified Protobuf path to a value using path matchers.
#[derive(Clone, Debug, Default)]
pub(crate) struct PathMap<T> {
Expand Down
File renamed without changes.
Loading
Loading