11from typing import Any , cast
22
33from graphql import (
4+ DocumentNode ,
45 GraphQLEnumType ,
56 GraphQLField ,
67 GraphQLInterfaceType ,
2627from s2dm import log
2728from s2dm .exporters .protobuf .models import ProtoEnum , ProtoEnumValue , ProtoField , ProtoMessage , ProtoSchema , ProtoUnion
2829from s2dm .exporters .utils .directive import get_directive_arguments , has_given_directive
29- from s2dm .exporters .utils .extraction import get_all_named_types
30+ from s2dm .exporters .utils .extraction import get_all_named_types , get_root_level_types_from_query
3031from s2dm .exporters .utils .field import get_cardinality
3132from s2dm .exporters .utils .instance_tag import expand_instance_tag , get_instance_tag_object , is_instance_tag_field
3233from s2dm .exporters .utils .naming import convert_name , get_target_case_for_element
@@ -84,13 +85,15 @@ def __init__(
8485 package_name : str | None = None ,
8586 naming_config : dict [str , Any ] | None = None ,
8687 expanded_instances : bool = False ,
88+ selection_query : DocumentNode | None = None ,
8789 ):
8890 self .graphql_schema = graphql_schema
8991 self .root_type = root_type
9092 self .flatten_naming = flatten_naming
9193 self .package_name = package_name
9294 self .naming_config = naming_config
9395 self .expanded_instances = expanded_instances
96+ self .selection_query = selection_query
9497
9598 self .env = Environment (
9699 loader = PackageLoader ("s2dm.exporters.protobuf" , "templates" ),
@@ -137,17 +140,23 @@ def transform(self) -> str:
137140 proto_schema = ProtoSchema (
138141 package = self .package_name ,
139142 enums = [],
140- flatten_mode = self .flatten_naming and self . root_type is not None ,
143+ flatten_mode = self .flatten_naming ,
141144 )
142145
143- if self .flatten_naming and self . root_type :
146+ if self .flatten_naming :
144147 # In flatten mode, we need a second filtering pass to remove types that were completely flattened.
145148 # When object fields are flattened, they become prefixed fields in the parent (e.g., parent_child_field).
146149 # If no fields reference that object type directly (non-flattened), the type definition is no longer needed.
147150 # However, unions and enums cannot be flattened and must remain as separate type definitions.
148- proto_schema .flattened_fields , referenced_type_names = self ._build_flattened_fields (message_types )
151+ (
152+ proto_schema .flattened_fields ,
153+ referenced_type_names ,
154+ flattened_root_types ,
155+ ) = self ._build_flattened_fields (message_types )
149156 message_types = [
150- message_type for message_type in message_types if message_type .name in referenced_type_names
157+ message_type
158+ for message_type in message_types
159+ if message_type .name in referenced_type_names and message_type .name not in flattened_root_types
151160 ]
152161 union_types = [union_type for union_type in union_types if union_type .name in referenced_type_names ]
153162 enum_types = [enum_type for enum_type in enum_types if enum_type .name in referenced_type_names ]
@@ -156,7 +165,7 @@ def transform(self) -> str:
156165 proto_schema .unions = self ._build_unions (union_types )
157166 proto_schema .messages = self ._build_messages (message_types )
158167
159- template_name = "proto_flattened.j2" if self .flatten_naming and self . root_type else "proto_standard.j2"
168+ template_name = "proto_flattened.j2" if self .flatten_naming else "proto_standard.j2"
160169 template = self .env .get_template (template_name )
161170
162171 template_vars = self ._build_template_vars (proto_schema )
@@ -312,20 +321,47 @@ def _build_unions(self, union_types: list[GraphQLUnionType]) -> list[ProtoUnion]
312321
313322 def _build_flattened_fields (
314323 self , message_types : list [GraphQLObjectType | GraphQLInterfaceType ]
315- ) -> tuple [list [ProtoField ], set [str ]]:
316- """Build flattened fields for flatten_naming mode."""
317- root_object = None
318- for message_type in message_types :
319- if message_type .name == self .root_type :
320- root_object = message_type
321- break
324+ ) -> tuple [list [ProtoField ], set [str ], set [str ]]:
325+ """Build flattened fields for flatten_naming mode.
326+
327+ Returns:
328+ tuple: (flattened_fields, referenced_types, flattened_root_types)
329+ """
330+ type_cache = {type_def .name : type_def for type_def in message_types }
331+
332+ if self .root_type :
333+ root_object = type_cache .get (self .root_type )
334+ if not root_object :
335+ log .warning (f"Root type '{ self .root_type } ' not found, creating empty message" )
336+ return [], set (), set ()
337+
338+ fields , referenced_types , _ = self ._flatten_fields (root_object , root_object .name , message_types , 1 )
339+ return fields , referenced_types , {self .root_type }
340+
341+ root_level_type_names = get_root_level_types_from_query (self .graphql_schema , self .selection_query )
342+ if not root_level_type_names :
343+ log .warning ("No root-level types found in selection query, creating empty message" )
344+ return [], set (), set ()
345+
346+ all_fields : list [ProtoField ] = []
347+ all_referenced_types : set [str ] = set ()
348+ flattened_root_types : set [str ] = set ()
349+ field_counter = 1
322350
323- if not root_object :
324- log .warning (f"Root type '{ self .root_type } ' not found, creating empty message" )
325- return [], set ()
351+ for type_name in root_level_type_names :
352+ root_object = type_cache .get (type_name )
353+ if not root_object :
354+ log .warning (f"Root-level type '{ type_name } ' not found in message types" )
355+ continue
356+
357+ flattened_root_types .add (type_name )
358+ fields , referenced_types , field_counter = self ._flatten_fields (
359+ root_object , type_name , message_types , field_counter , type_cache
360+ )
361+ all_fields .extend (fields )
362+ all_referenced_types .update (referenced_types )
326363
327- fields , referenced_types , _ = self ._flatten_fields (root_object , root_object .name , message_types , 1 )
328- return fields , referenced_types
364+ return all_fields , all_referenced_types , flattened_root_types
329365
330366 def _create_proto_field_with_validation (
331367 self , field : GraphQLField , field_name : str , proto_type : str , field_number : int
0 commit comments