Skip to content

Commit 653a8f4

Browse files
committed
feat(protobuf): implement instance expansion
Signed-off-by: Ahmed Mohamed <[email protected]>
1 parent 28967af commit 653a8f4

File tree

4 files changed

+467
-27
lines changed

4 files changed

+467
-27
lines changed

src/s2dm/exporters/protobuf/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ProtoMessage(BaseModel):
4545
fields: list[ProtoField]
4646
description: str | None = None
4747
source: str | None = None
48+
nested_messages: list["ProtoMessage"] = Field(default_factory=list)
4849

4950
@field_validator("fields")
5051
@classmethod

src/s2dm/exporters/protobuf/templates/proto.j2

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,28 @@ message {{ enum.name }} {
3030
}
3131

3232
{% endfor %}
33-
{% for message in messages %}
33+
{% macro render_message(message, indent=0) %}
34+
{% set indent_str = " " * indent %}
3435
{% if message.description %}
35-
// {{ message.description }}
36+
{{ indent_str }}// {{ message.description }}
3637
{% endif %}
37-
message {{ message.name }} {
38+
{{ indent_str }}message {{ message.name }} {
3839
{% if message.source %}
39-
option (source) = "{{ message.source }}";
40+
{{ indent_str }} option (source) = "{{ message.source }}";
4041

4142
{% endif %}
43+
{% for nested in message.nested_messages %}
44+
{{ render_message(nested, indent + 1) }}
45+
{% endfor %}
4246
{% for field in message.fields %}
43-
{{ field.type }} {{ field.name }} = {{ field.number }}{% if field.validation_rules %} {{ field.validation_rules }}{% endif %};{% if field.description %} // {{ field.description }}{% endif %}
47+
{{ indent_str }} {{ field.type }} {{ field.name }} = {{ field.number }}{% if field.validation_rules %} {{ field.validation_rules }}{% endif %};{% if field.description %} // {{ field.description }}{% endif %}
4448

4549
{% endfor %}
46-
}
50+
{{ indent_str }}}
4751

52+
{% endmacro %}
53+
{% for message in messages %}
54+
{{ render_message(message) }}
4855
{% endfor %}
4956
{% if flatten_mode %}
5057
message Message {

src/s2dm/exporters/protobuf/transformer.py

Lines changed: 114 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from s2dm.exporters.utils.directive import get_directive_arguments, has_given_directive
2929
from s2dm.exporters.utils.extraction import get_all_named_types
3030
from s2dm.exporters.utils.field import get_cardinality
31-
from s2dm.exporters.utils.instance_tag import is_instance_tag_field
31+
from s2dm.exporters.utils.instance_tag import expand_instance_tag, get_instance_tag_object, is_instance_tag_field
32+
from s2dm.exporters.utils.naming import convert_name, get_target_case_for_element
3233
from s2dm.exporters.utils.schema_loader import get_referenced_types
3334

3435
GRAPHQL_SCALAR_TO_PROTOBUF = {
@@ -108,7 +109,7 @@ def transform(self) -> str:
108109
log.info("Starting GraphQL to Protobuf transformation")
109110

110111
if self.root_type:
111-
referenced_types = get_referenced_types(self.graphql_schema, self.root_type, True)
112+
referenced_types = get_referenced_types(self.graphql_schema, self.root_type, not self.expanded_instances)
112113
user_defined_types: list[GraphQLNamedType] = [
113114
referenced_type for referenced_type in referenced_types if isinstance(referenced_type, GraphQLNamedType)
114115
]
@@ -181,30 +182,53 @@ def _build_messages(self, message_types: list[GraphQLObjectType | GraphQLInterfa
181182
"""Build Pydantic models for message types."""
182183
messages = []
183184
for message_type in message_types:
184-
fields = self._build_message_fields(message_type)
185+
fields, nested_messages = self._build_message_fields(message_type)
185186

186187
messages.append(
187188
ProtoMessage(
188189
name=message_type.name,
189190
fields=fields,
190191
description=message_type.description,
191192
source=message_type.name,
193+
nested_messages=nested_messages,
192194
)
193195
)
194196
return messages
195197

196-
def _build_message_fields(self, message_type: GraphQLObjectType | GraphQLInterfaceType) -> list[ProtoField]:
198+
def _build_message_fields(
199+
self, message_type: GraphQLObjectType | GraphQLInterfaceType
200+
) -> tuple[list[ProtoField], list[ProtoMessage]]:
197201
"""Build Pydantic models for fields in a message."""
198202
fields = []
203+
nested_messages = []
199204
field_number = 1
200205

201206
for field_name, field in message_type.fields.items():
202207
if is_instance_tag_field(field_name) and self.expanded_instances:
203208
continue
204209

205-
proto_field_type = self._get_field_proto_type(field.type)
206-
proto_field_name = self._escape_field_name(field_name)
207-
validation_rules = self.process_directives(field, proto_field_type)
210+
field_type = field.type
211+
unwrapped_type = get_named_type(field_type)
212+
213+
proto_field_name = field_name
214+
proto_field_type = None
215+
expanded_message_name = None
216+
217+
if is_object_type(unwrapped_type):
218+
object_type = cast(GraphQLObjectType, unwrapped_type)
219+
expanded_instances = self._get_expanded_instances(object_type)
220+
if expanded_instances:
221+
proto_field_name, proto_field_type, nested_message = self._handle_expanded_instance_field(
222+
object_type, message_type, expanded_instances
223+
)
224+
nested_messages.append(nested_message)
225+
expanded_message_name = proto_field_type
226+
227+
if proto_field_type is None:
228+
proto_field_type = self._get_field_proto_type(field.type)
229+
proto_field_name = self._escape_field_name(field_name)
230+
231+
validation_rules = None if expanded_message_name else self.process_directives(field, proto_field_type)
208232

209233
fields.append(
210234
ProtoField(
@@ -217,7 +241,7 @@ def _build_message_fields(self, message_type: GraphQLObjectType | GraphQLInterfa
217241
)
218242
field_number += 1
219243

220-
return fields
244+
return fields, nested_messages
221245

222246
def _build_unions(self, union_types: list[GraphQLUnionType]) -> list[ProtoUnion]:
223247
"""Build Pydantic models for union types."""
@@ -277,12 +301,14 @@ def _flatten_fields(
277301
prefix: str,
278302
all_types: list[GraphQLObjectType | GraphQLInterfaceType],
279303
field_counter: int,
304+
type_cache: dict[str, GraphQLObjectType | GraphQLInterfaceType] | None = None,
280305
) -> tuple[list[ProtoField], set[str], int]:
281306
"""Recursively flatten fields with prefix."""
307+
if type_cache is None:
308+
type_cache = {type_def.name: type_def for type_def in all_types}
309+
282310
fields: list[ProtoField] = []
283311
referenced_types: set[str] = set()
284-
if not hasattr(object_type, "fields"):
285-
return fields, referenced_types, field_counter
286312

287313
for field_name, field in object_type.fields.items():
288314
if is_instance_tag_field(field_name) and self.expanded_instances:
@@ -291,6 +317,21 @@ def _flatten_fields(
291317
field_type = field.type
292318
unwrapped_type = get_named_type(field_type)
293319

320+
if is_object_type(unwrapped_type):
321+
object_type = cast(GraphQLObjectType, unwrapped_type)
322+
expanded_instances = self._get_expanded_instances(object_type)
323+
if expanded_instances:
324+
nested_type = type_cache.get(object_type.name)
325+
if nested_type:
326+
for expanded_instance in expanded_instances:
327+
expanded_prefix = f"{prefix}_{field_name}_{expanded_instance.replace('.', '_')}"
328+
nested_fields, nested_referenced, field_counter = self._flatten_fields(
329+
nested_type, expanded_prefix, all_types, field_counter, type_cache
330+
)
331+
fields.extend(nested_fields)
332+
referenced_types.update(nested_referenced)
333+
continue
334+
294335
inner = field_type.of_type if is_non_null_type(field_type) else field_type
295336
is_list = is_list_type(inner)
296337

@@ -317,11 +358,11 @@ def _flatten_fields(
317358
continue
318359

319360
named_unwrapped_type = cast(GraphQLObjectType | GraphQLInterfaceType, unwrapped_type)
320-
nested_type = self._get_type(named_unwrapped_type.name, all_types)
361+
nested_type = type_cache.get(named_unwrapped_type.name)
321362

322363
if nested_type:
323364
nested_fields, nested_referenced, field_counter = self._flatten_fields(
324-
nested_type, flattened_name, all_types, field_counter
365+
nested_type, flattened_name, all_types, field_counter, type_cache
325366
)
326367
fields.extend(nested_fields)
327368
referenced_types.update(nested_referenced)
@@ -330,15 +371,6 @@ def _flatten_fields(
330371

331372
return fields, referenced_types, field_counter
332373

333-
def _get_type(
334-
self, type_name: str, all_types: list[GraphQLObjectType | GraphQLInterfaceType]
335-
) -> GraphQLObjectType | GraphQLInterfaceType | None:
336-
"""Get a GraphQL type by name from a list of types."""
337-
for type_def in all_types:
338-
if type_def.name == type_name:
339-
return type_def
340-
return None
341-
342374
def _get_field_proto_type(self, field_type: GraphQLType) -> str:
343375
"""Get the Protobuf type string for a GraphQL field type."""
344376
if is_non_null_type(field_type):
@@ -414,3 +446,64 @@ def _get_validation_type(self, proto_type: str) -> str | None:
414446
"""Get the protovalidate scalar type from protobuf type."""
415447
validation_type = proto_type.replace("repeated ", "")
416448
return validation_type if validation_type in PROTOBUF_DATA_TYPES else None
449+
450+
def _handle_expanded_instance_field(
451+
self,
452+
object_type: GraphQLObjectType,
453+
message_type: GraphQLObjectType | GraphQLInterfaceType,
454+
expanded_instances: list[str],
455+
) -> tuple[str, str, ProtoMessage]:
456+
"""Handle expanded instance fields, returning field name, type, and nested message."""
457+
prefixed_message_name = f"{message_type.name}_{object_type.name}"
458+
nested_message = self._build_nested_message_structure(
459+
prefixed_message_name, expanded_instances, object_type.name
460+
)
461+
462+
field_name_to_use = object_type.name
463+
if self.naming_config:
464+
target_case = get_target_case_for_element("field", "object", self.naming_config)
465+
if target_case:
466+
field_name_to_use = convert_name(object_type.name, target_case)
467+
468+
return (self._escape_field_name(field_name_to_use), nested_message.name, nested_message)
469+
470+
def _build_nested_message_structure(
471+
self,
472+
message_name: str,
473+
instance_paths: list[str],
474+
target_type: str,
475+
) -> ProtoMessage:
476+
"""Create nested message structure for expanded instance tags."""
477+
message = ProtoMessage(name=message_name, fields=[], nested_messages=[], source=None)
478+
child_paths_by_level: dict[str, list[str]] = {}
479+
field_counter = 1
480+
481+
for instance_path in instance_paths:
482+
instance_path_parts = instance_path.split(".")
483+
if len(instance_path_parts) > 1:
484+
root_level_name = instance_path_parts[0]
485+
remaining_path = ".".join(instance_path_parts[1:])
486+
child_paths_by_level.setdefault(root_level_name, []).append(remaining_path)
487+
else:
488+
message.fields.append(ProtoField(name=instance_path_parts[0], type=target_type, number=field_counter))
489+
field_counter += 1
490+
491+
for root_level_name, child_paths in child_paths_by_level.items():
492+
child_message_name = f"{message_name}_{root_level_name}"
493+
child_message = self._build_nested_message_structure(child_message_name, child_paths, target_type)
494+
message.nested_messages.append(child_message)
495+
message.fields.append(ProtoField(name=root_level_name, type=child_message.name, number=field_counter))
496+
field_counter += 1
497+
498+
return message
499+
500+
def _get_expanded_instances(self, object_type: GraphQLObjectType) -> list[str] | None:
501+
"""Get expanded instances if the type has a valid instance tag."""
502+
if not self.expanded_instances:
503+
return None
504+
505+
instance_tag_object = get_instance_tag_object(object_type, self.graphql_schema)
506+
if not instance_tag_object:
507+
return None
508+
509+
return expand_instance_tag(instance_tag_object, self.naming_config)

0 commit comments

Comments
 (0)