1- use std:: collections:: HashMap ;
1+ use std:: collections:: { HashMap , HashSet } ;
22
3- use petgraph:: algo:: has_path_connecting;
43use petgraph:: graph:: NodeIndex ;
5- use petgraph:: Graph ;
4+ use petgraph:: visit:: { EdgeRef , VisitMap } ;
5+ use petgraph:: { Direction , Graph } ;
66
77use prost_types:: {
88 field_descriptor_proto:: { Label , Type } ,
@@ -15,9 +15,13 @@ use crate::path::PathMap;
1515/// The goal is to recognize when message types are recursively nested, so
1616/// that fields can be boxed when necessary.
1717pub struct MessageGraph {
18+ /// Map<fq type name, graph node index>
1819 index : HashMap < String , NodeIndex > ,
19- graph : Graph < String , ( ) > ,
20+ /// Graph with fq type name as node, field name as edge
21+ graph : Graph < String , String > ,
22+ /// Map<fq type name, DescriptorProto>
2023 messages : HashMap < String , DescriptorProto > ,
24+ /// Manually boxed fields
2125 boxed : PathMap < ( ) > ,
2226}
2327
@@ -71,7 +75,8 @@ impl MessageGraph {
7175 for field in & msg. field {
7276 if field. r#type ( ) == Type :: Message && field. label ( ) != Label :: Repeated {
7377 let field_index = self . get_or_insert_index ( field. type_name . clone ( ) . unwrap ( ) ) ;
74- self . graph . add_edge ( msg_index, field_index, ( ) ) ;
78+ self . graph
79+ . add_edge ( msg_index, field_index, field. name . clone ( ) . unwrap ( ) ) ;
7580 }
7681 }
7782 self . messages . insert ( msg_name. clone ( ) , msg. clone ( ) ) ;
@@ -86,8 +91,9 @@ impl MessageGraph {
8691 self . messages . get ( message)
8792 }
8893
89- /// Returns true if message type `inner` is nested in message type `outer`.
90- pub fn is_nested ( & self , outer : & str , inner : & str ) -> bool {
94+ /// Returns true if message type `inner` is nested in message type `outer`,
95+ /// and no field edge in the chain of dependencies is manually boxed.
96+ pub fn is_directly_nested ( & self , outer : & str , inner : & str ) -> bool {
9197 let outer = match self . index . get ( outer) {
9298 Some ( outer) => * outer,
9399 None => return false ,
@@ -97,7 +103,12 @@ impl MessageGraph {
97103 None => return false ,
98104 } ;
99105
100- has_path_connecting ( & self . graph , outer, inner, None )
106+ // Check if `inner` is nested in `outer` and ensure that all edge fields are not boxed manually.
107+ is_connected_with_edge_filter ( & self . graph , outer, inner, |node, field_name| {
108+ self . boxed
109+ . get_first_field ( & self . graph [ node] , field_name)
110+ . is_none ( )
111+ } )
101112 }
102113
103114 /// Returns `true` if this message can automatically derive Copy trait.
@@ -123,11 +134,11 @@ impl MessageGraph {
123134 false
124135 } else if field. r#type ( ) == Type :: Message {
125136 // nested and boxed messages cannot derive Copy
126- if self . is_nested ( field . type_name ( ) , fq_message_name )
127- || self
128- . boxed
129- . get_first_field ( fq_message_name , field . name ( ) )
130- . is_some ( )
137+ if self
138+ . boxed
139+ . get_first_field ( fq_message_name , field . name ( ) )
140+ . is_some ( )
141+ || self . is_directly_nested ( field . type_name ( ) , fq_message_name )
131142 {
132143 false
133144 } else {
@@ -154,3 +165,123 @@ impl MessageGraph {
154165 }
155166 }
156167}
168+
169+ /// Check two nodes is connected with edge filter
170+ fn is_connected_with_edge_filter < F , N , E > (
171+ graph : & Graph < N , E > ,
172+ start : NodeIndex ,
173+ end : NodeIndex ,
174+ mut is_good_edge : F ,
175+ ) -> bool
176+ where
177+ F : FnMut ( NodeIndex , & E ) -> bool ,
178+ {
179+ fn visitor < F , N , E > (
180+ graph : & Graph < N , E > ,
181+ start : NodeIndex ,
182+ end : NodeIndex ,
183+ is_good_edge : & mut F ,
184+ visited : & mut HashSet < NodeIndex > ,
185+ ) -> bool
186+ where
187+ F : FnMut ( NodeIndex , & E ) -> bool ,
188+ {
189+ if start == end {
190+ return true ;
191+ }
192+ visited. visit ( start) ;
193+ for edge in graph. edges_directed ( start, Direction :: Outgoing ) {
194+ // if the edge doesn't pass the filter, skip it
195+ if !is_good_edge ( start, edge. weight ( ) ) {
196+ continue ;
197+ }
198+ let target = edge. target ( ) ;
199+ if visited. is_visited ( & target) {
200+ continue ;
201+ }
202+ if visitor ( graph, target, end, is_good_edge, visited) {
203+ return true ;
204+ }
205+ }
206+ false
207+ }
208+ let mut visited = HashSet :: new ( ) ;
209+ visitor ( graph, start, end, & mut is_good_edge, & mut visited)
210+ }
211+
212+ #[ cfg( test) ]
213+ mod tests {
214+ use super :: * ;
215+
216+ #[ test]
217+ fn test_connected ( ) {
218+ let mut graph = Graph :: new ( ) ;
219+ let n1 = graph. add_node ( 1 ) ;
220+ let n2 = graph. add_node ( 2 ) ;
221+ let n3 = graph. add_node ( 3 ) ;
222+ let n4 = graph. add_node ( 4 ) ;
223+ let n5 = graph. add_node ( 5 ) ;
224+ let n6 = graph. add_node ( 6 ) ;
225+ let n7 = graph. add_node ( 7 ) ;
226+ let n8 = graph. add_node ( 8 ) ;
227+ graph. add_edge ( n1, n2, 1. ) ;
228+ graph. add_edge ( n2, n3, 2. ) ;
229+ graph. add_edge ( n3, n4, 3. ) ;
230+ graph. add_edge ( n4, n5, 4. ) ;
231+ graph. add_edge ( n5, n6, 5. ) ;
232+ graph. add_edge ( n6, n7, 6. ) ;
233+ graph. add_edge ( n7, n8, 7. ) ;
234+ graph. add_edge ( n8, n1, 8. ) ;
235+ assert_eq ! (
236+ is_connected_with_edge_filter( & graph, n2, n1, |_, edge| {
237+ dbg!( edge) ;
238+ true
239+ } ) ,
240+ true
241+ ) ;
242+ assert_eq ! (
243+ is_connected_with_edge_filter( & graph, n2, n1, |_, edge| {
244+ dbg!( edge) ;
245+ edge < & 8.5
246+ } ) ,
247+ true ,
248+ ) ;
249+ assert_eq ! (
250+ is_connected_with_edge_filter( & graph, n2, n1, |_, edge| {
251+ dbg!( edge) ;
252+ edge < & 7.5
253+ } ) ,
254+ false ,
255+ ) ;
256+ }
257+
258+ #[ test]
259+ fn test_connected_multi_circle ( ) {
260+ let mut graph = Graph :: new ( ) ;
261+ let n0 = graph. add_node ( 0 ) ;
262+ let n1 = graph. add_node ( 1 ) ;
263+ let n2 = graph. add_node ( 2 ) ;
264+ let n3 = graph. add_node ( 3 ) ;
265+ let n4 = graph. add_node ( 4 ) ;
266+ graph. add_edge ( n0, n1, 0. ) ;
267+ graph. add_edge ( n1, n2, 1. ) ;
268+ graph. add_edge ( n2, n3, 2. ) ;
269+ graph. add_edge ( n3, n0, 3. ) ;
270+ graph. add_edge ( n1, n4, 1.5 ) ;
271+ graph. add_edge ( n4, n0, 2.5 ) ;
272+ assert_eq ! (
273+ is_connected_with_edge_filter( & graph, n1, n0, |_, edge| {
274+ dbg!( edge) ;
275+ edge < & 2.8
276+ } ) ,
277+ true ,
278+ ) ;
279+ assert_eq ! (
280+ is_connected_with_edge_filter( & graph, n1, n0, |_, edge| {
281+ dbg!( edge) ;
282+ edge < & 2.1
283+ } ) ,
284+ false ,
285+ ) ;
286+ }
287+ }
0 commit comments