2828from s2dm .exporters .utils .directive import get_directive_arguments , has_given_directive
2929from s2dm .exporters .utils .extraction import get_all_named_types
3030from s2dm .exporters .utils .field import get_cardinality
31- from s2dm .exporters .utils .instance_tag import is_instance_tag_field
31+ from s2dm .exporters .utils .instance_tag import expand_instance_tag , get_instance_tag_object , is_instance_tag_field
32+ from s2dm .exporters .utils .naming import convert_name , get_target_case_for_element
3233from s2dm .exporters .utils .schema_loader import get_referenced_types
3334
3435GRAPHQL_SCALAR_TO_PROTOBUF = {
@@ -108,7 +109,7 @@ def transform(self) -> str:
108109 log .info ("Starting GraphQL to Protobuf transformation" )
109110
110111 if self .root_type :
111- referenced_types = get_referenced_types (self .graphql_schema , self .root_type , True )
112+ referenced_types = get_referenced_types (self .graphql_schema , self .root_type , not self . expanded_instances )
112113 user_defined_types : list [GraphQLNamedType ] = [
113114 referenced_type for referenced_type in referenced_types if isinstance (referenced_type , GraphQLNamedType )
114115 ]
@@ -181,30 +182,53 @@ def _build_messages(self, message_types: list[GraphQLObjectType | GraphQLInterfa
181182 """Build Pydantic models for message types."""
182183 messages = []
183184 for message_type in message_types :
184- fields = self ._build_message_fields (message_type )
185+ fields , nested_messages = self ._build_message_fields (message_type )
185186
186187 messages .append (
187188 ProtoMessage (
188189 name = message_type .name ,
189190 fields = fields ,
190191 description = message_type .description ,
191192 source = message_type .name ,
193+ nested_messages = nested_messages ,
192194 )
193195 )
194196 return messages
195197
196- def _build_message_fields (self , message_type : GraphQLObjectType | GraphQLInterfaceType ) -> list [ProtoField ]:
198+ def _build_message_fields (
199+ self , message_type : GraphQLObjectType | GraphQLInterfaceType
200+ ) -> tuple [list [ProtoField ], list [ProtoMessage ]]:
197201 """Build Pydantic models for fields in a message."""
198202 fields = []
203+ nested_messages = []
199204 field_number = 1
200205
201206 for field_name , field in message_type .fields .items ():
202207 if is_instance_tag_field (field_name ) and self .expanded_instances :
203208 continue
204209
205- proto_field_type = self ._get_field_proto_type (field .type )
206- proto_field_name = self ._escape_field_name (field_name )
207- validation_rules = self .process_directives (field , proto_field_type )
210+ field_type = field .type
211+ unwrapped_type = get_named_type (field_type )
212+
213+ proto_field_name = field_name
214+ proto_field_type = None
215+ expanded_message_name = None
216+
217+ if is_object_type (unwrapped_type ):
218+ object_type = cast (GraphQLObjectType , unwrapped_type )
219+ expanded_instances = self ._get_expanded_instances (object_type )
220+ if expanded_instances :
221+ proto_field_name , proto_field_type , nested_message = self ._handle_expanded_instance_field (
222+ object_type , message_type , expanded_instances
223+ )
224+ nested_messages .append (nested_message )
225+ expanded_message_name = proto_field_type
226+
227+ if proto_field_type is None :
228+ proto_field_type = self ._get_field_proto_type (field .type )
229+ proto_field_name = self ._escape_field_name (field_name )
230+
231+ validation_rules = None if expanded_message_name else self .process_directives (field , proto_field_type )
208232
209233 fields .append (
210234 ProtoField (
@@ -217,7 +241,7 @@ def _build_message_fields(self, message_type: GraphQLObjectType | GraphQLInterfa
217241 )
218242 field_number += 1
219243
220- return fields
244+ return fields , nested_messages
221245
222246 def _build_unions (self , union_types : list [GraphQLUnionType ]) -> list [ProtoUnion ]:
223247 """Build Pydantic models for union types."""
@@ -277,12 +301,14 @@ def _flatten_fields(
277301 prefix : str ,
278302 all_types : list [GraphQLObjectType | GraphQLInterfaceType ],
279303 field_counter : int ,
304+ type_cache : dict [str , GraphQLObjectType | GraphQLInterfaceType ] | None = None ,
280305 ) -> tuple [list [ProtoField ], set [str ], int ]:
281306 """Recursively flatten fields with prefix."""
307+ if type_cache is None :
308+ type_cache = {type_def .name : type_def for type_def in all_types }
309+
282310 fields : list [ProtoField ] = []
283311 referenced_types : set [str ] = set ()
284- if not hasattr (object_type , "fields" ):
285- return fields , referenced_types , field_counter
286312
287313 for field_name , field in object_type .fields .items ():
288314 if is_instance_tag_field (field_name ) and self .expanded_instances :
@@ -291,6 +317,21 @@ def _flatten_fields(
291317 field_type = field .type
292318 unwrapped_type = get_named_type (field_type )
293319
320+ if is_object_type (unwrapped_type ):
321+ object_type = cast (GraphQLObjectType , unwrapped_type )
322+ expanded_instances = self ._get_expanded_instances (object_type )
323+ if expanded_instances :
324+ nested_type = type_cache .get (object_type .name )
325+ if nested_type :
326+ for expanded_instance in expanded_instances :
327+ expanded_prefix = f"{ prefix } _{ field_name } _{ expanded_instance .replace ('.' , '_' )} "
328+ nested_fields , nested_referenced , field_counter = self ._flatten_fields (
329+ nested_type , expanded_prefix , all_types , field_counter , type_cache
330+ )
331+ fields .extend (nested_fields )
332+ referenced_types .update (nested_referenced )
333+ continue
334+
294335 inner = field_type .of_type if is_non_null_type (field_type ) else field_type
295336 is_list = is_list_type (inner )
296337
@@ -317,11 +358,11 @@ def _flatten_fields(
317358 continue
318359
319360 named_unwrapped_type = cast (GraphQLObjectType | GraphQLInterfaceType , unwrapped_type )
320- nested_type = self . _get_type (named_unwrapped_type .name , all_types )
361+ nested_type = type_cache . get (named_unwrapped_type .name )
321362
322363 if nested_type :
323364 nested_fields , nested_referenced , field_counter = self ._flatten_fields (
324- nested_type , flattened_name , all_types , field_counter
365+ nested_type , flattened_name , all_types , field_counter , type_cache
325366 )
326367 fields .extend (nested_fields )
327368 referenced_types .update (nested_referenced )
@@ -330,15 +371,6 @@ def _flatten_fields(
330371
331372 return fields , referenced_types , field_counter
332373
333- def _get_type (
334- self , type_name : str , all_types : list [GraphQLObjectType | GraphQLInterfaceType ]
335- ) -> GraphQLObjectType | GraphQLInterfaceType | None :
336- """Get a GraphQL type by name from a list of types."""
337- for type_def in all_types :
338- if type_def .name == type_name :
339- return type_def
340- return None
341-
342374 def _get_field_proto_type (self , field_type : GraphQLType ) -> str :
343375 """Get the Protobuf type string for a GraphQL field type."""
344376 if is_non_null_type (field_type ):
@@ -414,3 +446,64 @@ def _get_validation_type(self, proto_type: str) -> str | None:
414446 """Get the protovalidate scalar type from protobuf type."""
415447 validation_type = proto_type .replace ("repeated " , "" )
416448 return validation_type if validation_type in PROTOBUF_DATA_TYPES else None
449+
450+ def _handle_expanded_instance_field (
451+ self ,
452+ object_type : GraphQLObjectType ,
453+ message_type : GraphQLObjectType | GraphQLInterfaceType ,
454+ expanded_instances : list [str ],
455+ ) -> tuple [str , str , ProtoMessage ]:
456+ """Handle expanded instance fields, returning field name, type, and nested message."""
457+ prefixed_message_name = f"{ message_type .name } _{ object_type .name } "
458+ nested_message = self ._build_nested_message_structure (
459+ prefixed_message_name , expanded_instances , object_type .name
460+ )
461+
462+ field_name_to_use = object_type .name
463+ if self .naming_config :
464+ target_case = get_target_case_for_element ("field" , "object" , self .naming_config )
465+ if target_case :
466+ field_name_to_use = convert_name (object_type .name , target_case )
467+
468+ return (self ._escape_field_name (field_name_to_use ), nested_message .name , nested_message )
469+
470+ def _build_nested_message_structure (
471+ self ,
472+ message_name : str ,
473+ instance_paths : list [str ],
474+ target_type : str ,
475+ ) -> ProtoMessage :
476+ """Create nested message structure for expanded instance tags."""
477+ message = ProtoMessage (name = message_name , fields = [], nested_messages = [], source = None )
478+ child_paths_by_level : dict [str , list [str ]] = {}
479+ field_counter = 1
480+
481+ for instance_path in instance_paths :
482+ instance_path_parts = instance_path .split ("." )
483+ if len (instance_path_parts ) > 1 :
484+ root_level_name = instance_path_parts [0 ]
485+ remaining_path = "." .join (instance_path_parts [1 :])
486+ child_paths_by_level .setdefault (root_level_name , []).append (remaining_path )
487+ else :
488+ message .fields .append (ProtoField (name = instance_path_parts [0 ], type = target_type , number = field_counter ))
489+ field_counter += 1
490+
491+ for root_level_name , child_paths in child_paths_by_level .items ():
492+ child_message_name = f"{ message_name } _{ root_level_name } "
493+ child_message = self ._build_nested_message_structure (child_message_name , child_paths , target_type )
494+ message .nested_messages .append (child_message )
495+ message .fields .append (ProtoField (name = root_level_name , type = child_message .name , number = field_counter ))
496+ field_counter += 1
497+
498+ return message
499+
500+ def _get_expanded_instances (self , object_type : GraphQLObjectType ) -> list [str ] | None :
501+ """Get expanded instances if the type has a valid instance tag."""
502+ if not self .expanded_instances :
503+ return None
504+
505+ instance_tag_object = get_instance_tag_object (object_type , self .graphql_schema )
506+ if not instance_tag_object :
507+ return None
508+
509+ return expand_instance_tag (instance_tag_object , self .naming_config )
0 commit comments