Skip to content

Commit c7e9a00

Browse files
authored
feat: derive Eq and Hash trait for messages where possible (#1175)
Integer and bytes types can be compared using trait Eq. Some generated Rust structs can also have this property by deriving the Eq trait. Automatically derive Eq and Hash for: - messages that only have fields with integer or bytes types - messages where all field types also implement Eq and Hash - the Rust enum for one-of fields, where all fields implement Eq and Hash Generated code for Protobuf enums already derives Eq and Hash. BREAKING CHANGE: `prost-build` will automatically derive `trait Eq` and `trait Hash` for types where all field support those as well. If you manually `impl Eq` and/or `impl Hash` for generated types, then you need to remove the manual implementation. If you use `type_attribute` to `derive(Eq)` and/or `derive(Hash)`, then you need to remove those.
1 parent 87c22b4 commit c7e9a00

File tree

12 files changed

+90
-52
lines changed

12 files changed

+90
-52
lines changed

prost-build/src/code_generator.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,17 @@ impl<'b> CodeGenerator<'_, 'b> {
223223
self.append_message_attributes(&fq_message_name);
224224
self.push_indent();
225225
self.buf.push_str(&format!(
226-
"#[derive(Clone, {}PartialEq, {}::Message)]\n",
226+
"#[derive(Clone, {}PartialEq, {}{}::Message)]\n",
227227
if self.context.can_message_derive_copy(&fq_message_name) {
228228
"Copy, "
229229
} else {
230230
""
231231
},
232+
if self.context.can_message_derive_eq(&fq_message_name) {
233+
"Eq, Hash, "
234+
} else {
235+
""
236+
},
232237
self.context.prost_path()
233238
));
234239
self.append_skip_debug(&fq_message_name);
@@ -596,9 +601,18 @@ impl<'b> CodeGenerator<'_, 'b> {
596601
self.context
597602
.can_field_derive_copy(fq_message_name, &field.descriptor)
598603
});
604+
let can_oneof_derive_eq = oneof.fields.iter().all(|field| {
605+
self.context
606+
.can_field_derive_eq(fq_message_name, &field.descriptor)
607+
});
599608
self.buf.push_str(&format!(
600-
"#[derive(Clone, {}PartialEq, {}::Oneof)]\n",
609+
"#[derive(Clone, {}PartialEq, {}{}::Oneof)]\n",
601610
if can_oneof_derive_copy { "Copy, " } else { "" },
611+
if can_oneof_derive_eq {
612+
"Eq, Hash, "
613+
} else {
614+
""
615+
},
602616
self.context.prost_path()
603617
));
604618
self.append_skip_debug(fq_message_name);

prost-build/src/context.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,51 @@ impl<'a> Context<'a> {
234234
}
235235
}
236236

237+
/// Returns `true` if this message can automatically derive Eq trait.
238+
pub fn can_message_derive_eq(&self, fq_message_name: &str) -> bool {
239+
assert_eq!(".", &fq_message_name[..1]);
240+
241+
let msg = self.message_graph.get_message(fq_message_name).unwrap();
242+
msg.field
243+
.iter()
244+
.all(|field| self.can_field_derive_eq(fq_message_name, field))
245+
}
246+
247+
/// Returns `true` if the type of this field allows deriving the Eq trait.
248+
pub fn can_field_derive_eq(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
249+
assert_eq!(".", &fq_message_name[..1]);
250+
251+
if field.r#type() == Type::Message {
252+
if field.label() == Label::Repeated
253+
|| self
254+
.message_graph
255+
.is_nested(field.type_name(), fq_message_name)
256+
{
257+
false
258+
} else {
259+
self.can_message_derive_eq(field.type_name())
260+
}
261+
} else {
262+
matches!(
263+
field.r#type(),
264+
Type::Int32
265+
| Type::Int64
266+
| Type::Uint32
267+
| Type::Uint64
268+
| Type::Sint32
269+
| Type::Sint64
270+
| Type::Fixed32
271+
| Type::Fixed64
272+
| Type::Sfixed32
273+
| Type::Sfixed64
274+
| Type::Bool
275+
| Type::Enum
276+
| Type::String
277+
| Type::Bytes
278+
)
279+
}
280+
}
281+
237282
pub fn should_disable_comments(&self, fq_message_name: &str, field_name: Option<&str>) -> bool {
238283
if let Some(field_name) = field_name {
239284
self.config
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,29 @@
11
// This file is @generated by prost-build.
2-
#[derive(Clone, PartialEq, ::prost::Message)]
2+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
33
pub struct Container {
44
#[prost(oneof="container::Data", tags="1, 2")]
55
pub data: ::core::option::Option<container::Data>,
66
}
77
/// Nested message and enum types in `Container`.
88
pub mod container {
9-
#[derive(Clone, PartialEq, ::prost::Oneof)]
9+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
1010
pub enum Data {
1111
#[prost(message, tag="1")]
1212
Foo(::prost::alloc::boxed::Box<super::Foo>),
1313
#[prost(message, tag="2")]
1414
Bar(super::Bar),
1515
}
1616
}
17-
#[derive(Clone, PartialEq, ::prost::Message)]
17+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
1818
pub struct Foo {
1919
#[prost(string, tag="1")]
2020
pub foo: ::prost::alloc::string::String,
2121
}
22-
#[derive(Clone, PartialEq, ::prost::Message)]
22+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
2323
pub struct Bar {
2424
#[prost(message, optional, boxed, tag="1")]
2525
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
2626
}
27-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
27+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
2828
pub struct Qux {
2929
}
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
// This file is @generated by prost-build.
2-
#[derive(Clone, PartialEq, ::prost::Message)]
2+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
33
pub struct Container {
44
#[prost(oneof = "container::Data", tags = "1, 2")]
55
pub data: ::core::option::Option<container::Data>,
66
}
77
/// Nested message and enum types in `Container`.
88
pub mod container {
9-
#[derive(Clone, PartialEq, ::prost::Oneof)]
9+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
1010
pub enum Data {
1111
#[prost(message, tag = "1")]
1212
Foo(::prost::alloc::boxed::Box<super::Foo>),
1313
#[prost(message, tag = "2")]
1414
Bar(super::Bar),
1515
}
1616
}
17-
#[derive(Clone, PartialEq, ::prost::Message)]
17+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
1818
pub struct Foo {
1919
#[prost(string, tag = "1")]
2020
pub foo: ::prost::alloc::string::String,
2121
}
22-
#[derive(Clone, PartialEq, ::prost::Message)]
22+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
2323
pub struct Bar {
2424
#[prost(message, optional, boxed, tag = "1")]
2525
pub qux: ::core::option::Option<::prost::alloc::boxed::Box<Qux>>,
2626
}
27-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
27+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
2828
pub struct Qux {}

prost-build/src/fixtures/helloworld/_expected_helloworld.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
// This file is @generated by prost-build.
22
#[derive(derive_builder::Builder)]
33
#[derive(custom_proto::Input)]
4-
#[derive(Clone, PartialEq, ::prost::Message)]
4+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
55
pub struct Message {
66
#[prost(string, tag="1")]
77
pub say: ::prost::alloc::string::String,
88
}
99
#[derive(derive_builder::Builder)]
1010
#[derive(custom_proto::Output)]
11-
#[derive(Clone, PartialEq, ::prost::Message)]
11+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
1212
pub struct Response {
1313
#[prost(string, tag="1")]
1414
pub say: ::prost::alloc::string::String,

prost-build/src/fixtures/helloworld/_expected_helloworld_formatted.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
// This file is @generated by prost-build.
22
#[derive(derive_builder::Builder)]
33
#[derive(custom_proto::Input)]
4-
#[derive(Clone, PartialEq, ::prost::Message)]
4+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
55
pub struct Message {
66
#[prost(string, tag = "1")]
77
pub say: ::prost::alloc::string::String,
88
}
99
#[derive(derive_builder::Builder)]
1010
#[derive(custom_proto::Output)]
11-
#[derive(Clone, PartialEq, ::prost::Message)]
11+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
1212
pub struct Response {
1313
#[prost(string, tag = "1")]
1414
pub say: ::prost::alloc::string::String,

prost-types/src/compiler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// This file is @generated by prost-build.
22
/// The version number of protocol compiler.
33
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
4-
#[derive(Clone, PartialEq, ::prost::Message)]
4+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
55
pub struct Version {
66
#[prost(int32, optional, tag = "1")]
77
pub major: ::core::option::Option<i32>,

prost-types/src/duration.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
use super::*;
22

3-
#[cfg(feature = "std")]
4-
impl std::hash::Hash for Duration {
5-
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
6-
self.seconds.hash(state);
7-
self.nanos.hash(state);
8-
}
9-
}
10-
113
impl Duration {
124
/// Normalizes the duration to a canonical format.
135
///

prost-types/src/protobuf.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ pub mod descriptor_proto {
9494
/// fields or extension ranges in the same message. Reserved ranges may
9595
/// not overlap.
9696
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
97-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
97+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
9898
pub struct ReservedRange {
9999
/// Inclusive.
100100
#[prost(int32, optional, tag = "1")]
@@ -362,7 +362,7 @@ pub mod enum_descriptor_proto {
362362
/// is inclusive such that it can appropriately represent the entire int32
363363
/// domain.
364364
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
365-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
365+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
366366
pub struct EnumReservedRange {
367367
/// Inclusive.
368368
#[prost(int32, optional, tag = "1")]
@@ -990,7 +990,7 @@ pub mod uninterpreted_option {
990990
/// E.g.,{ \["foo", false\], \["bar.baz", true\], \["qux", false\] } represents
991991
/// "foo.(bar.baz).qux".
992992
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
993-
#[derive(Clone, PartialEq, ::prost::Message)]
993+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
994994
pub struct NamePart {
995995
#[prost(string, required, tag = "1")]
996996
pub name_part: ::prost::alloc::string::String,
@@ -1053,7 +1053,7 @@ pub struct SourceCodeInfo {
10531053
/// Nested message and enum types in `SourceCodeInfo`.
10541054
pub mod source_code_info {
10551055
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1056-
#[derive(Clone, PartialEq, ::prost::Message)]
1056+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
10571057
pub struct Location {
10581058
/// Identifies which part of the FileDescriptorProto was defined at this
10591059
/// location.
@@ -1158,7 +1158,7 @@ pub struct GeneratedCodeInfo {
11581158
/// Nested message and enum types in `GeneratedCodeInfo`.
11591159
pub mod generated_code_info {
11601160
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1161-
#[derive(Clone, PartialEq, ::prost::Message)]
1161+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
11621162
pub struct Annotation {
11631163
/// Identifies the element in the original source .proto file. This field
11641164
/// is formatted the same as SourceCodeInfo.Location.path.
@@ -1272,7 +1272,7 @@ pub mod generated_code_info {
12721272
/// }
12731273
/// ```
12741274
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1275-
#[derive(Clone, PartialEq, ::prost::Message)]
1275+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
12761276
pub struct Any {
12771277
/// A URL/resource name that uniquely identifies the type of the serialized
12781278
/// protocol buffer message. This string must contain at least
@@ -1310,7 +1310,7 @@ pub struct Any {
13101310
/// `SourceContext` represents information about the source of a
13111311
/// protobuf element, like the file in which it is defined.
13121312
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1313-
#[derive(Clone, PartialEq, ::prost::Message)]
1313+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
13141314
pub struct SourceContext {
13151315
/// The path-qualified name of the .proto file that contained the associated
13161316
/// protobuf element. For example: `"google/protobuf/source_context.proto"`.
@@ -1573,7 +1573,7 @@ pub struct EnumValue {
15731573
/// A protocol buffer option, which can be attached to a message, field,
15741574
/// enumeration, etc.
15751575
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1576-
#[derive(Clone, PartialEq, ::prost::Message)]
1576+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
15771577
pub struct Option {
15781578
/// The option's name. For protobuf built-in options (options defined in
15791579
/// descriptor.proto), this is the short name. For example, `"map_entry"`.
@@ -1787,7 +1787,7 @@ pub struct Method {
17871787
/// }
17881788
/// ```
17891789
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1790-
#[derive(Clone, PartialEq, ::prost::Message)]
1790+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
17911791
pub struct Mixin {
17921792
/// The fully qualified name of the interface which is included.
17931793
#[prost(string, tag = "1")]
@@ -1862,7 +1862,7 @@ pub struct Mixin {
18621862
/// be expressed in JSON format as "3.000000001s", and 3 seconds and 1
18631863
/// microsecond should be expressed in JSON format as "3.000001s".
18641864
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
1865-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
1865+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
18661866
pub struct Duration {
18671867
/// Signed seconds of the span of time. Must be from -315,576,000,000
18681868
/// to +315,576,000,000 inclusive. Note: these bounds are computed from:
@@ -2101,7 +2101,7 @@ pub struct Duration {
21012101
/// request should verify the included field paths, and return an
21022102
/// `INVALID_ARGUMENT` error if any path is unmappable.
21032103
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
2104-
#[derive(Clone, PartialEq, ::prost::Message)]
2104+
#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
21052105
pub struct FieldMask {
21062106
/// The set of field mask paths.
21072107
#[prost(string, repeated, tag = "1")]
@@ -2303,7 +2303,7 @@ impl NullValue {
23032303
/// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use
23042304
/// the Joda Time's [`ISODateTimeFormat.dateTime()`](<http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D>) to obtain a formatter capable of generating timestamps in this format.
23052305
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
2306-
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
2306+
#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)]
23072307
pub struct Timestamp {
23082308
/// Represents seconds of UTC time since Unix epoch
23092309
/// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to

prost-types/src/timestamp.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,19 +123,6 @@ impl Name for Timestamp {
123123
}
124124
}
125125

126-
/// Implements the unstable/naive version of `Eq`: a basic equality check on the internal fields of the `Timestamp`.
127-
/// This implies that `normalized_ts != non_normalized_ts` even if `normalized_ts == non_normalized_ts.normalized()`.
128-
#[cfg(feature = "std")]
129-
impl Eq for Timestamp {}
130-
131-
#[cfg(feature = "std")]
132-
impl std::hash::Hash for Timestamp {
133-
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
134-
self.seconds.hash(state);
135-
self.nanos.hash(state);
136-
}
137-
}
138-
139126
#[cfg(feature = "std")]
140127
impl From<std::time::SystemTime> for Timestamp {
141128
fn from(system_time: std::time::SystemTime) -> Timestamp {

0 commit comments

Comments
 (0)