Skip to content

Commit f8da8b8

Browse files
committed
feat(protobuf): flatten all root-level types from selection query
Signed-off-by: Ahmed Mohamed <[email protected]>
1 parent c338f12 commit f8da8b8

File tree

5 files changed

+169
-29
lines changed

5 files changed

+169
-29
lines changed

src/s2dm/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def jsonschema(
574574
"-f",
575575
is_flag=True,
576576
default=False,
577-
help="Flatten nested field names. Requires --root-type to be set.",
577+
help="Flatten nested field names.",
578578
)
579579
@click.option(
580580
"--package-name",
@@ -598,12 +598,13 @@ def protobuf(
598598
naming_config = ctx.obj.get("naming_config")
599599
graphql_schema = load_schema_with_naming(schemas, naming_config)
600600

601+
query_document = None
601602
if selection_query:
602603
query_document = parse(selection_query.read_text())
603604
graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document)
604605

605606
result = translate_to_protobuf(
606-
graphql_schema, root_type, flatten_naming, package_name, naming_config, expanded_instances
607+
graphql_schema, root_type, flatten_naming, package_name, naming_config, expanded_instances, query_document
607608
)
608609
_ = output.write_text(result)
609610

src/s2dm/exporters/protobuf/protobuf.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Any
22

3-
from graphql import GraphQLSchema
3+
from graphql import DocumentNode, GraphQLSchema
44

55
from s2dm import log
66

@@ -14,6 +14,7 @@ def transform(
1414
package_name: str | None = None,
1515
naming_config: dict[str, Any] | None = None,
1616
expanded_instances: bool = False,
17+
selection_query: DocumentNode | None = None,
1718
) -> str:
1819
"""
1920
Transform a GraphQL schema object to Protocol Buffers format.
@@ -25,6 +26,7 @@ def transform(
2526
package_name: Optional package name for the .proto file
2627
naming_config: Optional naming configuration
2728
expanded_instances: If True, expand instance tags into nested structures
29+
selection_query: Optional selection query document to determine root-level types
2830
2931
Returns:
3032
str: Protocol Buffers representation as a string
@@ -36,12 +38,8 @@ def transform(
3638
raise ValueError(f"Root type '{root_type}' not found in schema")
3739
log.info(f"Using root type: {root_type}")
3840

39-
if flatten_naming and not root_type:
40-
log.warning("Flatten naming mode requires a root type, falling back to standard mode")
41-
flatten_naming = False
42-
4341
transformer = ProtobufTransformer(
44-
graphql_schema, root_type, flatten_naming, package_name, naming_config, expanded_instances
42+
graphql_schema, root_type, flatten_naming, package_name, naming_config, expanded_instances, selection_query
4543
)
4644
proto_content = transformer.transform()
4745

@@ -57,6 +55,7 @@ def translate_to_protobuf(
5755
package_name: str | None = None,
5856
naming_config: dict[str, Any] | None = None,
5957
expanded_instances: bool = False,
58+
selection_query: DocumentNode | None = None,
6059
) -> str:
6160
"""
6261
Translate a GraphQL schema to Protocol Buffers format.
@@ -68,8 +67,11 @@ def translate_to_protobuf(
6867
package_name: Optional package name for the .proto file
6968
naming_config: Optional naming configuration
7069
expanded_instances: If True, expand instance tags into nested structures
70+
selection_query: Optional selection query document to determine root-level types
7171
7272
Returns:
7373
str: Protocol Buffers (.proto) representation as a string
7474
"""
75-
return transform(schema, root_type, flatten_naming, package_name, naming_config, expanded_instances)
75+
return transform(
76+
schema, root_type, flatten_naming, package_name, naming_config, expanded_instances, selection_query
77+
)

src/s2dm/exporters/protobuf/transformer.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, cast
22

33
from graphql import (
4+
DocumentNode,
45
GraphQLEnumType,
56
GraphQLField,
67
GraphQLInterfaceType,
@@ -26,7 +27,7 @@
2627
from s2dm import log
2728
from s2dm.exporters.protobuf.models import ProtoEnum, ProtoEnumValue, ProtoField, ProtoMessage, ProtoSchema, ProtoUnion
2829
from s2dm.exporters.utils.directive import get_directive_arguments, has_given_directive
29-
from s2dm.exporters.utils.extraction import get_all_named_types
30+
from s2dm.exporters.utils.extraction import get_all_named_types, get_root_level_types_from_query
3031
from s2dm.exporters.utils.field import get_cardinality
3132
from s2dm.exporters.utils.instance_tag import expand_instance_tag, get_instance_tag_object, is_instance_tag_field
3233
from s2dm.exporters.utils.naming import convert_name, get_target_case_for_element
@@ -84,13 +85,15 @@ def __init__(
8485
package_name: str | None = None,
8586
naming_config: dict[str, Any] | None = None,
8687
expanded_instances: bool = False,
88+
selection_query: DocumentNode | None = None,
8789
):
8890
self.graphql_schema = graphql_schema
8991
self.root_type = root_type
9092
self.flatten_naming = flatten_naming
9193
self.package_name = package_name
9294
self.naming_config = naming_config
9395
self.expanded_instances = expanded_instances
96+
self.selection_query = selection_query
9497

9598
self.env = Environment(
9699
loader=PackageLoader("s2dm.exporters.protobuf", "templates"),
@@ -137,17 +140,23 @@ def transform(self) -> str:
137140
proto_schema = ProtoSchema(
138141
package=self.package_name,
139142
enums=[],
140-
flatten_mode=self.flatten_naming and self.root_type is not None,
143+
flatten_mode=self.flatten_naming,
141144
)
142145

143-
if self.flatten_naming and self.root_type:
146+
if self.flatten_naming:
144147
# In flatten mode, we need a second filtering pass to remove types that were completely flattened.
145148
# When object fields are flattened, they become prefixed fields in the parent (e.g., parent_child_field).
146149
# If no fields reference that object type directly (non-flattened), the type definition is no longer needed.
147150
# However, unions and enums cannot be flattened and must remain as separate type definitions.
148-
proto_schema.flattened_fields, referenced_type_names = self._build_flattened_fields(message_types)
151+
(
152+
proto_schema.flattened_fields,
153+
referenced_type_names,
154+
flattened_root_types,
155+
) = self._build_flattened_fields(message_types)
149156
message_types = [
150-
message_type for message_type in message_types if message_type.name in referenced_type_names
157+
message_type
158+
for message_type in message_types
159+
if message_type.name in referenced_type_names and message_type.name not in flattened_root_types
151160
]
152161
union_types = [union_type for union_type in union_types if union_type.name in referenced_type_names]
153162
enum_types = [enum_type for enum_type in enum_types if enum_type.name in referenced_type_names]
@@ -156,7 +165,7 @@ def transform(self) -> str:
156165
proto_schema.unions = self._build_unions(union_types)
157166
proto_schema.messages = self._build_messages(message_types)
158167

159-
template_name = "proto_flattened.j2" if self.flatten_naming and self.root_type else "proto_standard.j2"
168+
template_name = "proto_flattened.j2" if self.flatten_naming else "proto_standard.j2"
160169
template = self.env.get_template(template_name)
161170

162171
template_vars = self._build_template_vars(proto_schema)
@@ -312,20 +321,47 @@ def _build_unions(self, union_types: list[GraphQLUnionType]) -> list[ProtoUnion]
312321

313322
def _build_flattened_fields(
314323
self, message_types: list[GraphQLObjectType | GraphQLInterfaceType]
315-
) -> tuple[list[ProtoField], set[str]]:
316-
"""Build flattened fields for flatten_naming mode."""
317-
root_object = None
318-
for message_type in message_types:
319-
if message_type.name == self.root_type:
320-
root_object = message_type
321-
break
324+
) -> tuple[list[ProtoField], set[str], set[str]]:
325+
"""Build flattened fields for flatten_naming mode.
326+
327+
Returns:
328+
tuple: (flattened_fields, referenced_types, flattened_root_types)
329+
"""
330+
type_cache = {type_def.name: type_def for type_def in message_types}
331+
332+
if self.root_type:
333+
root_object = type_cache.get(self.root_type)
334+
if not root_object:
335+
log.warning(f"Root type '{self.root_type}' not found, creating empty message")
336+
return [], set(), set()
337+
338+
fields, referenced_types, _ = self._flatten_fields(root_object, root_object.name, message_types, 1)
339+
return fields, referenced_types, {self.root_type}
340+
341+
root_level_type_names = get_root_level_types_from_query(self.graphql_schema, self.selection_query)
342+
if not root_level_type_names:
343+
log.warning("No root-level types found in selection query, creating empty message")
344+
return [], set(), set()
345+
346+
all_fields: list[ProtoField] = []
347+
all_referenced_types: set[str] = set()
348+
flattened_root_types: set[str] = set()
349+
field_counter = 1
322350

323-
if not root_object:
324-
log.warning(f"Root type '{self.root_type}' not found, creating empty message")
325-
return [], set()
351+
for type_name in root_level_type_names:
352+
root_object = type_cache.get(type_name)
353+
if not root_object:
354+
log.warning(f"Root-level type '{type_name}' not found in message types")
355+
continue
356+
357+
flattened_root_types.add(type_name)
358+
fields, referenced_types, field_counter = self._flatten_fields(
359+
root_object, type_name, message_types, field_counter, type_cache
360+
)
361+
all_fields.extend(fields)
362+
all_referenced_types.update(referenced_types)
326363

327-
fields, referenced_types, _ = self._flatten_fields(root_object, root_object.name, message_types, 1)
328-
return fields, referenced_types
364+
return all_fields, all_referenced_types, flattened_root_types
329365

330366
def _create_proto_field_with_validation(
331367
self, field: GraphQLField, field_name: str, proto_type: str, field_number: int

src/s2dm/exporters/utils/extraction.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
1-
from graphql import GraphQLNamedType, GraphQLObjectType, GraphQLSchema
1+
from graphql import (
2+
DocumentNode,
3+
FieldNode,
4+
GraphQLNamedType,
5+
GraphQLObjectType,
6+
GraphQLSchema,
7+
OperationDefinitionNode,
8+
OperationType,
9+
get_named_type,
10+
is_interface_type,
11+
is_object_type,
12+
)
213

314
from s2dm.exporters.utils.directive import has_given_directive
415
from s2dm.exporters.utils.graphql_type import is_introspection_type
@@ -34,3 +45,40 @@ def get_all_object_types(
3445
def get_all_objects_with_directive(objects: list[GraphQLObjectType], directive_name: str) -> list[GraphQLObjectType]:
3546
# TODO: Extend this function to return all objects that have any directive is directive_name is None
3647
return [o for o in objects if has_given_directive(o, directive_name)]
48+
49+
50+
def get_root_level_types_from_query(schema: GraphQLSchema, selection_query: DocumentNode | None) -> list[str]:
51+
"""Extract root-level type names from the selection query.
52+
53+
Args:
54+
schema: The GraphQL schema
55+
selection_query: The selection query document
56+
57+
Returns:
58+
List of type names that are selected at the root level of the query
59+
"""
60+
query_type = schema.query_type
61+
if not selection_query or not query_type:
62+
return []
63+
64+
root_type_names: list[str] = []
65+
66+
for definition in selection_query.definitions:
67+
if not isinstance(definition, OperationDefinitionNode) or definition.operation != OperationType.QUERY:
68+
continue
69+
70+
for selection in definition.selection_set.selections:
71+
if not isinstance(selection, FieldNode):
72+
continue
73+
74+
field_name = selection.name.value
75+
if field_name not in query_type.fields:
76+
continue
77+
78+
field = query_type.fields[field_name]
79+
field_type = get_named_type(field.type)
80+
81+
if is_object_type(field_type) or is_interface_type(field_type):
82+
root_type_names.append(field_type.name)
83+
84+
return root_type_names

tests/test_protobuf.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pathlib import Path
33

44
import pytest
5-
from graphql import build_schema
5+
from graphql import build_schema, parse
66

77
from s2dm.exporters.protobuf import translate_to_protobuf
88
from s2dm.exporters.utils.schema import load_schema_with_naming
@@ -1040,3 +1040,56 @@ def test_complete_proto_file(self) -> None:
10401040
result,
10411041
re.DOTALL,
10421042
), "Transmission message with enum field and validation"
1043+
1044+
def test_flatten_naming_multiple_root_types(self, test_schema_path: list[Path]) -> None:
1045+
"""Test that flatten mode without root_type flattens all root-level types."""
1046+
graphql_schema = load_schema_with_naming(test_schema_path, None)
1047+
1048+
# Selection query that selects vehicle, cabin, and door at the top level
1049+
query_str = """
1050+
query {
1051+
vehicle {
1052+
doors { isLocked }
1053+
model
1054+
}
1055+
cabin {
1056+
seats { isOccupied }
1057+
temperature
1058+
}
1059+
door {
1060+
isLocked
1061+
position
1062+
instanceTag { row side }
1063+
}
1064+
}
1065+
"""
1066+
selection_query = parse(query_str)
1067+
1068+
result = translate_to_protobuf(
1069+
graphql_schema, flatten_naming=True, expanded_instances=False, selection_query=selection_query
1070+
)
1071+
1072+
assert re.search(
1073+
r"message Message \{.*?"
1074+
r"optional repeated Door Vehicle_doors = 1;.*?"
1075+
r"optional string Vehicle_model = 2;.*?"
1076+
r"optional int32 Vehicle_year = 3;.*?"
1077+
r"optional repeated string Vehicle_features = 4 "
1078+
r"\[\(buf\.validate\.field\)\.repeated = \{unique: true, min_items: 1, max_items: 10\}\];.*?"
1079+
r"optional repeated Seat Cabin_seats = 5;.*?"
1080+
r"optional repeated Door Cabin_doors = 6;.*?"
1081+
r"optional float Cabin_temperature = 7 \[\(buf\.validate\.field\)\.float = \{gte: -100, lte: 100\}\];.*?"
1082+
r"optional bool Door_isLocked = 8;.*?"
1083+
r"optional int32 Door_position = 9 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?"
1084+
r"RowEnum\.Enum Door_instanceTag_row = 10 \[\(buf\.validate\.field\)\.required = true\];.*?"
1085+
r"SideEnum\.Enum Door_instanceTag_side = 11 \[\(buf\.validate\.field\)\.required = true\];.*?"
1086+
r"\}",
1087+
result,
1088+
re.DOTALL,
1089+
), "Message with flattened fields from all root-level types (Vehicle, Cabin, Door)"
1090+
1091+
assert "message Seat {" in result, "Seat message should be included as it's referenced by arrays"
1092+
1093+
assert "message Vehicle {" not in result, "Vehicle should be completely flattened"
1094+
assert "message Cabin {" not in result, "Cabin should be completely flattened"
1095+
assert "message Door {" not in result, "Door should be completely flattened"

0 commit comments

Comments
 (0)