diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 55f5d185..b02d3859 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,10 +30,10 @@ repos: entry: bash -c 'uv run mypy --install-types --non-interactive "$@"' -- language: system types: [python] - files: ^src/ + files: ^(src|tests)/ require_serial: true pass_filenames: false - args: [src/s2dm] + args: [src/, tests/] - repo: https://github.com/jorisroovers/gitlint rev: v0.19.1 hooks: diff --git a/docs-gen/content/docs/tools/cli.md b/docs-gen/content/docs/tools/cli.md index 4a1bb4e7..8c3454a1 100644 --- a/docs-gen/content/docs/tools/cli.md +++ b/docs-gen/content/docs/tools/cli.md @@ -4,8 +4,677 @@ weight: 1 chapter: false --- +## Compose Command + +The `compose` command merges multiple GraphQL schema files into a single unified schema file. It automatically adds `@reference` directives to track which file each type was obtained from. + +### Basic Usage + +```bash +s2dm compose -s -s -o +``` + +### Options + +- `-s, --schema PATH`: GraphQL schema file or directory (required, can be specified multiple times) +- `-r, --root-type TEXT`: Root type name for filtering the schema (optional) +- `-q, --selection-query PATH`: GraphQL query file for filtering schema based on selected fields (optional) +- `-o, --output FILE`: Output file path (required) + +### Examples + +#### Compose Multiple Schema Files + +Merge multiple GraphQL schema files into a single output: + +```bash +s2dm compose -s schema1.graphql -s schema2.graphql -o composed.graphql +``` + +#### Compose from Directories + +Merge all `.graphql` files from multiple directories: + +```bash +s2dm compose -s ./schemas/vehicle -s ./schemas/person -o composed.graphql +``` + +#### Filter by Root Type + +Compose only types reachable from a specific root type: + +```bash +s2dm compose -s schema1.graphql -s schema2.graphql -o composed.graphql -r Vehicle +``` + +This will include only the `Vehicle` type and all types transitively referenced by it, filtering out unreferenced types like `Person` if they're not connected to `Vehicle`. + +#### Filter by Selection Query + +Compose only types and fields selected in a GraphQL query: + +```bash +s2dm compose -s schema1.graphql -s schema2.graphql -q query.graphql -o composed.graphql +``` + +Given a query file `query.graphql`: + +```graphql +query Selection { + vehicle(instance: "id") { + averageSpeed + adas { + abs { + isEngaged + } + } + } +} +``` + +The composed schema will include: + +- Only the selected types: `vehicle`, `adas`, `abs` +- Only the selected fields within each type +- Types referenced by field arguments (e.g., enums used in field arguments) +- Only directive definitions that are actually used in the filtered schema + +**Note:** The query must be valid against the composed schema. Root fields in the query (e.g., `vehicle`) must exist in the `Query` type of the schema. + +### Reference Directives + +The compose command automatically adds `@reference(source: String!)` directives to all types to track their source: + +```graphql +type Vehicle @reference(source: "schema1.graphql") { + id: ID! + name: String +} + +type Person @reference(source: "schema2.graphql") { + id: ID! + name: String +} +``` + +Types from the S2DM specification (common types, scalars, directives) are marked with: + +```graphql +type InCabinArea2x2 @instanceTag @reference(source: "S2DM Spec") { + row: TwoRowsInCabinEnum + column: TwoColumnsInCabinEnum +} +``` + +**Note:** If a type already has a `@reference` directive in the source schema, it will be preserved and not overwritten. + ## Export Commands +### Protocol Buffers (Protobuf) + +This exporter translates the given GraphQL schema to [Protocol Buffers](https://protobuf.dev/) (`.proto`) format. + +#### Key Features + +- **Complete GraphQL Type Support**: Handles all GraphQL types including scalars, objects, enums, unions, interfaces, and lists +- **Selection Query (Required)**: Use the `--selection-query` flag to specify which types and fields to export via a GraphQL query +- **Root Type Filtering**: Use the `--root-type` flag to export only a specific type and its dependencies +- **Flatten Naming Mode**: Use the `--flatten-naming` flag to flatten nested structures into a single message with prefixed field names +- **Expanded Instance Tags**: Use the `--expanded-instances` flag to transform instance tag arrays into nested message structures +- **Field Nullability**: Properly handles nullable vs non-nullable fields from GraphQL schema +- **Directive Support**: Converts S2DM directives like `@cardinality`, `@range`, and `@noDuplicates` to protovalidate constraints +- **Package Name Support**: Use the `--package-name` flag to specify a protobuf package namespace + +#### Example Transformation + +Consider the following GraphQL schema and selection query: + +GraphQL Schema: + +```graphql +type Cabin { + doors: [Door] + temperature: Float +} + +type Door { + isLocked: Boolean + instanceTag: DoorPosition +} + +type DoorPosition @instanceTag { + row: RowEnum + side: SideEnum +} + +enum RowEnum { + ROW1 + ROW2 +} + +enum SideEnum { + DRIVERSIDE + PASSENGERSIDE +} + +type Query { + cabin: Cabin +} +``` + +Selection Query: + +```graphql +query Selection { + cabin { + doors { + isLocked + instanceTag { + row + side + } + } + temperature + } +} +``` + +The Protobuf exporter produces: + +> See [Selection Query](#selection-query-required) for more details on the command. + +```protobuf +syntax = "proto3"; + +import "google/protobuf/descriptor.proto"; +import "buf/validate/validate.proto"; + +extend google.protobuf.MessageOptions { + string source = 50001; +} + +message RowEnum { + option (source) = "RowEnum"; + + enum Enum { + ROWENUM_UNSPECIFIED = 0; + ROW1 = 1; + ROW2 = 2; + } +} + +message SideEnum { + option (source) = "SideEnum"; + + enum Enum { + SIDEENUM_UNSPECIFIED = 0; + DRIVERSIDE = 1; + PASSENGERSIDE = 2; + } +} + +message DoorPosition { + option (source) = "DoorPosition"; + + RowEnum.Enum row = 1; + SideEnum.Enum side = 2; +} + +message Cabin { + option (source) = "Cabin"; + + repeated Door doors = 1; + float temperature = 2; +} + +message Door { + option (source) = "Door"; + + bool isLocked = 1; + DoorPosition instanceTag = 2; +} + +message Selection { + option (source) = "Query"; + + optional Cabin cabin = 1; +} +``` + +> The `Query` type from the GraphQL schema is renamed to match the selection query operation name (`Selection` in this example). + +#### Selection Query (Required) + +The protobuf exporter requires a selection query to determine which types and fields to export: + +```bash +s2dm export protobuf --schema schema.graphql --selection-query query.graphql --output cabin.proto +``` + +Given a query file `query.graphql` (presented above), the exporter will include only the selected types and fields from the schema. + +#### Root Type Filtering + +Use the `--root-type` flag in combination with the selection query to further filter the export: + +```bash +s2dm export protobuf --schema schema.graphql --selection-query query.graphql --output cabin.proto --root-type Cabin +``` + +This will include only the `Cabin` type and all types transitively referenced by it from the selection query. + +#### Flatten Naming Mode + +Use the `--flatten-naming` flag to flatten nested object structures into a single message with prefixed field names. This mode works with the selection query to flatten all root-level types selected in the query: + +```bash +s2dm export protobuf --schema schema.graphql --selection-query query.graphql --output vehicle.proto --flatten-naming +``` + +You can optionally combine it with `--root-type` to flatten only a specific root type: + +```bash +s2dm export protobuf --schema schema.graphql --selection-query query.graphql --output vehicle.proto --root-type Vehicle --flatten-naming +``` + +**Example transformation:** + +Given a GraphQL schema and the selection query: + +GraphQL Schema: + +```graphql +type Vehicle { + adas: ADAS +} + +type ADAS { + abs: ABS +} + +type ABS { + isEngaged: Boolean +} + +type Query { + vehicle: Vehicle +} +``` + +Selection Query: + +```graphql +query Selection { + vehicle { + adas { + abs { + isEngaged + } + } + } +} +``` + +Flatten mode produces: + +```protobuf +syntax = "proto3"; + +import "google/protobuf/descriptor.proto"; +import "buf/validate/validate.proto"; + +extend google.protobuf.MessageOptions { + string source = 50001; +} + +message Selection { + bool Vehicle_adas_abs_isEngaged = 1; +} + +``` + +> The output message name is derived from the selection query operation name (`Selection` in this example). + +#### Expanded Instance Tags + +The `--expanded-instances` flag transforms instance tag objects into nested message structures instead of repeated fields. This provides compile-time type safety for accessing specific instances. + +```bash +s2dm export protobuf --schema schema.graphql --selection-query query.graphql --output cabin.proto --expanded-instances +``` + +**Default behavior (without flag):** + +Given a GraphQL schema with instance tags and a selection query: + +GraphQL Schema: + +```graphql +type Cabin { + doors: [Door] +} + +type Door { + isLocked: Boolean + instanceTag: DoorPosition +} + +type DoorPosition @instanceTag { + row: RowEnum + side: SideEnum +} + +enum RowEnum { + ROW1 + ROW2 +} + +enum SideEnum { + DRIVERSIDE + PASSENGERSIDE +} + +type Query { + cabin: Cabin +} +``` + +Selection Query: + +```graphql +query Selection { + cabin { + doors { + isLocked + instanceTag { + row + side + } + } + } +} +``` + +Default output uses repeated fields and includes the instanceTag field: + +```protobuf +syntax = "proto3"; + +import "google/protobuf/descriptor.proto"; +import "buf/validate/validate.proto"; + +extend google.protobuf.MessageOptions { + string source = 50001; +} + +message RowEnum { + option (source) = "RowEnum"; + + enum Enum { + ROWENUM_UNSPECIFIED = 0; + ROW1 = 1; + ROW2 = 2; + } +} + +message SideEnum { + option (source) = "SideEnum"; + + enum Enum { + SIDEENUM_UNSPECIFIED = 0; + DRIVERSIDE = 1; + PASSENGERSIDE = 2; + } +} + +message Door { + option (source) = "Door"; + + bool isLocked = 1; + DoorPosition instanceTag = 2; +} + + +message Cabin { + option (source) = "Cabin"; + + repeated Door doors = 1; +} + + +message DoorPosition { + option (source) = "DoorPosition"; + + RowEnum.Enum row = 1; + SideEnum.Enum side = 2; +} + +message Selection { + option (source) = "Query"; + + optional Cabin cabin = 1; +} +``` + +**With `--expanded-instances` flag:** + +The same schema and selection query produce nested messages representing the cartesian product of instance tag values: + +```protobuf +syntax = "proto3"; + +import "google/protobuf/descriptor.proto"; +import "buf/validate/validate.proto"; + +extend google.protobuf.MessageOptions { + string source = 50001; +} + +message Door { + option (source) = "Door"; + + bool isLocked = 1; +} + + +message Cabin { + option (source) = "Cabin"; + + message Cabin_Door { + message Cabin_Door_ROW1 { + Door DRIVERSIDE = 1; + Door PASSENGERSIDE = 2; + } + + message Cabin_Door_ROW2 { + Door DRIVERSIDE = 1; + Door PASSENGERSIDE = 2; + } + + Cabin_Door_ROW1 ROW1 = 1; + Cabin_Door_ROW2 ROW2 = 2; + } + + Cabin_Door Door = 1; +} + +message Selection { + option (source) = "Query"; + + optional Cabin cabin = 1; +} +``` + +**Key differences:** + +- Instance tag enums (`RowEnum`, `SideEnum`) are excluded from the output when using expanded instances +- Types with `@instanceTag` directive (`DoorPosition`) are excluded from the output +- The `instanceTag` field is excluded from the Door message +- Nested messages are created inside the parent message +- Field names use the GraphQL type name (`Door` not `doors`) + +#### Directive Support + +S2DM directives are converted to [protovalidate](https://github.com/bufbuild/protovalidate) constraints: + +- `@range(min: 0, max: 100)` → `[(buf.validate.field).int32 = {gte: 0, lte: 100}]` +- `@noDuplicates` → `[(buf.validate.field).repeated = {unique: true}]` +- `@cardinality(min: 1, max: 5)` → `[(buf.validate.field).repeated = {min_items: 1, max_items: 5}]` + +GraphQL Schema: + +```graphql +type Vehicle { + speed: Int @range(min: 0, max: 300) + tags: [String] @noDuplicates @cardinality(min: 1, max: 10) +} + +type Query { + vehicle: Vehicle +} +``` + +Selection Query: + +```graphql +query Selection { + vehicle { + speed + tags + } +} +``` + +Produces: + +```protobuf +syntax = "proto3"; + +import "google/protobuf/descriptor.proto"; +import "buf/validate/validate.proto"; + +extend google.protobuf.MessageOptions { + string source = 50001; +} + +message Vehicle { + option (source) = "Vehicle"; + + int32 speed = 1 [(buf.validate.field).int32 = {gte: 0, lte: 300}]; + repeated string tags = 2 [(buf.validate.field).repeated = {unique: true, min_items: 1, max_items: 10}]; +} + +message Selection { + option (source) = "Query"; + + optional Vehicle vehicle = 1; +} +``` + +#### Type Mappings + +GraphQL types are mapped to protobuf types as follows: + +| GraphQL Type | Protobuf Type | +|--------------|---------------| +| `String` | `string` | +| `Int` | `int32` | +| `Float` | `float` | +| `Boolean` | `bool` | +| `ID` | `string` | +| `Int8` | `int32` | +| `UInt8` | `uint32` | +| `Int16` | `int32` | +| `UInt16` | `uint32` | +| `UInt32` | `uint32` | +| `Int64` | `int64` | +| `UInt64` | `uint64` | + +**List types** are converted to `repeated` fields: + +- `[String]` → `repeated string` +- `[Int]` → `repeated int32` + +**Enums** are converted to protobuf enums wrapped in a message: + +- Each GraphQL enum becomes a protobuf message with the same name +- Inside the message, an `Enum` nested enum is created +- An `UNSPECIFIED` value is added at position 0 +- References use the `.Enum` suffix (e.g., `LockStatus.Enum`) + +**Field Nullability:** + +GraphQL field nullability is preserved in protobuf using the `optional` keyword and protovalidate constraints: + +- **Nullable fields** (e.g., `name: String`) → `optional` proto3 fields +- **Non-nullable fields** (e.g., `id: ID!`) → fields with `[(buf.validate.field).required = true]` + +Example: + +```graphql +type User { + id: ID! # Non-nullable + name: String # Nullable +} +``` + +Produces: + +```protobuf +message User { + option (source) = "User"; + + string id = 1 [(buf.validate.field).required = true]; + optional string name = 2; +} +``` + +You can call the help for usage reference: + +```bash +s2dm export protobuf --help +``` + +#### Field Number Stability + +**Important Limitation**: Field numbers in generated protobuf files are **not stable** across schema regenerations when the GraphQL schema changes. + +**How Field Numbers Are Assigned:** + +Field numbers are assigned sequentially (starting from 1) based on: + +1. The iteration order of fields in the GraphQL schema +2. Which types/fields are included (affected by `--root-type` filtering) +3. The flattening logic (when using `--flatten-naming`) + +**Impact on Schema Evolution:** + +Any change to the GraphQL schema can cause field number reassignments: + +```graphql +# Version 1 +type Door { + isLocked: Boolean # becomes field number 1 + position: Int # becomes field number 2 +} + +# Version 2 - Adding a new field +type Door { + id: ID # becomes field number 1 + isLocked: Boolean # becomes field number 2 (was 1!) + position: Int # becomes field number 3 (was 2!) +} +``` + +**When Field Number Stability Matters:** + +Field number changes break compatibility if you have: + +- **Persistent protobuf data**: Data stored in databases, files, or caches will deserialize incorrectly after regeneration +- **Rolling deployments**: Services using different schema versions cannot communicate during deployment +- **Message queues**: Messages enqueued before regeneration will fail to deserialize correctly +- **Archived data**: Historical protobuf-encoded logs or backups become unreadable + ### Naming Configuration All export commands support a global naming configuration feature that allows you to transform element names during the export process using the `[--naming-config | -n]` flag. diff --git a/pyproject.toml b/pyproject.toml index 7c34f66a..5d74f1d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "graphql-core>=3.2.6", "pyyaml>=6.0.2", "pydantic>=2.10.6", + "jinja2>=3.1.0", "rdflib>=7.1.3", "pyshacl>=0.30.0", "ariadne>=0.24.0", diff --git a/src/s2dm/cli.py b/src/s2dm/cli.py index d150035a..2d246d01 100644 --- a/src/s2dm/cli.py +++ b/src/s2dm/cli.py @@ -1,24 +1,25 @@ import json import logging import sys +from collections.abc import Callable from pathlib import Path from typing import Any import rich_click as click import yaml from graphql import build_schema, parse -from graphql import validate as graphql_validate from rich.traceback import install from s2dm import __version__, log from s2dm.concept.services import create_concept_uri_model, iter_all_concepts from s2dm.exporters.id import IDExporter from s2dm.exporters.jsonschema import translate_to_jsonschema +from s2dm.exporters.protobuf import translate_to_protobuf from s2dm.exporters.shacl import translate_to_shacl from s2dm.exporters.spec_history import SpecHistoryExporter from s2dm.exporters.utils.extraction import get_all_named_types, get_all_object_types from s2dm.exporters.utils.graphql_type import is_builtin_scalar_type, is_introspection_type -from s2dm.exporters.utils.schema import search_schema +from s2dm.exporters.utils.schema import load_schema_with_naming, search_schema from s2dm.exporters.utils.schema_loader import ( create_tempfile_to_composed_schema, load_schema, @@ -74,6 +75,25 @@ def process_value(self, ctx: click.Context, value: Any) -> list[Path] | None: help="The GraphQL schema file or directory containing schema files. Can be specified multiple times.", ) + +def selection_query_option(required: bool = False) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + return click.option( + "--selection-query", + "-q", + type=click.Path(exists=True, dir_okay=False, path_type=Path), + required=required, + help="GraphQL query file to filter the passed schema", + ) + + +root_type_option = click.option( + "--root-type", + "-r", + type=str, + help="Root type name for filtering/scoping the schema", +) + + output_option = click.option( "--output", "-o", @@ -91,6 +111,15 @@ def process_value(self, ctx: click.Context, value: Any) -> list[Path] | None: ) +expanded_instances_option = click.option( + "--expanded-instances", + "-e", + is_flag=True, + default=False, + help="Expand instance tags into nested structure instead of arrays/repeated fields", +) + + def pretty_print_dict_json(result: dict[str, Any]) -> dict[str, Any]: """ Recursively pretty-print a dict for JSON output: @@ -357,18 +386,8 @@ def validate() -> None: @click.command() @schema_option -@click.option( - "--root-type", - "-r", - type=str, - help="Root type name for filtering the schema", -) -@click.option( - "--selection-query", - "-q", - type=click.Path(exists=True, dir_okay=False, path_type=Path), - help="GraphQL query file to filter the composed schema", -) +@selection_query_option() +@root_type_option @output_option def compose(schemas: list[Path], root_type: str | None, selection_query: Path | None, output: Path) -> None: """Compose GraphQL schema files into a single output file.""" @@ -381,25 +400,9 @@ def compose(schemas: list[Path], root_type: str | None, selection_query: Path | graphql_schema = build_schema(composed_schema_str) if selection_query: - log.debug("Validating selection query against composed schema...") - query_document = parse(selection_query.read_text()) - errors = graphql_validate(graphql_schema, query_document) - - if errors: - log.error("Selection query validation failed:") - for error in errors: - log.error(f" - {error.message}") - sys.exit(1) - - log.debug("Selection query validated successfully") - - log.debug("Filtering schema based on selection query...") - graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document) - log.debug("Filtered schema based on query selections") - composed_schema_str = print_schema_with_directives_preserved(graphql_schema) output.write_text(composed_schema_str) @@ -426,6 +429,7 @@ def compose(schemas: list[Path], root_type: str | None, selection_query: Path | # ---------- @export.command @schema_option +@selection_query_option() @output_option @click.option( "--serialization-format", @@ -471,6 +475,7 @@ def compose(schemas: list[Path], root_type: str | None, selection_query: Path | def shacl( ctx: click.Context, schemas: list[Path], + selection_query: Path | None, output: Path, serialization_format: str, shapes_namespace: str, @@ -480,8 +485,15 @@ def shacl( ) -> None: """Generate SHACL shapes from a given GraphQL schema.""" naming_config = ctx.obj.get("naming_config") + + graphql_schema = load_schema_with_naming(schemas, naming_config) + + if selection_query: + query_document = parse(selection_query.read_text()) + graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document) + result = translate_to_shacl( - schemas, + graphql_schema, shapes_namespace, shapes_namespace_prefix, model_namespace, @@ -496,12 +508,19 @@ def shacl( # ---------- @export.command @schema_option +@selection_query_option() @output_option @click.pass_context -def vspec(ctx: click.Context, schemas: list[Path], output: Path) -> None: +def vspec(ctx: click.Context, schemas: list[Path], selection_query: Path | None, output: Path) -> None: """Generate VSPEC from a given GraphQL schema.""" naming_config = ctx.obj.get("naming_config") - result = translate_to_vspec(schemas, naming_config) + graphql_schema = load_schema_with_naming(schemas, naming_config) + + if selection_query: + query_document = parse(selection_query.read_text()) + graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document) + + result = translate_to_vspec(graphql_schema, naming_config) output.parent.mkdir(parents=True, exist_ok=True) _ = output.write_text(result) @@ -510,13 +529,9 @@ def vspec(ctx: click.Context, schemas: list[Path], output: Path) -> None: # ---------- @export.command @schema_option +@selection_query_option() @output_option -@click.option( - "--root-type", - "-r", - type=str, - help="Root type name for the JSON schema", -) +@root_type_option @click.option( "--strict", "-S", @@ -524,20 +539,71 @@ def vspec(ctx: click.Context, schemas: list[Path], output: Path) -> None: default=False, help="Enforce strict field nullability translation from GraphQL to JSON Schema", ) +@expanded_instances_option +@click.pass_context +def jsonschema( + ctx: click.Context, + schemas: list[Path], + selection_query: Path | None, + output: Path, + root_type: str | None, + strict: bool, + expanded_instances: bool, +) -> None: + """Generate JSON Schema from a given GraphQL schema.""" + naming_config = ctx.obj.get("naming_config") + graphql_schema = load_schema_with_naming(schemas, naming_config) + + if selection_query: + query_document = parse(selection_query.read_text()) + graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document) + + result = translate_to_jsonschema(graphql_schema, root_type, strict, expanded_instances, naming_config) + _ = output.write_text(result) + + +# Export -> protobuf +# ---------- +@export.command +@schema_option +@selection_query_option(required=True) +@output_option +@root_type_option @click.option( - "--expanded-instances", - "-e", + "--flatten-naming", + "-f", is_flag=True, default=False, - help="Expand instance tags into nested structure instead of arrays", + help="Flatten nested field names.", +) +@click.option( + "--package-name", + "-p", + type=str, + help="Protobuf package name", ) +@expanded_instances_option @click.pass_context -def jsonschema( - ctx: click.Context, schemas: list[Path], output: Path, root_type: str | None, strict: bool, expanded_instances: bool +def protobuf( + ctx: click.Context, + schemas: list[Path], + selection_query: Path, + output: Path, + root_type: str | None, + flatten_naming: bool, + package_name: str | None, + expanded_instances: bool, ) -> None: - """Generate JSON Schema from a given GraphQL schema.""" + """Generate Protocol Buffers (.proto) file from GraphQL schema.""" naming_config = ctx.obj.get("naming_config") - result = translate_to_jsonschema(schemas, root_type, strict, expanded_instances, naming_config) + graphql_schema = load_schema_with_naming(schemas, naming_config) + + query_document = parse(selection_query.read_text()) + graphql_schema = prune_schema_using_query_selection(graphql_schema, query_document) + + result = translate_to_protobuf( + graphql_schema, query_document, root_type, flatten_naming, package_name, naming_config, expanded_instances + ) _ = output.write_text(result) diff --git a/src/s2dm/exporters/jsonschema/jsonschema.py b/src/s2dm/exporters/jsonschema/jsonschema.py index f1b781e3..d2b0b54e 100644 --- a/src/s2dm/exporters/jsonschema/jsonschema.py +++ b/src/s2dm/exporters/jsonschema/jsonschema.py @@ -1,11 +1,9 @@ import json -from pathlib import Path from typing import Any from graphql import GraphQLSchema from s2dm import log -from s2dm.exporters.utils.schema import load_schema_with_naming from .transformer import JsonSchemaTransformer @@ -48,7 +46,7 @@ def transform( def translate_to_jsonschema( - schema_paths: list[Path], + schema: GraphQLSchema, root_type: str | None = None, strict: bool = False, expanded_instances: bool = False, @@ -66,9 +64,4 @@ def translate_to_jsonschema( Returns: str: JSON Schema representation as a string """ - log.info(f"Loading GraphQL schema from: {schema_paths}") - - graphql_schema = load_schema_with_naming(schema_paths, naming_config) - log.info(f"Successfully loaded GraphQL schema with {len(graphql_schema.type_map)} types") - - return transform(graphql_schema, root_type, strict, expanded_instances, naming_config) + return transform(schema, root_type, strict, expanded_instances, naming_config) diff --git a/src/s2dm/exporters/jsonschema/transformer.py b/src/s2dm/exporters/jsonschema/transformer.py index f65203c5..c23b0b93 100644 --- a/src/s2dm/exporters/jsonschema/transformer.py +++ b/src/s2dm/exporters/jsonschema/transformer.py @@ -25,7 +25,12 @@ from s2dm.exporters.utils.directive import get_directive_arguments, has_given_directive from s2dm.exporters.utils.extraction import get_all_named_types from s2dm.exporters.utils.field import get_cardinality -from s2dm.exporters.utils.instance_tag import expand_instance_tag, get_instance_tag_object, is_valid_instance_tag_field +from s2dm.exporters.utils.instance_tag import ( + expand_instance_tag, + get_instance_tag_object, + is_instance_tag_field, + is_valid_instance_tag_field, +) from s2dm.exporters.utils.schema_loader import get_referenced_types GRAPHQL_SCALAR_TO_JSON_SCHEMA = { @@ -187,8 +192,7 @@ def transform_object_type(self, object_type: GraphQLObjectType) -> dict[str, Any required_fields = [] for field_name, field in object_type.fields.items(): if is_valid_instance_tag_field(field, self.graphql_schema): - if field_name == "instanceTag": - # Skip instanceTag field as it is handled separately + if is_instance_tag_field(field_name): continue else: # Fields with an instanceTag object type should not be allowed since object diff --git a/src/s2dm/exporters/protobuf/__init__.py b/src/s2dm/exporters/protobuf/__init__.py new file mode 100644 index 00000000..3324e57d --- /dev/null +++ b/src/s2dm/exporters/protobuf/__init__.py @@ -0,0 +1,3 @@ +from .protobuf import translate_to_protobuf + +__all__ = ["translate_to_protobuf"] diff --git a/src/s2dm/exporters/protobuf/models.py b/src/s2dm/exporters/protobuf/models.py new file mode 100644 index 00000000..12c5f977 --- /dev/null +++ b/src/s2dm/exporters/protobuf/models.py @@ -0,0 +1,76 @@ +"""Pydantic models for Protobuf schema structures.""" + +from pydantic import BaseModel, Field, field_validator + + +class ProtoEnumValue(BaseModel): + """Represents a value in a Protocol Buffers enum.""" + + name: str + number: int = Field(ge=1) + description: str | None = None + + +class ProtoEnum(BaseModel): + """Represents a Protocol Buffers enum type.""" + + name: str + enum_values: list[ProtoEnumValue] + description: str | None = None + source: str | None = None + + @field_validator("enum_values") + @classmethod + def validate_unique_numbers(cls, enum_values: list[ProtoEnumValue]) -> list[ProtoEnumValue]: + numbers = [v.number for v in enum_values] + if len(numbers) != len(set(numbers)): + raise ValueError("Enum values must have unique numbers") + return enum_values + + +class ProtoField(BaseModel): + """Represents a field in a Protocol Buffers message.""" + + name: str + type: str + number: int = Field(ge=1) + description: str | None = None + validation_rules: str | None = None + + +class ProtoMessage(BaseModel): + """Represents a Protocol Buffers message type.""" + + name: str + fields: list[ProtoField] + description: str | None = None + source: str | None = None + nested_messages: list["ProtoMessage"] = Field(default_factory=list) + + @field_validator("fields") + @classmethod + def validate_unique_field_numbers(cls, fields: list[ProtoField]) -> list[ProtoField]: + numbers = [f.number for f in fields] + if len(numbers) != len(set(numbers)): + raise ValueError("Fields must have unique numbers") + return fields + + +class ProtoUnion(BaseModel): + """Represents a Protocol Buffers union (oneof).""" + + name: str + members: list[ProtoField] + description: str | None = None + source: str | None = None + + +class ProtoSchema(BaseModel): + """Represents a complete Protocol Buffers schema.""" + + syntax: str = "proto3" + package: str | None = None + enums: list[ProtoEnum] = Field(default_factory=list) + messages: list[ProtoMessage] = Field(default_factory=list) + unions: list[ProtoUnion] = Field(default_factory=list) + flatten_mode: bool = False diff --git a/src/s2dm/exporters/protobuf/protobuf.py b/src/s2dm/exporters/protobuf/protobuf.py new file mode 100644 index 00000000..743dd9a4 --- /dev/null +++ b/src/s2dm/exporters/protobuf/protobuf.py @@ -0,0 +1,83 @@ +from typing import Any + +from graphql import DocumentNode, GraphQLSchema + +from s2dm import log + +from .transformer import ProtobufTransformer + + +def transform( + graphql_schema: GraphQLSchema, + selection_query: DocumentNode, + root_type: str | None = None, + flatten_naming: bool = False, + package_name: str | None = None, + naming_config: dict[str, Any] | None = None, + expanded_instances: bool = False, +) -> str: + """ + Transform a GraphQL schema object to Protocol Buffers format. + + Args: + graphql_schema: The GraphQL schema object to transform + root_type: Optional root type name for the protobuf schema + flatten_naming: If True, flatten nested field names + package_name: Optional package name for the .proto file + naming_config: Optional naming configuration + expanded_instances: If True, expand instance tags into nested structures + selection_query: Required selection query document to determine root-level types + + Returns: + str: Protocol Buffers representation as a string + + Raises: + ValueError: If selection_query is not provided + """ + log.info(f"Transforming GraphQL schema to Protobuf with {len(graphql_schema.type_map)} types") + + if root_type: + if root_type not in graphql_schema.type_map: + raise ValueError(f"Root type '{root_type}' not found in schema") + log.info(f"Using root type: {root_type}") + + transformer = ProtobufTransformer( + graphql_schema, selection_query, root_type, flatten_naming, package_name, naming_config, expanded_instances + ) + proto_content = transformer.transform() + + log.info("Successfully converted GraphQL schema to Protobuf") + + return proto_content + + +def translate_to_protobuf( + schema: GraphQLSchema, + selection_query: DocumentNode, + root_type: str | None = None, + flatten_naming: bool = False, + package_name: str | None = None, + naming_config: dict[str, Any] | None = None, + expanded_instances: bool = False, +) -> str: + """ + Translate a GraphQL schema to Protocol Buffers format. + + Args: + schema: The GraphQL schema object + root_type: Optional root type name for the protobuf schema + flatten_naming: If True, flatten nested field names + package_name: Optional package name for the .proto file + naming_config: Optional naming configuration + expanded_instances: If True, expand instance tags into nested structures + selection_query: Required selection query document to determine root-level types + + Returns: + str: Protocol Buffers (.proto) representation as a string + + Raises: + ValueError: If selection_query is not provided + """ + return transform( + schema, selection_query, root_type, flatten_naming, package_name, naming_config, expanded_instances + ) diff --git a/src/s2dm/exporters/protobuf/templates/proto_standard.j2 b/src/s2dm/exporters/protobuf/templates/proto_standard.j2 new file mode 100644 index 00000000..55152ba4 --- /dev/null +++ b/src/s2dm/exporters/protobuf/templates/proto_standard.j2 @@ -0,0 +1,81 @@ +syntax = "{{ syntax }}"; +{% if imports %}{# if imports #} + +{% for import in imports %}{# for import in imports #} +{{ import }} +{% endfor %}{# endfor import #} +{% endif %}{# endif imports #} +{% if has_source_option %}{# if has_source_option #} + +extend google.protobuf.MessageOptions { + string source = 50001; +} +{% endif %}{# endif has_source_option #} +{% if package %}{# if package #} + +package {{ package }}; +{% endif %}{# endif package #} +{% if enums %}{# if enums #} + +{% for enum in enums %}{# for enum in enums #} +{% if enum.description %}{# if enum.description #} +// {{ enum.description }} +{% endif %}{# endif enum.description #} +message {{ enum.name }} { +{% if enum.source %}{# if enum.source #} + option (source) = "{{ enum.source }}"; + +{% endif %}{# endif enum.source #} + enum Enum { + {{ enum.name | upper }}_UNSPECIFIED = 0; +{% for value in enum.enum_values %}{# for value in enum.enum_values #} + {{ value.name }} = {{ value.number }};{% if value.description %} // {{ value.description }}{% endif %} + +{% endfor %}{# endfor value #} + } +} +{% endfor %}{# endfor enum #} +{% endif %}{# endif enums #} +{% macro render_message(message, indent=0) %}{# macro render_message #} +{% set indent_str = " " * indent %} +{% if message.description %}{# if message.description #} +{{ indent_str }}// {{ message.description }} +{% endif %}{# endif message.description #} +{{ indent_str }}message {{ message.name }} { +{% if message.source %}{# if message.source #} +{{ indent_str }} option (source) = "{{ message.source }}"; + +{% endif %}{# endif message.source #} +{% for nested in message.nested_messages %}{# for nested in message.nested_messages #} +{{ render_message(nested, indent + 1) }} +{% endfor %}{# endfor nested #} +{% for field in message.fields %}{# for field in message.fields #} +{{ indent_str }} {{ field.type }} {{ field.name }} = {{ field.number }}{% if field.validation_rules %} {{ field.validation_rules }}{% endif %};{% if field.description %} // {{ field.description }}{% endif %} + +{% endfor %}{# endfor field #} +{{ indent_str }}} +{% endmacro %}{# endmacro render_message #} +{% if messages %}{# if messages #} + +{% for message in messages %}{# for message in messages #} +{{ render_message(message) }} +{% endfor %}{# endfor message #} +{% endif %}{# endif messages #} +{% if unions %}{# if unions #} +{% for union in unions %}{# for union in unions #} +{% if union.description %}{# if union.description #} +// {{ union.description }} +{% endif %}{# endif union.description #} +message {{ union.name }} { +{% if union.source %}{# if union.source #} + option (source) = "{{ union.source }}"; + +{% endif %}{# endif union.source #} + oneof {{ union.name }} { +{% for member in union.members %}{# for member in union.members #} + {{ member.type }} {{ member.name }} = {{ member.number }}; +{% endfor %}{# endfor member #} + } +} +{% endfor %}{# endfor union #} +{% endif %}{# endif unions #} diff --git a/src/s2dm/exporters/protobuf/transformer.py b/src/s2dm/exporters/protobuf/transformer.py new file mode 100644 index 00000000..6254e93c --- /dev/null +++ b/src/s2dm/exporters/protobuf/transformer.py @@ -0,0 +1,647 @@ +from typing import Any, cast + +from graphql import ( + DocumentNode, + GraphQLEnumType, + GraphQLField, + GraphQLInterfaceType, + GraphQLList, + GraphQLNamedType, + GraphQLNonNull, + GraphQLObjectType, + GraphQLScalarType, + GraphQLSchema, + GraphQLType, + GraphQLUnionType, + OperationDefinitionNode, + OperationType, + get_named_type, + is_enum_type, + is_interface_type, + is_list_type, + is_non_null_type, + is_object_type, + is_scalar_type, + is_union_type, +) +from jinja2 import Environment, PackageLoader, select_autoescape + +from s2dm import log +from s2dm.exporters.protobuf.models import ProtoEnum, ProtoEnumValue, ProtoField, ProtoMessage, ProtoSchema, ProtoUnion +from s2dm.exporters.utils.directive import get_directive_arguments, has_given_directive +from s2dm.exporters.utils.extraction import get_all_named_types, get_root_level_types_from_query +from s2dm.exporters.utils.field import get_cardinality +from s2dm.exporters.utils.instance_tag import expand_instance_tag, get_instance_tag_object, is_instance_tag_field +from s2dm.exporters.utils.naming import convert_name, get_target_case_for_element +from s2dm.exporters.utils.schema_loader import get_referenced_types + +GRAPHQL_SCALAR_TO_PROTOBUF = { + # Built-in GraphQL scalars + "String": "string", + "Int": "int32", + "Float": "float", + "Boolean": "bool", + "ID": "string", + # Custom scalars + "Int8": "int32", + "UInt8": "uint32", + "Int16": "int32", + "UInt16": "uint32", + "UInt32": "uint32", + "Int64": "int64", + "UInt64": "uint64", +} + +PROTOBUF_RESERVED_KEYWORDS = { + "message", + "enum", + "service", + "rpc", + "option", + "import", + "package", + "syntax", + "reserved", + "oneof", + "repeated", + "optional", + "required", +} + +PROTOBUF_DATA_TYPES = set(GRAPHQL_SCALAR_TO_PROTOBUF.values()) + + +class ProtobufTransformer: + """ + Transformer class to convert GraphQL schema to Protocol Buffers format. + + This class provides methods to transform various GraphQL types into their + corresponding Protobuf definitions (messages, enums, etc.). + """ + + def __init__( + self, + graphql_schema: GraphQLSchema, + selection_query: DocumentNode, + root_type: str | None = None, + flatten_naming: bool = False, + package_name: str | None = None, + naming_config: dict[str, Any] | None = None, + expanded_instances: bool = False, + ): + if selection_query is None: + raise ValueError("selection_query is required") + + self.graphql_schema = graphql_schema + self.root_type = root_type + self.flatten_naming = flatten_naming + self.package_name = package_name + self.naming_config = naming_config + self.expanded_instances = expanded_instances + self.selection_query = selection_query + + self.env = Environment( + loader=PackageLoader("s2dm.exporters.protobuf", "templates"), + autoescape=select_autoescape(), + trim_blocks=True, + lstrip_blocks=True, + ) + + def transform(self) -> str: + """ + Transform a GraphQL schema to Protocol Buffers format. + + Returns: + str: Protobuf string representation of the GraphQL schema. + """ + log.info("Starting GraphQL to Protobuf transformation") + + if self.root_type: + referenced_types = get_referenced_types(self.graphql_schema, self.root_type, not self.expanded_instances) + user_defined_types: list[GraphQLNamedType] = [ + referenced_type for referenced_type in referenced_types if isinstance(referenced_type, GraphQLNamedType) + ] + else: + user_defined_types = get_all_named_types(self.graphql_schema) + + log.debug(f"Found {len(user_defined_types)} user-defined types to transform") + + enum_types: list[GraphQLEnumType] = [] + message_types: list[GraphQLObjectType | GraphQLInterfaceType] = [] + union_types: list[GraphQLUnionType] = [] + + for type_def in user_defined_types: + if is_enum_type(type_def): + enum_types.append(cast(GraphQLEnumType, type_def)) + elif is_object_type(type_def): + object_type = cast(GraphQLObjectType, type_def) + if not (has_given_directive(object_type, "instanceTag") and self.expanded_instances): + message_types.append(object_type) + elif is_interface_type(type_def): + message_types.append(cast(GraphQLInterfaceType, type_def)) + elif is_union_type(type_def): + union_types.append(cast(GraphQLUnionType, type_def)) + + proto_schema = ProtoSchema( + package=self.package_name, + enums=[], + flatten_mode=self.flatten_naming, + ) + + if self.flatten_naming: + # In flatten mode, we need a second filtering pass to remove types that were completely flattened. + # When object fields are flattened, they become prefixed fields in the parent (e.g., parent_child_field). + # If no fields reference that object type directly (non-flattened), the type definition is no longer needed. + # However, unions and enums cannot be flattened and must remain as separate type definitions. + ( + flattened_fields, + referenced_type_names, + flattened_root_types, + ) = self._build_flattened_fields(message_types) + message_types = [ + message_type + for message_type in message_types + if message_type.name in referenced_type_names and message_type.name not in flattened_root_types + ] + union_types = [union_type for union_type in union_types if union_type.name in referenced_type_names] + enum_types = [enum_type for enum_type in enum_types if enum_type.name in referenced_type_names] + + proto_schema.enums = self._build_enums(enum_types) + proto_schema.unions = self._build_unions(union_types) + proto_schema.messages = self._build_messages(message_types) + + if self.flatten_naming: + root_message_name = self._get_query_operation_name() + root_message_source = f"query: {root_message_name}" + root_message = ProtoMessage( + name=root_message_name, + fields=flattened_fields, + source=root_message_source, + ) + proto_schema.messages.append(root_message) + + template_name = "proto_standard.j2" + template = self.env.get_template(template_name) + + template_vars = self._build_template_vars(proto_schema) + + result = template.render(template_vars) + + log.info("Successfully transformed GraphQL schema to Protobuf") + return result + + def _has_validation_rules(self, proto_schema: ProtoSchema) -> bool: + """Check if any field in the schema has validation rules.""" + + def check_message(message: ProtoMessage) -> bool: + if any(field.validation_rules for field in message.fields): + return True + return any(check_message(nested) for nested in message.nested_messages) + + return any(check_message(message) for message in proto_schema.messages) + + def _has_source_option(self, proto_schema: ProtoSchema) -> bool: + """Check if any type in the schema has a source option.""" + return any(enum.source for enum in proto_schema.enums) or any( + message.source for message in proto_schema.messages + ) + + def _get_query_operation_name(self) -> str: + """Extract the operation name from the selection query, defaulting to appropriate fallback.""" + default_name = "Message" if self.flatten_naming else "Query" + + for definition in self.selection_query.definitions: + if not isinstance(definition, OperationDefinitionNode) or definition.operation != OperationType.QUERY: + continue + + if definition.name: + return definition.name.value + return default_name + + return default_name + + def _build_template_vars(self, proto_schema: ProtoSchema) -> dict[str, Any]: + """Build all template variables from proto schema.""" + has_source_option = self._has_source_option(proto_schema) + has_validation_rules = self._has_validation_rules(proto_schema) + + imports = [] + if has_source_option: + imports.append('import "google/protobuf/descriptor.proto";') + if has_validation_rules: + imports.append('import "buf/validate/validate.proto";') + + template_vars = proto_schema.model_dump() + template_vars["imports"] = imports + template_vars["has_source_option"] = has_source_option + + return template_vars + + def _build_enums(self, enum_types: list[GraphQLEnumType]) -> list[ProtoEnum]: + """Build Pydantic models for enum types.""" + enums = [] + for enum_type in enum_types: + enum_values = [ + ProtoEnumValue( + name=value_name, + number=index + 1, + description=enum_type.values[value_name].description, + ) + for index, value_name in enumerate(enum_type.values) + ] + enums.append( + ProtoEnum( + name=enum_type.name, + enum_values=enum_values, + description=enum_type.description, + source=enum_type.name, + ) + ) + return enums + + def _build_messages(self, message_types: list[GraphQLObjectType | GraphQLInterfaceType]) -> list[ProtoMessage]: + """Build Pydantic models for message types.""" + messages = [] + for message_type in message_types: + fields, nested_messages = self._build_message_fields(message_type) + + message_name = message_type.name + source = message_type.name + + if message_type.name == "Query": + message_name = self._get_query_operation_name() + source = f"query: {message_name}" + + messages.append( + ProtoMessage( + name=message_name, + fields=fields, + description=message_type.description, + source=source, + nested_messages=nested_messages, + ) + ) + return messages + + def _build_message_fields( + self, message_type: GraphQLObjectType | GraphQLInterfaceType + ) -> tuple[list[ProtoField], list[ProtoMessage]]: + """Build Pydantic models for fields in a message.""" + fields = [] + nested_messages = [] + field_number = 1 + + for field_name, field in message_type.fields.items(): + if is_instance_tag_field(field_name) and self.expanded_instances: + continue + + field_type = field.type + unwrapped_type = get_named_type(field_type) + + proto_field_name = field_name + proto_field_type = None + expanded_message_name = None + + if is_object_type(unwrapped_type): + object_type = cast(GraphQLObjectType, unwrapped_type) + expanded_instances = self._get_expanded_instances(object_type) + if expanded_instances: + proto_field_name, proto_field_type, nested_message = self._handle_expanded_instance_field( + object_type, message_type, expanded_instances + ) + nested_messages.append(nested_message) + expanded_message_name = proto_field_type + + if proto_field_type is None: + proto_field_type = self._get_field_proto_type(field.type) + proto_field_name = self._escape_field_name(field_name) + + validation_rules = None if expanded_message_name else self.process_directives(field, proto_field_type) + + fields.append( + ProtoField( + name=proto_field_name, + type=proto_field_type, + number=field_number, + description=field.description, + validation_rules=validation_rules, + ) + ) + field_number += 1 + + return fields, nested_messages + + def _build_unions(self, union_types: list[GraphQLUnionType]) -> list[ProtoUnion]: + """Build Pydantic models for union types.""" + unions = [] + for union_type in union_types: + members = [ + ProtoField( + name=member_type.name, + type=member_type.name, + number=index + 1, + ) + for index, member_type in enumerate(union_type.types) + ] + unions.append( + ProtoUnion( + name=union_type.name, + members=members, + description=union_type.description, + source=union_type.name, + ) + ) + return unions + + def _build_flattened_fields( + self, message_types: list[GraphQLObjectType | GraphQLInterfaceType] + ) -> tuple[list[ProtoField], set[str], set[str]]: + """Build flattened fields for flatten_naming mode. + + Returns: + tuple: (flattened_fields, referenced_types, flattened_root_types) + """ + type_cache = {type_def.name: type_def for type_def in message_types} + + if self.root_type: + root_object = type_cache.get(self.root_type) + if not root_object: + log.warning(f"Root type '{self.root_type}' not found, creating empty message") + return [], set(), set() + + fields, referenced_types, _ = self._flatten_fields(root_object, root_object.name, message_types, 1) + return fields, referenced_types, {self.root_type} + + root_level_type_names = get_root_level_types_from_query(self.graphql_schema, self.selection_query) + if not root_level_type_names: + log.warning("No root-level types found in selection query, creating empty message") + return [], set(), set() + + all_fields: list[ProtoField] = [] + all_referenced_types: set[str] = set() + flattened_root_types: set[str] = set() + field_counter = 1 + + for type_name in root_level_type_names: + root_object = type_cache.get(type_name) + if not root_object: + log.warning(f"Root-level type '{type_name}' not found in message types") + continue + + flattened_root_types.add(type_name) + fields, referenced_types, field_counter = self._flatten_fields( + root_object, type_name, message_types, field_counter, type_cache + ) + all_fields.extend(fields) + all_referenced_types.update(referenced_types) + + return all_fields, all_referenced_types, flattened_root_types + + def _create_proto_field_with_validation( + self, field: GraphQLField, field_name: str, proto_type: str, field_number: int + ) -> ProtoField: + """Create a ProtoField with validation rules from directives.""" + validation_rules = self.process_directives(field, proto_type) + return ProtoField( + name=field_name, + type=proto_type, + number=field_number, + description=field.description, + validation_rules=validation_rules, + ) + + def _should_flatten_field(self, unwrapped_type: GraphQLType, is_list: bool) -> bool: + """Check if a field should be recursively flattened into parent fields.""" + if is_list: + return False + if is_union_type(unwrapped_type): + return False + return is_object_type(unwrapped_type) or is_interface_type(unwrapped_type) + + def _add_type_with_dependencies(self, type_name: str, referenced_types: set[str]) -> None: + """Add a type and all its transitive dependencies to the referenced_types set.""" + dependencies = get_referenced_types(self.graphql_schema, type_name, include_instance_tag_fields=True) + for dependency in dependencies: + if isinstance(dependency, GraphQLNamedType): + referenced_types.add(dependency.name) + + def _flatten_fields( + self, + object_type: GraphQLObjectType | GraphQLInterfaceType, + prefix: str, + all_types: list[GraphQLObjectType | GraphQLInterfaceType], + field_counter: int, + type_cache: dict[str, GraphQLObjectType | GraphQLInterfaceType] | None = None, + ) -> tuple[list[ProtoField], set[str], int]: + """Recursively flatten fields with prefix.""" + if type_cache is None: + type_cache = {type_def.name: type_def for type_def in all_types} + + fields: list[ProtoField] = [] + referenced_types: set[str] = set() + + for field_name, field in object_type.fields.items(): + if is_instance_tag_field(field_name) and self.expanded_instances: + continue + + field_type = field.type + unwrapped_type = get_named_type(field_type) + + if is_object_type(unwrapped_type): + object_type_cast = cast(GraphQLObjectType, unwrapped_type) + expanded_instances = self._get_expanded_instances(object_type_cast) + if expanded_instances: + nested_type = type_cache.get(object_type_cast.name) + if nested_type: + for expanded_instance in expanded_instances: + expanded_prefix = f"{prefix}_{field_name}_{expanded_instance.replace('.', '_')}" + nested_fields, nested_referenced, field_counter = self._flatten_fields( + nested_type, expanded_prefix, all_types, field_counter, type_cache + ) + fields.extend(nested_fields) + referenced_types.update(nested_referenced) + continue + + inner = field_type.of_type if is_non_null_type(field_type) else field_type + is_list = is_list_type(inner) + should_flatten = self._should_flatten_field(unwrapped_type, is_list) + + flattened_name = f"{prefix}_{field_name}" + proto_type = self._get_field_proto_type(field_type) + + if should_flatten: + named_unwrapped_type = cast(GraphQLObjectType | GraphQLInterfaceType, unwrapped_type) + nested_type = type_cache.get(named_unwrapped_type.name) + + if nested_type: + nested_fields, nested_referenced, field_counter = self._flatten_fields( + nested_type, flattened_name, all_types, field_counter, type_cache + ) + fields.extend(nested_fields) + referenced_types.update(nested_referenced) + else: + raise ValueError(f"Type '{named_unwrapped_type.name}' not found in available types") + else: + if is_list and (is_object_type(unwrapped_type) or is_interface_type(unwrapped_type)): + named_type = cast(GraphQLObjectType | GraphQLInterfaceType, unwrapped_type) + self._add_type_with_dependencies(named_type.name, referenced_types) + elif is_union_type(unwrapped_type): + union_type_cast = cast(GraphQLUnionType, unwrapped_type) + self._add_type_with_dependencies(union_type_cast.name, referenced_types) + for member_type in union_type_cast.types: + self._add_type_with_dependencies(member_type.name, referenced_types) + + fields.append( + self._create_proto_field_with_validation(field, flattened_name, proto_type, field_counter) + ) + field_counter += 1 + + return fields, referenced_types, field_counter + + def _get_field_proto_type(self, field_type: GraphQLType) -> str: + """Get the Protobuf type string for a GraphQL field type.""" + proto_type = self._get_base_proto_type(field_type) + + if not is_non_null_type(field_type): + return f"optional {proto_type}" + return proto_type + + def _get_base_proto_type(self, field_type: GraphQLType) -> str: + """Get the base Protobuf type string without optional prefix.""" + if is_non_null_type(field_type): + return self._get_base_proto_type(cast(GraphQLNonNull[Any], field_type).of_type) + + if is_list_type(field_type): + list_type = cast(GraphQLList[Any], field_type) + item_type = self._get_base_proto_type(list_type.of_type) + return f"repeated {item_type}" + + if is_scalar_type(field_type): + scalar_type = cast(GraphQLScalarType, field_type) + return GRAPHQL_SCALAR_TO_PROTOBUF.get(scalar_type.name, "string") + + if is_enum_type(field_type): + enum_type = cast(GraphQLEnumType, field_type) + return f"{enum_type.name}.Enum" + + if is_object_type(field_type) or is_interface_type(field_type): + named_type = cast(GraphQLObjectType | GraphQLInterfaceType, field_type) + return named_type.name + + if is_union_type(field_type): + union_type = cast(GraphQLUnionType, field_type) + return union_type.name + + return "string" + + def _escape_field_name(self, name: str) -> str: + """Escape field names that conflict with Protobuf reserved keywords.""" + if name in PROTOBUF_RESERVED_KEYWORDS: + return f"_{name}_" + return name + + def process_directives(self, field: GraphQLField, proto_type: str) -> str | None: + """Process GraphQL directives and convert them to protovalidate constraints.""" + rules = [] + + if is_non_null_type(field.type): + rules.append("(buf.validate.field).required = true") + + repeated_rules = [] + + if has_given_directive(field, "noDuplicates"): + unwrapped_type = get_named_type(field.type) + if is_scalar_type(unwrapped_type) or is_enum_type(unwrapped_type): + repeated_rules.append("unique: true") + + cardinality = get_cardinality(field) + if cardinality: + if cardinality.min is not None: + repeated_rules.append(f"min_items: {cardinality.min}") + if cardinality.max is not None: + repeated_rules.append(f"max_items: {cardinality.max}") + + if repeated_rules: + rules.append(f"(buf.validate.field).repeated = {{{', '.join(repeated_rules)}}}") + + if has_given_directive(field, "range"): + args = get_directive_arguments(field, "range") + scalar_type = self._get_validation_type(proto_type) + if scalar_type: + range_rules = [] + if "min" in args: + range_rules.append(f"gte: {args['min']}") + if "max" in args: + range_rules.append(f"lte: {args['max']}") + if range_rules: + rules.append(f"(buf.validate.field).{scalar_type} = {{{', '.join(range_rules)}}}") + + if rules: + return f"[{', '.join(rules)}]" + return None + + def _get_validation_type(self, proto_type: str) -> str | None: + """Get the protovalidate scalar type from protobuf type.""" + validation_type = proto_type.replace("repeated ", "").replace("optional ", "") + return validation_type if validation_type in PROTOBUF_DATA_TYPES else None + + def _handle_expanded_instance_field( + self, + object_type: GraphQLObjectType, + message_type: GraphQLObjectType | GraphQLInterfaceType, + expanded_instances: list[str], + ) -> tuple[str, str, ProtoMessage]: + """Handle expanded instance fields, returning field name, type, and nested message.""" + prefixed_message_name = f"{message_type.name}_{object_type.name}" + nested_message = self._build_nested_message_structure( + prefixed_message_name, expanded_instances, object_type.name + ) + + field_name_to_use = object_type.name + if self.naming_config: + target_case = get_target_case_for_element("field", "object", self.naming_config) + if target_case: + field_name_to_use = convert_name(object_type.name, target_case) + + return (self._escape_field_name(field_name_to_use), nested_message.name, nested_message) + + def _build_nested_message_structure( + self, + message_name: str, + instance_paths: list[str], + target_type: str, + ) -> ProtoMessage: + """Create nested message structure for expanded instance tags.""" + message = ProtoMessage(name=message_name, fields=[], nested_messages=[], source=None) + child_paths_by_level: dict[str, list[str]] = {} + field_counter = 1 + + for instance_path in instance_paths: + instance_path_parts = instance_path.split(".") + if len(instance_path_parts) > 1: + root_level_name = instance_path_parts[0] + remaining_path = ".".join(instance_path_parts[1:]) + child_paths_by_level.setdefault(root_level_name, []).append(remaining_path) + else: + message.fields.append(ProtoField(name=instance_path_parts[0], type=target_type, number=field_counter)) + field_counter += 1 + + for root_level_name, child_paths in child_paths_by_level.items(): + child_message_name = f"{message_name}_{root_level_name}" + child_message = self._build_nested_message_structure(child_message_name, child_paths, target_type) + message.nested_messages.append(child_message) + message.fields.append(ProtoField(name=root_level_name, type=child_message.name, number=field_counter)) + field_counter += 1 + + return message + + def _get_expanded_instances(self, object_type: GraphQLObjectType) -> list[str] | None: + """Get expanded instances if the type has a valid instance tag.""" + if not self.expanded_instances: + return None + + instance_tag_object = get_instance_tag_object(object_type, self.graphql_schema) + if not instance_tag_object: + return None + + return expand_instance_tag(instance_tag_object, self.naming_config) diff --git a/src/s2dm/exporters/shacl.py b/src/s2dm/exporters/shacl.py index 5a16356f..ba6905fa 100644 --- a/src/s2dm/exporters/shacl.py +++ b/src/s2dm/exporters/shacl.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from pathlib import Path from typing import Any, cast from graphql import ( @@ -18,8 +17,12 @@ from s2dm.exporters.utils.extraction import get_all_object_types from s2dm.exporters.utils.field import Cardinality, FieldCase, get_cardinality, get_field_case_extended, print_field_sdl from s2dm.exporters.utils.graphql_type import is_introspection_or_root_type -from s2dm.exporters.utils.instance_tag import expand_instance_tag, get_instance_tag_object, has_valid_instance_tag_field -from s2dm.exporters.utils.schema import load_schema_with_naming +from s2dm.exporters.utils.instance_tag import ( + expand_instance_tag, + get_instance_tag_object, + has_valid_instance_tag_field, + is_instance_tag_field, +) SUPPORTED_FIELD_CASES = { FieldCase.DEFAULT, @@ -53,7 +56,7 @@ def add_comment_to_property_node(field: GraphQLField, property_node: BNode, grap def translate_to_shacl( - schema_paths: list[Path], + schema: GraphQLSchema, shapes_namespace: str, shapes_namespace_prefix: str, model_namespace: str, @@ -67,7 +70,6 @@ def translate_to_shacl( Namespace(model_namespace), Namespace(model_namespace_prefix), ) - schema = load_schema_with_naming(schema_paths, naming_config) graph = Graph() graph.bind(namespaces.shapes_prefix, namespaces.shapes) graph.bind(namespaces.model_prefix, namespaces.model) @@ -199,7 +201,7 @@ def process_field( Skipping field '{field_name}'.""" ) return None - elif field_name == "instanceTag": + elif is_instance_tag_field(field_name): log.debug( f"Skipping field '{field_name}'. It is a reserved field and its likely already " + "processed as expanded instances.", diff --git a/src/s2dm/exporters/utils/extraction.py b/src/s2dm/exporters/utils/extraction.py index 0593c5ec..90277033 100644 --- a/src/s2dm/exporters/utils/extraction.py +++ b/src/s2dm/exporters/utils/extraction.py @@ -1,4 +1,15 @@ -from graphql import GraphQLNamedType, GraphQLObjectType, GraphQLSchema +from graphql import ( + DocumentNode, + FieldNode, + GraphQLNamedType, + GraphQLObjectType, + GraphQLSchema, + OperationDefinitionNode, + OperationType, + get_named_type, + is_interface_type, + is_object_type, +) from s2dm.exporters.utils.directive import has_given_directive from s2dm.exporters.utils.graphql_type import is_introspection_type @@ -34,3 +45,40 @@ def get_all_object_types( def get_all_objects_with_directive(objects: list[GraphQLObjectType], directive_name: str) -> list[GraphQLObjectType]: # TODO: Extend this function to return all objects that have any directive is directive_name is None return [o for o in objects if has_given_directive(o, directive_name)] + + +def get_root_level_types_from_query(schema: GraphQLSchema, selection_query: DocumentNode | None) -> list[str]: + """Extract root-level type names from the selection query. + + Args: + schema: The GraphQL schema + selection_query: The selection query document + + Returns: + List of type names that are selected at the root level of the query + """ + query_type = schema.query_type + if not selection_query or not query_type: + return [] + + root_type_names: list[str] = [] + + for definition in selection_query.definitions: + if not isinstance(definition, OperationDefinitionNode) or definition.operation != OperationType.QUERY: + continue + + for selection in definition.selection_set.selections: + if not isinstance(selection, FieldNode): + continue + + field_name = selection.name.value + if field_name not in query_type.fields: + continue + + field = query_type.fields[field_name] + field_type = get_named_type(field.type) + + if is_object_type(field_type) or is_interface_type(field_type): + root_type_names.append(field_type.name) + + return root_type_names diff --git a/src/s2dm/exporters/utils/instance_tag.py b/src/s2dm/exporters/utils/instance_tag.py index 740cbc84..5e6253d9 100644 --- a/src/s2dm/exporters/utils/instance_tag.py +++ b/src/s2dm/exporters/utils/instance_tag.py @@ -9,6 +9,10 @@ from s2dm.exporters.utils.naming import apply_naming_to_instance_values +def is_instance_tag_field(field_name: str) -> bool: + return field_name == "instanceTag" + + def get_all_expanded_instance_tags( schema: GraphQLSchema, naming_config: dict[str, Any] | None = None, diff --git a/src/s2dm/exporters/utils/schema_loader.py b/src/s2dm/exporters/utils/schema_loader.py index a3851b71..e691494d 100644 --- a/src/s2dm/exporters/utils/schema_loader.py +++ b/src/s2dm/exporters/utils/schema_loader.py @@ -27,6 +27,7 @@ is_union_type, print_schema, ) +from graphql import validate as graphql_validate from graphql.language.ast import SelectionSetNode from s2dm import log @@ -296,13 +297,16 @@ def ensure_query(schema: GraphQLSchema) -> GraphQLSchema: return schema -def get_referenced_types(graphql_schema: GraphQLSchema, root_type: str) -> set[GraphQLType]: +def get_referenced_types( + graphql_schema: GraphQLSchema, root_type: str, include_instance_tag_fields: bool = False +) -> set[GraphQLType]: """ Find all GraphQL types referenced from the root type through graph traversal. Args: graphql_schema: The GraphQL schema root_type: The root type to start traversal from + include_instance_tag_fields: Whether to traverse fields of @instanceTag types to find their dependencies Returns: Set[GraphQLType]: Set of referenced GraphQL type objects @@ -325,8 +329,10 @@ def visit_type(type_name: str) -> None: referenced.add(type_def) - if is_object_type(type_def) and not has_given_directive(cast(GraphQLObjectType, type_def), "instanceTag"): - visit_object_type(cast(GraphQLObjectType, type_def)) + if is_object_type(type_def): + object_type = cast(GraphQLObjectType, type_def) + if not has_given_directive(object_type, "instanceTag") or include_instance_tag_fields: + visit_object_type(object_type) elif is_interface_type(type_def): visit_interface_type(cast(GraphQLInterfaceType, type_def)) elif is_union_type(type_def): @@ -371,6 +377,21 @@ def visit_field_type(field_type: GraphQLType) -> None: return referenced +def validate_schema(schema: GraphQLSchema, document: DocumentNode) -> GraphQLSchema | None: + log.debug("Validating schema against the provided document") + + errors = graphql_validate(schema, document) + if errors: + log.error("Schema validation failed:") + for error in errors: + log.error(f" - {error}") + return None + + log.debug("Schema validation succeeded") + + return schema + + def prune_schema_using_query_selection(schema: GraphQLSchema, document: DocumentNode) -> GraphQLSchema: """ Filter schema by pruning unselected fields and types based on query selections. @@ -385,6 +406,9 @@ def prune_schema_using_query_selection(schema: GraphQLSchema, document: Document if not schema.query_type: raise ValueError("Schema has no query type defined") + if validate_schema(schema, document) is None: + raise ValueError("Schema validation failed") + fields_to_keep: dict[str, set[str]] = {} types_to_keep: set[str] = set() @@ -439,6 +463,8 @@ def collect_selections(type_name: str, selection_set: SelectionSetNode) -> None: if not query_operations: raise ValueError("No query operation found in selection document") + log.debug("Composing filtered schema based on query selections") + query_operation = query_operations[0] if hasattr(query_operation, "selection_set"): collect_selections(schema.query_type.name, query_operation.selection_set) @@ -479,6 +505,6 @@ def collect_selections(type_name: str, selection_set: SelectionSetNode) -> None: schema.directives = tuple(directive for directive in schema.directives if directive.name in directives_used) - log.info(f"Composed filtered schema with {len(fields_to_keep)} object types") + log.debug(f"Composed filtered schema with {len(fields_to_keep)} object types") return schema diff --git a/src/s2dm/exporters/vspec.py b/src/s2dm/exporters/vspec.py index 6f0641b1..090d49aa 100644 --- a/src/s2dm/exporters/vspec.py +++ b/src/s2dm/exporters/vspec.py @@ -22,6 +22,7 @@ get_all_expanded_instance_tags, get_instance_tag_dict, get_instance_tag_object, + is_instance_tag_field, ) from s2dm.exporters.utils.naming import apply_naming_to_instance_values from s2dm.exporters.utils.schema import load_schema_with_naming @@ -191,9 +192,8 @@ def represent_list(self, data: Iterable[Any]) -> yaml.SequenceNode: CustomDumper.add_representer(list, CustomDumper.represent_list) -def translate_to_vspec(schema_paths: list[Path], naming_config: dict[str, Any] | None = None) -> str: +def translate_to_vspec(schema: GraphQLSchema, naming_config: dict[str, Any] | None = None) -> str: """Translate a GraphQL schema to YAML.""" - schema = load_schema_with_naming(schema_paths, naming_config) all_object_types = get_all_object_types(schema) log.debug(f"Object types: {all_object_types}") @@ -343,7 +343,7 @@ def process_field( field_dict["type"] = vss_type return {concat_field_name: field_dict} - elif isinstance(output_type, GraphQLObjectType) and field_name != "instanceTag": + elif isinstance(output_type, GraphQLObjectType) and not is_instance_tag_field(field_name): # Collect nested structures # nested_types.append(f"{object_type.name}.{output_type}({field_name})") nested_types.append((object_type.name, output_type.name)) @@ -411,7 +411,9 @@ def main( schemas: list[Path], output: Path, ) -> None: - result = translate_to_vspec(schemas) + # TODO: deprecate + graphql_schema = load_schema_with_naming(schemas, None) + result = translate_to_vspec(graphql_schema) log.info(f"Result:\n{result}") with open(output, "w", encoding="utf-8") as output_file: log.info(f"Writing data to '{output}'") diff --git a/tests/conftest.py b/tests/conftest.py index c5ac39f5..bbe8901a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,6 +39,7 @@ class TestSchemaData: VALID_QUERY = TESTS_DATA_DIR / "valid_query.graphql" INVALID_QUERY = TESTS_DATA_DIR / "invalid_query.graphql" + SCHEMA1_QUERY = TESTS_DATA_DIR / "schema1_query.graphql" # Version bump test schemas BASE_SCHEMA = TESTS_DATA_DIR / "base.graphql" diff --git a/tests/data/schema1_query.graphql b/tests/data/schema1_query.graphql new file mode 100644 index 00000000..59d2fb1a --- /dev/null +++ b/tests/data/schema1_query.graphql @@ -0,0 +1,15 @@ +query Selection { + vehicle { + averageSpeed + lowVoltageSystemState + adas { + abs { + isEngaged + } + obstacleDetection_s { + isEnabled + warningType + } + } + } +} diff --git a/tests/test_e2e_cli.py b/tests/test_e2e_cli.py index 5cbff076..e29b8e57 100644 --- a/tests/test_e2e_cli.py +++ b/tests/test_e2e_cli.py @@ -77,6 +77,95 @@ def test_export_vspec(runner: CliRunner, tmp_outputs: Path) -> None: assert "Vehicle_ADAS_ObstacleDetection:" in content +def test_export_jsonschema(runner: CliRunner, tmp_outputs: Path) -> None: + out = tmp_outputs / "jsonschema.yaml" + result = runner.invoke( + cli, ["export", "jsonschema", "-s", str(TSD.SAMPLE1_1), "-s", str(TSD.SAMPLE1_2), "-o", str(out)] + ) + assert result.exit_code == 0, result.output + assert out.exists() + with open(out, encoding="utf-8") as f: + content = f.read() + + assert '"Vehicle"' in content + assert '"Vehicle_ADAS_ObstacleDetection"' in content + + +def test_export_protobuf(runner: CliRunner, tmp_outputs: Path) -> None: + out = tmp_outputs / "schema.proto" + result = runner.invoke( + cli, + [ + "export", + "protobuf", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-q", + str(TSD.SCHEMA1_QUERY), + "-o", + str(out), + "-r", + "Vehicle", + ], + ) + assert result.exit_code == 0, result.output + assert out.exists() + with open(out, encoding="utf-8") as f: + content = f.read() + + assert "package" not in content + + assert "message Vehicle" in content + assert "message Vehicle_ADAS" in content + assert "message Vehicle_ADAS_ObstacleDetection" in content + + assert "message Vehicle_ADAS_ObstacleDetection_WarningType_Enum" in content + assert "enum Enum" in content + + assert "optional float averageSpeed = 1;" in content + assert "optional bool isEngaged = 1;" in content + + +def test_export_protobuf_flattened_naming(runner: CliRunner, tmp_outputs: Path) -> None: + out = tmp_outputs / "schema.proto" + result = runner.invoke( + cli, + [ + "export", + "protobuf", + "-s", + str(TSD.SAMPLE1_1), + "-s", + str(TSD.SAMPLE1_2), + "-q", + str(TSD.SCHEMA1_QUERY), + "-o", + str(out), + "-r", + "Vehicle", + "-f", + "-p", + "package.name", + ], + ) + assert result.exit_code == 0, result.output + assert out.exists() + with open(out, encoding="utf-8") as f: + content = f.read() + + assert "package package.name;" in content + + assert "message Selection" in content + + assert "message Vehicle_ADAS_ObstacleDetection_WarningType_Enum" in content + assert "enum Enum" in content + + assert "optional float Vehicle_averageSpeed = 1;" in content + assert "optional bool Vehicle_adas_abs_isEngaged = 3;" in content + + def test_generate_skos_skeleton(runner: CliRunner, tmp_outputs: Path) -> None: out = tmp_outputs / "skos_skeleton.ttl" result = runner.invoke( diff --git a/tests/test_expanded_instances/test_expanded_instances.py b/tests/test_expanded_instances/test_expanded_instances.py index bfe1974f..eed2067a 100644 --- a/tests/test_expanded_instances/test_expanded_instances.py +++ b/tests/test_expanded_instances/test_expanded_instances.py @@ -5,6 +5,7 @@ import pytest from s2dm.exporters.jsonschema import translate_to_jsonschema +from s2dm.exporters.utils.schema import load_schema_with_naming class TestExpandedInstances: @@ -17,7 +18,8 @@ def test_schema_path(self) -> list[Path]: def test_default_behavior_creates_arrays(self, test_schema_path: list[Path]) -> None: """Test that the default behavior creates arrays for instance tagged objects.""" - result = translate_to_jsonschema(test_schema_path, root_type="Cabin") + graphql_schema = load_schema_with_naming(test_schema_path, None) + result = translate_to_jsonschema(graphql_schema, root_type="Cabin") schema = json.loads(result) # Check that doors is an array @@ -32,7 +34,8 @@ def test_default_behavior_creates_arrays(self, test_schema_path: list[Path]) -> def test_expanded_instances_creates_nested_objects(self, test_schema_path: list[Path]) -> None: """Test that expanded_instances=True creates nested object structures.""" - result = translate_to_jsonschema(test_schema_path, root_type="Cabin", expanded_instances=True) + graphql_schema = load_schema_with_naming(test_schema_path, None) + result = translate_to_jsonschema(graphql_schema, root_type="Cabin", expanded_instances=True) schema = json.loads(result) # Check that Doors becomes Door (singular) and is a nested object structure @@ -56,7 +59,8 @@ def test_expanded_instances_creates_nested_objects(self, test_schema_path: list[ def test_expanded_instances_for_seats(self, test_schema_path: list[Path]) -> None: """Test expanded instances for seats with 3-level nesting.""" - result = translate_to_jsonschema(test_schema_path, root_type="Cabin", expanded_instances=True) + graphql_schema = load_schema_with_naming(test_schema_path, None) + result = translate_to_jsonschema(graphql_schema, root_type="Cabin", expanded_instances=True) schema = json.loads(result) # Check that Seats becomes Seat (singular) and is a nested object structure @@ -120,7 +124,8 @@ def test_non_instance_tagged_objects_remain_arrays(self, test_schema_path: list[ temp_path = Path(f.name) try: - result = translate_to_jsonschema([temp_path], root_type="TestObject", expanded_instances=True) + graphql_schema = load_schema_with_naming([temp_path], None) + result = translate_to_jsonschema(graphql_schema, root_type="TestObject", expanded_instances=True) schema = json.loads(result) # Instance-tagged doors should be expanded and use singular name @@ -138,7 +143,8 @@ def test_non_instance_tagged_objects_remain_arrays(self, test_schema_path: list[ def test_expanded_instances_with_strict_mode(self, test_schema_path: list[Path]) -> None: """Test that expanded instances work correctly with strict mode.""" - result = translate_to_jsonschema(test_schema_path, root_type="Cabin", strict=True, expanded_instances=True) + graphql_schema = load_schema_with_naming(test_schema_path, None) + result = translate_to_jsonschema(graphql_schema, root_type="Cabin", strict=True, expanded_instances=True) schema = json.loads(result) # Should still create expanded structure with singular naming @@ -152,8 +158,9 @@ def test_expanded_instances_with_strict_mode(self, test_schema_path: list[Path]) def test_singular_naming_for_expanded_instances(self, test_schema_path: list[Path]) -> None: """Test that expanded instances use singular type names instead of field names.""" - result_normal = translate_to_jsonschema(test_schema_path, root_type="Cabin", expanded_instances=False) - result_expanded = translate_to_jsonschema(test_schema_path, root_type="Cabin", expanded_instances=True) + graphql_schema = load_schema_with_naming(test_schema_path, None) + result_normal = translate_to_jsonschema(graphql_schema, root_type="Cabin", expanded_instances=False) + result_expanded = translate_to_jsonschema(graphql_schema, root_type="Cabin", expanded_instances=True) schema_normal = json.loads(result_normal) schema_expanded = json.loads(result_expanded) @@ -175,7 +182,8 @@ def test_nested_instances_use_refs_not_inline_expansion(self) -> None: # Create a nested schema path nested_schema_path = Path(__file__).parent / "test_nested_schema.graphql" - result = translate_to_jsonschema([nested_schema_path], root_type="Chassis", expanded_instances=True) + graphql_schema = load_schema_with_naming([nested_schema_path], None) + result = translate_to_jsonschema(graphql_schema, root_type="Chassis", expanded_instances=True) schema = json.loads(result) # Check that Chassis -> Axle uses proper expansion with $ref diff --git a/tests/test_expanded_instances/test_schema.graphql b/tests/test_expanded_instances/test_schema.graphql index 9b8c392b..a322bc0c 100644 --- a/tests/test_expanded_instances/test_schema.graphql +++ b/tests/test_expanded_instances/test_schema.graphql @@ -15,7 +15,7 @@ type DoorPosition @instanceTag { type Door { isLocked: Boolean - position: Int + position: Int @range(min: 0, max: 100) instanceTag: DoorPosition } @@ -23,6 +23,7 @@ type Vehicle { doors: [Door] @noDuplicates model: String year: Int + features: [String] @noDuplicates @cardinality(min: 1, max: 10) } enum SeatRowEnum { @@ -44,12 +45,12 @@ type SeatPosition @instanceTag { type Seat { isOccupied: Boolean - height: Int + height: Int @range(min: 0, max: 100) instanceTag: SeatPosition } type Cabin { seats: [Seat] @noDuplicates doors: [Door] @noDuplicates - temperature: Float + temperature: Float @range(min: -100, max: 100) } diff --git a/tests/test_protobuf.py b/tests/test_protobuf.py new file mode 100644 index 00000000..84505e8d --- /dev/null +++ b/tests/test_protobuf.py @@ -0,0 +1,1342 @@ +import re +from pathlib import Path +from typing import cast + +import pytest +from graphql import GraphQLField, GraphQLObjectType, GraphQLSchema, build_schema, parse + +from s2dm.exporters.protobuf import translate_to_protobuf +from s2dm.exporters.utils.schema import load_schema_with_naming +from s2dm.exporters.utils.schema_loader import prune_schema_using_query_selection + + +class TestProtobufExporter: + """Test suite for the Protobuf exporter.""" + + @pytest.fixture + def test_schema_path(self) -> list[Path]: + """Fixture providing path to test schema.""" + return [Path("tests/test_expanded_instances/test_schema.graphql")] + + def test_basic_scalar_types(self) -> None: + """Test that basic scalar types are correctly mapped to Protobuf types.""" + schema_str = """ + type ScalarType { + stringField: String + intField: Int + floatField: Float + boolField: Boolean + idField: ID + } + + type Query { + scalarType: ScalarType + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { scalarType { stringField intField floatField boolField idField } }") + result = translate_to_protobuf(schema, root_type="ScalarType", selection_query=selection_query) + + assert 'syntax = "proto3";' in result + assert re.search( + r"message ScalarType \{.*?" + r'option \(source\) = "ScalarType".*?;.*?' + r"optional string stringField = 1.*?;.*?" + r"optional int32 intField = 2.*?;.*?" + r"optional float floatField = 3.*?;.*?" + r"optional bool boolField = 4.*?;.*?" + r"optional string idField = 5.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "ScalarType message with source option and optional fields in order" + + assert "message Message {" not in result + + def test_custom_scalars(self) -> None: + """Test that custom scalar types are mapped correctly.""" + schema_str = """ + scalar Int8 + scalar UInt8 + scalar Int16 + scalar UInt16 + scalar UInt32 + scalar Int64 + scalar UInt64 + + type CustomScalarType { + int8Field: Int8 + uint8Field: UInt8 + int16Field: Int16 + uint16Field: UInt16 + uint32Field: UInt32 + int64Field: Int64 + uint64Field: UInt64 + } + + type Query { + customScalarType: CustomScalarType + } + """ + schema = build_schema(schema_str) + selection_query = parse( + "query Selection { " + "customScalarType { int8Field uint8Field int16Field uint16Field uint32Field int64Field uint64Field } " + "}" + ) + result = translate_to_protobuf(schema, root_type="CustomScalarType", selection_query=selection_query) + + assert re.search( + r"message CustomScalarType \{.*?" + r'option \(source\) = "CustomScalarType".*?;.*?' + r"optional int32 int8Field = 1.*?;.*?" + r"optional uint32 uint8Field = 2.*?;.*?" + r"optional int32 int16Field = 3.*?;.*?" + r"optional uint32 uint16Field = 4.*?;.*?" + r"optional uint32 uint32Field = 5.*?;.*?" + r"optional int64 int64Field = 6.*?;.*?" + r"optional uint64 uint64Field = 7.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "CustomScalarType message with source option and optional custom scalar fields in order" + + def test_enum_type_with_unspecified(self) -> None: + """Test that enums include UNSPECIFIED default value.""" + schema_str = """ + enum LockStatus { + LOCKED + UNLOCKED + PARTIAL + } + + type Door { + lockStatus: LockStatus + } + + type Query { + door: Door + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { door { lockStatus } }") + result = translate_to_protobuf(schema, root_type="Door", selection_query=selection_query) + + assert re.search( + r"message LockStatus \{.*?" + r'option \(source\) = "LockStatus".*?;.*?' + r"enum Enum \{.*?" + r"LOCKSTATUS_UNSPECIFIED = 0.*?;.*?" + r"LOCKED = 1.*?;.*?" + r"UNLOCKED = 2.*?;.*?" + r"PARTIAL = 3.*?;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "LockStatus enum wrapped in message with source option and values in order" + + assert re.search( + r"message Door \{.*?" + r'option \(source\) = "Door".*?;.*?' + r"optional LockStatus.Enum lockStatus = 1.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Door message with source option and optional LockStatus.Enum lockStatus = 1 field" + + def test_list_to_repeated(self) -> None: + """Test that GraphQL lists are converted to repeated fields and non-null handling.""" + schema_str = """ + type Vehicle { + features: [String] + requiredFeatures: [String!]! + model: String + vin: String! + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { features requiredFeatures model vin } }") + result = translate_to_protobuf(schema, root_type="Vehicle", selection_query=selection_query) + + assert re.search( + r"message Vehicle \{.*?" + r'option \(source\) = "Vehicle".*?;.*?' + r"optional repeated string features = 1.*?;.*?" + r"repeated string requiredFeatures = 2 \[\(buf\.validate\.field\)\.required = true\].*?;.*?" + r"optional string model = 3.*?;.*?" + r"string vin = 4 \[\(buf\.validate\.field\)\.required = true\].*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Vehicle message with optional for nullable fields and required for non-nullable" + + def test_nested_objects_standard_mode(self) -> None: + """Test nested object types in standard mode.""" + schema_str = """ + type Speed { + average: Float + current: Float + } + + type Vehicle { + speed: Speed + model: String + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { speed { average current } model } }") + result = translate_to_protobuf(schema, root_type="Vehicle", selection_query=selection_query) + + assert re.search( + r"message Speed \{.*?" + r'option \(source\) = "Speed".*?;.*?' + r"optional float average = 1.*?;.*?" + r"optional float current = 2.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Speed message with source option and optional fields in order" + + assert re.search( + r"message Vehicle \{.*?" + r'option \(source\) = "Vehicle".*?;.*?' + r"optional Speed speed = 1.*?;.*?" + r"optional string model = 2.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Vehicle message with source option and optional fields in order" + + def test_query_type_renamed_with_selection_query(self) -> None: + """Test that Query type is renamed to the selection query operation name in standard mode.""" + schema_str = """ + type Speed { + average: Float + current: Float + } + + type Vehicle { + speed: Speed + model: String + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { speed { average } model } }") + result = translate_to_protobuf(schema, selection_query=selection_query) + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"optional Vehicle vehicle = 1.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Query type renamed to Selection with query source option and fields" + + assert "message Query {" not in result, "Original Query type name should not appear" + + def test_flattened_naming_mode(self) -> None: + """Test that flatten_naming mode creates flattened field names.""" + schema_str = """ + type Average { + value: Float + timestamp: Int + } + + type Speed { + average: Average + current: Float + } + + type Vehicle { + speed: Speed + model: String + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { speed { average { value timestamp } current } model } }") + result = translate_to_protobuf( + schema, root_type="Vehicle", flatten_naming=True, selection_query=selection_query + ) + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"float Vehicle_speed_average_value = 1.*?;.*?" + r"int32 Vehicle_speed_average_timestamp = 2.*?;.*?" + r"float Vehicle_speed_current = 3.*?;.*?" + r"string Vehicle_model = 4.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Selection with source option and flattened fields in order" + + assert "message Speed {" not in result + assert "message Average {" not in result + assert "message Vehicle {" not in result + + def test_flattened_naming_with_arrays_and_unions(self) -> None: + """Test that flatten_naming mode keeps arrays and unions as references with their definitions.""" + schema_str = """ + type Feature { + name: String + enabled: Boolean + } + + type Car { + brand: String + } + + type Truck { + capacity: Int + } + + union VehicleType = Car | Truck + + type Vehicle { + id: String + features: [Feature] + vehicleType: VehicleType + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { id features { name enabled } } }") + result = translate_to_protobuf( + schema, root_type="Vehicle", flatten_naming=True, selection_query=selection_query + ) + + assert re.search( + r"message Feature \{.*?" + r'option \(source\) = "Feature";.*?' + r"string name = 1;.*?" + r"bool enabled = 2;.*?" + r"\}", + result, + re.DOTALL, + ), "Feature message should be included as it's referenced by array" + + assert re.search( + r"message Car \{.*?" r'option \(source\) = "Car";.*?' r"string brand = 1;.*?" r"\}", + result, + re.DOTALL, + ), "Car message should be included as it's part of union" + + assert re.search( + r"message Truck \{.*?" r'option \(source\) = "Truck";.*?' r"int32 capacity = 1;.*?" r"\}", + result, + re.DOTALL, + ), "Truck message should be included as it's part of union" + + assert re.search( + r"message VehicleType \{.*?" + r'option \(source\) = "VehicleType";.*?' + r"oneof VehicleType \{.*?" + r"Car Car = 1;.*?" + r"Truck Truck = 2;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "VehicleType union should be included" + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"string Vehicle_id = 1;.*?" + r"repeated Feature Vehicle_features = 2;.*?" + r"\}", + result, + re.DOTALL, + ), "Selection with source option, flattened scalar and array reference" + + assert "message Vehicle {" not in result + + def test_package_name(self) -> None: + """Test that package name is included when specified.""" + schema_str = """ + type Vehicle { + model: String + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { model } }") + result = translate_to_protobuf( + schema, root_type="Vehicle", package_name="package.name", selection_query=selection_query + ) + + assert "package package.name;" in result + + def test_descriptions_as_comments(self) -> None: + """Test that type descriptions are converted to comments.""" + schema_str = ''' + """Represents a motor vehicle""" + type Vehicle { + """Vehicle identification number""" + vin: String + } + + type Query { + vehicle: Vehicle + } + ''' + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { vin } }") + result = translate_to_protobuf(schema, root_type="Vehicle", selection_query=selection_query) + + assert re.search( + r"// Represents a motor vehicle\s*\n\s*" + r"message Vehicle \{.*?" + r'option \(source\) = "Vehicle".*?;.*?' + r"string vin = 1; // Vehicle identification number.*?" + r"\}", + result, + re.DOTALL, + ), "Vehicle message with description comment, source option, and field with inline comment" + + def test_union_type_to_oneof(self) -> None: + """Test that union types are converted to oneof.""" + schema_str = """ + type Car { + brand: String + } + + type Truck { + capacity: Int + } + + union Vehicle = Car | Truck + + type TestType { + vehicle: Vehicle + } + + type Query { + testType: TestType + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { testType { vehicle } }") + result = translate_to_protobuf(schema, root_type="TestType", selection_query=selection_query) + + assert re.search( + r"message Car \{.*?" r'option \(source\) = "Car".*?;.*?' r"string brand = 1.*?;.*?" r"\}", result, re.DOTALL + ), "Car message with source option and field" + + assert re.search( + r"message Truck \{.*?" r'option \(source\) = "Truck".*?;.*?' r"int32 capacity = 1.*?;.*?" r"\}", + result, + re.DOTALL, + ), "Truck message with source option and field" + + assert re.search( + r"message Vehicle \{.*?" + r'option \(source\) = "Vehicle".*?;.*?' + r"oneof Vehicle \{.*?" + r"Car Car = 1.*?;.*?" + r"Truck Truck = 2.*?;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "Vehicle message with source option and oneof containing Car = 1 and Truck = 2" + + assert re.search( + r"message TestType \{.*?" r'option \(source\) = "TestType".*?;.*?' r"Vehicle vehicle = 1.*?;.*?" r"\}", + result, + re.DOTALL, + ), "TestType message with source option and Vehicle field" + + def test_interface_type(self) -> None: + """Test that interface types are converted to messages.""" + schema_str = """ + interface Vehicle { + vin: String! + } + + type ElectricVehicle implements Vehicle { + vin: String! + batteryCapacity: Int + } + + type Query { + electricVehicle: ElectricVehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { electricVehicle { vin batteryCapacity } }") + result = translate_to_protobuf(schema, root_type="ElectricVehicle", selection_query=selection_query) + + assert re.search( + r"message Vehicle \{.*?" r'option \(source\) = "Vehicle".*?;.*?' r"string vin = 1.*?;.*?" r"\}", + result, + re.DOTALL, + ), "Vehicle interface message with source option and string vin = 1" + + assert re.search( + r"message ElectricVehicle \{.*?" + r'option \(source\) = "ElectricVehicle".*?;.*?' + r"string vin = 1.*?;.*?" + r"int32 batteryCapacity = 2.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "ElectricVehicle message with source option and fields: string vin = 1, int32 batteryCapacity = 2" + + def test_include_instance_tag_types_without_expansion(self, test_schema_path: list[Path]) -> None: + """Test that types with @instanceTag directive are included when expansion is disabled.""" + graphql_schema = load_schema_with_naming(test_schema_path, None) + selection_query = parse("query Selection { cabin { seats { isOccupied } doors { isLocked } temperature } }") + result = translate_to_protobuf( + graphql_schema, root_type="Cabin", expanded_instances=False, selection_query=selection_query + ) + + assert re.search( + r"message DoorPosition \{.*?" + r'option \(source\) = "DoorPosition".*?;.*?' + r"RowEnum.Enum row = 1.*?;.*?" + r"SideEnum.Enum side = 2.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "DoorPosition message with source option and fields in order" + + assert re.search( + r"message SeatPosition \{.*?" + r'option \(source\) = "SeatPosition".*?;.*?' + r"SeatRowEnum.Enum row = 1.*?;.*?" + r"SeatPositionEnum.Enum position = 2.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "SeatPosition message with source option and fields in order" + + assert re.search( + r"message Seat \{.*?" + r'option \(source\) = "Seat".*?;.*?' + r"bool isOccupied = 1.*?;.*?" + r"int32 height = 2.*?" + r"SeatPosition instanceTag = 3.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Seat message with source option and fields in order" + + assert re.search( + r"message Door \{.*?" + r'option \(source\) = "Door".*?;.*?' + r"bool isLocked = 1.*?;.*?" + r"int32 position = 2.*?" + r"DoorPosition instanceTag = 3.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Door message with source option and fields in order" + + assert re.search( + r"message Cabin \{.*?" + r'option \(source\) = "Cabin".*?;.*?' + r"repeated Seat seats = 1.*?;.*?" + r"repeated Door doors = 2.*?;.*?" + r"float temperature = 3.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Cabin message with source option and fields in order" + + assert re.search( + r"message RowEnum \{.*?" + r'option \(source\) = "RowEnum";.*?' + r"enum Enum \{.*?" + r"ROWENUM_UNSPECIFIED = 0;.*?" + r"ROW1 = 1;.*?" + r"ROW2 = 2;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "RowEnum present with all values" + + assert re.search( + r"message SideEnum \{.*?" + r'option \(source\) = "SideEnum";.*?' + r"enum Enum \{.*?" + r"SIDEENUM_UNSPECIFIED = 0;.*?" + r"DRIVERSIDE = 1;.*?" + r"PASSENGERSIDE = 2;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "SideEnum present with all values" + + assert re.search( + r"message SeatRowEnum \{.*?" + r'option \(source\) = "SeatRowEnum";.*?' + r"enum Enum \{.*?" + r"SEATROWENUM_UNSPECIFIED = 0;.*?" + r"ROW1 = 1;.*?" + r"ROW2 = 2;.*?" + r"ROW3 = 3;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "SeatRowEnum present with all values" + + assert re.search( + r"message SeatPositionEnum \{.*?" + r'option \(source\) = "SeatPositionEnum";.*?' + r"enum Enum \{.*?" + r"SEATPOSITIONENUM_UNSPECIFIED = 0;.*?" + r"LEFT = 1;.*?" + r"CENTER = 2;.*?" + r"RIGHT = 3;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "SeatPositionEnum present with all values" + + def test_reserved_keyword_escaping(self) -> None: + """Test that Protobuf reserved keywords are escaped.""" + schema_str = """ + type Vehicle { + message: String + enum: Int + service: Boolean + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { message enum service } }") + result = translate_to_protobuf(schema, root_type="Vehicle", selection_query=selection_query) + + assert re.search( + r"message Vehicle \{.*?" + r'option \(source\) = "Vehicle".*?;.*?' + r"string _message_ = 1.*?;.*?" + r"int32 _enum_ = 2.*?;.*?" + r"bool _service_ = 3.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Vehicle message with source option and escaped fields in order" + + def test_validation_rules(self) -> None: + """Test that validation directives are converted to protovalidate constraints.""" + schema_str = """ + directive @range(min: Float, max: Float) on FIELD_DEFINITION + directive @cardinality(min: Int, max: Int) on FIELD_DEFINITION + directive @noDuplicates on FIELD_DEFINITION + + type Vehicle { + speed: Int @range(min: 0, max: 250) + engineTemp: Float @range(min: -40.0, max: 150.0) + sensors: [String] @noDuplicates + features: [String] @cardinality(min: 1, max: 10) + wheels: [Int] @noDuplicates @cardinality(min: 2, max: 6) + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { speed engineTemp sensors features wheels } }") + result = translate_to_protobuf(schema, root_type="Vehicle", selection_query=selection_query) + + assert re.search( + r"message Vehicle \{.*?" + r'option \(source\) = "Vehicle";.*?' + r"optional int32 speed = 1 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 250\}\];.*?" + r"optional float engineTemp = 2 \[\(buf\.validate\.field\)\.float = \{gte: -40\.0, lte: 150\.0\}\];.*?" + r"optional repeated string sensors = 3 \[\(buf\.validate\.field\)\.repeated = \{unique: true\}\];.*?" + r"optional repeated string features = 4 " + r"\[\(buf\.validate\.field\)\.repeated = \{min_items: 1, max_items: 10\}\];.*?" + r"optional repeated int32 wheels = 5 \[\(buf\.validate\.field\)\.repeated = " + r"\{unique: true, min_items: 2, max_items: 6\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Vehicle message with optional and validation rules on all fields" + + def test_required_validation_with_other_rules(self) -> None: + """Test that required validation works together with other validation rules.""" + schema_str = """ + directive @range(min: Float, max: Float) on FIELD_DEFINITION + directive @cardinality(min: Int, max: Int) on FIELD_DEFINITION + directive @noDuplicates on FIELD_DEFINITION + + type Vehicle { + speed: Int! @range(min: 0, max: 300) + tags: [String!]! @noDuplicates @cardinality(min: 1, max: 10) + vin: String! + } + + type Query { + vehicle: Vehicle + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { vehicle { speed tags vin } }") + result = translate_to_protobuf(schema, root_type="Vehicle", selection_query=selection_query) + + assert re.search( + r"message Vehicle \{.*?" + r'option \(source\) = "Vehicle";.*?' + r"int32 speed = 1 \[\(buf\.validate\.field\)\.required = true, " + r"\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 300\}\];.*?" + r"repeated string tags = 2 \[\(buf\.validate\.field\)\.required = true, " + r"\(buf\.validate\.field\)\.repeated = \{unique: true, min_items: 1, max_items: 10\}\];.*?" + r"string vin = 3 \[\(buf\.validate\.field\)\.required = true\];.*?" + r"\}", + result, + re.DOTALL, + ), "Vehicle message with required combined with other validation rules" + + def test_flatten_naming_without_expansion(self, test_schema_path: list[Path]) -> None: + """Test that flatten mode WITHOUT -e flag keeps arrays as repeated.""" + graphql_schema = load_schema_with_naming(test_schema_path, None) + selection_query = parse("query Selection { cabin { seats { isOccupied } doors { isLocked } temperature } }") + result = translate_to_protobuf( + graphql_schema, + root_type="Cabin", + flatten_naming=True, + expanded_instances=False, + selection_query=selection_query, + ) + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"repeated Seat Cabin_seats = 1.*?" + r"repeated Door Cabin_doors = 2.*?" + r"float Cabin_temperature = 3.*?" + r"\}", + result, + re.DOTALL, + ), ( + "Selection with source option and fields: repeated Seat Cabin_seats = 1, " + "repeated Door Cabin_doors = 2, float Cabin_temperature = 3" + ) + + def test_flatten_naming_includes_referenced_types_transitively(self, test_schema_path: list[Path]) -> None: + """Test that flatten mode includes types referenced by non-flattened types and their dependencies.""" + graphql_schema = load_schema_with_naming(test_schema_path, None) + selection_query = parse("query Selection { vehicle { doors { isLocked } model year features } }") + result = translate_to_protobuf( + graphql_schema, + root_type="Vehicle", + flatten_naming=True, + expanded_instances=False, + selection_query=selection_query, + ) + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"optional repeated Door Vehicle_doors = 1.*?" + r"optional string Vehicle_model = 2.*?" + r"optional int32 Vehicle_year = 3.*?" + r"optional repeated string Vehicle_features = 4.*?" + r"\}", + result, + re.DOTALL, + ), "Selection with source option and flattened Vehicle fields including repeated Door reference" + + assert re.search( + r"message Door \{.*?" + r'option \(source\) = "Door";.*?' + r"optional bool isLocked = 1;.*?" + r"optional int32 position = 2 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional DoorPosition instanceTag = 3;.*?" + r"\}", + result, + re.DOTALL, + ), "Door message should be included with DoorPosition reference" + + assert re.search( + r"message DoorPosition \{.*?" + r'option \(source\) = "DoorPosition";.*?' + r"RowEnum.Enum row = 1 \[\(buf\.validate\.field\)\.required = true\];.*?" + r"SideEnum.Enum side = 2 \[\(buf\.validate\.field\)\.required = true\];.*?" + r"\}", + result, + re.DOTALL, + ), "DoorPosition message should be included as it's referenced by Door" + + assert re.search( + r"message RowEnum \{.*?" + r'option \(source\) = "RowEnum";.*?' + r"enum Enum \{.*?" + r"ROWENUM_UNSPECIFIED = 0;.*?" + r"ROW1 = 1;.*?" + r"ROW2 = 2;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "RowEnum should be included as it's used by DoorPosition" + + assert re.search( + r"message SideEnum \{.*?" + r'option \(source\) = "SideEnum";.*?' + r"enum Enum \{.*?" + r"SIDEENUM_UNSPECIFIED = 0;.*?" + r"DRIVERSIDE = 1;.*?" + r"PASSENGERSIDE = 2;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "SideEnum should be included as it's used by DoorPosition" + + assert "SeatPosition" not in result, "SeatPosition should not be included as it's not referenced by Vehicle" + assert "SeatRowEnum" not in result, "SeatRowEnum should not be included as it's not referenced by Vehicle" + assert ( + "SeatPositionEnum" not in result + ), "SeatPositionEnum should not be included as it's not referenced by Vehicle" + + def test_expanded_instances_default(self, test_schema_path: list[Path]) -> None: + """Test that instance tags are NOT expanded by default (treated as regular types).""" + graphql_schema = load_schema_with_naming(test_schema_path, None) + selection_query = parse("query Selection { cabin { seats { isOccupied } doors { isLocked } temperature } }") + result = translate_to_protobuf( + graphql_schema, root_type="Cabin", expanded_instances=False, selection_query=selection_query + ) + + assert re.search( + r"message Cabin \{.*?" + r'option \(source\) = "Cabin".*?;.*?' + r"repeated Seat seats = 1.*?;.*?" + r"repeated Door doors = 2.*?;.*?" + r"float temperature = 3.*?;.*?" + r"\}", + result, + re.DOTALL, + ), "Cabin message with source option and repeated fields" + + assert re.search( + r"message Seat \{.*?" r'option \(source\) = "Seat";', result, re.DOTALL + ), "Seat message with source option" + + assert re.search( + r"message Door \{.*?" r'option \(source\) = "Door";', result, re.DOTALL + ), "Door message with source option" + + assert "message Cabin_seats" not in result + assert "message Cabin_doors" not in result + + def test_expanded_instances(self, test_schema_path: list[Path]) -> None: + """Test that instance tags are expanded into nested messages when enabled.""" + graphql_schema = load_schema_with_naming(test_schema_path, None) + selection_query = parse("query Selection { cabin { seats { isOccupied } doors { isLocked } temperature } }") + result = translate_to_protobuf( + graphql_schema, root_type="Cabin", expanded_instances=True, selection_query=selection_query + ) + + assert re.search( + r"message Cabin \{.*?" + r'option \(source\) = "Cabin";.*?' + r"message Cabin_Seat \{.*?" + r"message Cabin_Seat_ROW1 \{.*?" + r"Seat LEFT = 1;.*?" + r"Seat CENTER = 2;.*?" + r"Seat RIGHT = 3;.*?" + r"\}.*?" + r"message Cabin_Seat_ROW2 \{.*?" + r"Seat LEFT = 1;.*?" + r"Seat CENTER = 2;.*?" + r"Seat RIGHT = 3;.*?" + r"\}.*?" + r"message Cabin_Seat_ROW3 \{.*?" + r"Seat LEFT = 1;.*?" + r"Seat CENTER = 2;.*?" + r"Seat RIGHT = 3;.*?" + r"\}.*?" + r"Cabin_Seat_ROW1 ROW1 = 1;.*?" + r"Cabin_Seat_ROW2 ROW2 = 2;.*?" + r"Cabin_Seat_ROW3 ROW3 = 3;.*?" + r"\}.*?" + r"message Cabin_Door \{.*?" + r"message Cabin_Door_ROW1 \{.*?" + r"Door DRIVERSIDE = 1;.*?" + r"Door PASSENGERSIDE = 2;.*?" + r"\}.*?" + r"message Cabin_Door_ROW2 \{.*?" + r"Door DRIVERSIDE = 1;.*?" + r"Door PASSENGERSIDE = 2;.*?" + r"\}.*?" + r"Cabin_Door_ROW1 ROW1 = 1;.*?" + r"Cabin_Door_ROW2 ROW2 = 2;.*?" + r"\}.*?" + r"Cabin_Seat Seat = 1;.*?" + r"Cabin_Door Door = 2;.*?" + r"optional float temperature = 3 \[\(buf\.validate\.field\)\.float = \{gte: -100, lte: 100\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Cabin message with complete nested expanded instance structure" + + assert re.search( + r"message Door \{.*?" + r'option \(source\) = "Door";.*?' + r"optional bool isLocked = 1;.*?" + r"optional int32 position = 2 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Door message with fields" + + assert re.search( + r"message Seat \{.*?" + r'option \(source\) = "Seat";.*?' + r"optional bool isOccupied = 1;.*?" + r"optional int32 height = 2 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Seat message with fields" + + assert "SeatRowEnum" not in result + assert "SeatPositionEnum" not in result + assert "RowEnum" not in result + assert "SideEnum" not in result + + def test_expanded_instances_with_flatten_naming(self, test_schema_path: list[Path]) -> None: + """Test that expanded instances only expand in flatten mode when flag is set.""" + graphql_schema = load_schema_with_naming(test_schema_path, None) + selection_query = parse( + "query Selection { cabin { seats { isOccupied height } doors { isLocked position } temperature } }" + ) + result = translate_to_protobuf( + graphql_schema, + root_type="Cabin", + flatten_naming=True, + expanded_instances=True, + selection_query=selection_query, + ) + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"optional bool Cabin_seats_ROW1_LEFT_isOccupied = 1;.*?" + r"optional int32 Cabin_seats_ROW1_LEFT_height = 2 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW1_CENTER_isOccupied = 3;.*?" + r"optional int32 Cabin_seats_ROW1_CENTER_height = 4 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW1_RIGHT_isOccupied = 5;.*?" + r"optional int32 Cabin_seats_ROW1_RIGHT_height = 6 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW2_LEFT_isOccupied = 7;.*?" + r"optional int32 Cabin_seats_ROW2_LEFT_height = 8 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW2_CENTER_isOccupied = 9;.*?" + r"optional int32 Cabin_seats_ROW2_CENTER_height = 10 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW2_RIGHT_isOccupied = 11;.*?" + r"optional int32 Cabin_seats_ROW2_RIGHT_height = 12 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW3_LEFT_isOccupied = 13;.*?" + r"optional int32 Cabin_seats_ROW3_LEFT_height = 14 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW3_CENTER_isOccupied = 15;.*?" + r"optional int32 Cabin_seats_ROW3_CENTER_height = 16 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW3_RIGHT_isOccupied = 17;.*?" + r"optional int32 Cabin_seats_ROW3_RIGHT_height = 18 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_doors_ROW1_DRIVERSIDE_isLocked = 19;.*?" + r"optional int32 Cabin_doors_ROW1_DRIVERSIDE_position = 20 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_doors_ROW1_PASSENGERSIDE_isLocked = 21;.*?" + r"optional int32 Cabin_doors_ROW1_PASSENGERSIDE_position = 22 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_doors_ROW2_DRIVERSIDE_isLocked = 23;.*?" + r"optional int32 Cabin_doors_ROW2_DRIVERSIDE_position = 24 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_doors_ROW2_PASSENGERSIDE_isLocked = 25;.*?" + r"optional int32 Cabin_doors_ROW2_PASSENGERSIDE_position = 26 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional float Cabin_temperature = 27 \[\(buf\.validate\.field\)\.float = \{gte: -100, lte: 100\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Message with all flattened expanded instance fields" + + assert "SeatRowEnum" not in result + assert "SeatPositionEnum" not in result + assert "RowEnum" not in result + assert "SideEnum" not in result + + def test_expanded_instances_with_naming_config(self, test_schema_path: list[Path]) -> None: + """Test that naming config is applied to expanded instance field names in non-flatten mode.""" + naming_config = {"field": {"object": "MACROCASE"}} + graphql_schema = load_schema_with_naming(test_schema_path, naming_config) + selection_query = parse("query Selection { cabin { seats { isOccupied } doors { isLocked } temperature } }") + result = translate_to_protobuf( + graphql_schema, + root_type="Cabin", + expanded_instances=True, + naming_config=naming_config, + selection_query=selection_query, + ) + + assert re.search( + r"message Cabin \{.*?" + r'option \(source\) = "Cabin";.*?' + r"message Cabin_Seat \{.*?" + r"message Cabin_Seat_ROW1 \{.*?" + r"Seat LEFT = 1;.*?" + r"Seat CENTER = 2;.*?" + r"Seat RIGHT = 3;.*?" + r"\}.*?" + r"message Cabin_Seat_ROW2 \{.*?" + r"Seat LEFT = 1;.*?" + r"Seat CENTER = 2;.*?" + r"Seat RIGHT = 3;.*?" + r"\}.*?" + r"message Cabin_Seat_ROW3 \{.*?" + r"Seat LEFT = 1;.*?" + r"Seat CENTER = 2;.*?" + r"Seat RIGHT = 3;.*?" + r"\}.*?" + r"Cabin_Seat_ROW1 ROW1 = 1;.*?" + r"Cabin_Seat_ROW2 ROW2 = 2;.*?" + r"Cabin_Seat_ROW3 ROW3 = 3;.*?" + r"\}.*?" + r"message Cabin_Door \{.*?" + r"message Cabin_Door_ROW1 \{.*?" + r"Door DRIVERSIDE = 1;.*?" + r"Door PASSENGERSIDE = 2;.*?" + r"\}.*?" + r"message Cabin_Door_ROW2 \{.*?" + r"Door DRIVERSIDE = 1;.*?" + r"Door PASSENGERSIDE = 2;.*?" + r"\}.*?" + r"Cabin_Door_ROW1 ROW1 = 1;.*?" + r"Cabin_Door_ROW2 ROW2 = 2;.*?" + r"\}.*?" + r"Cabin_Seat SEAT = 1;.*?" + r"Cabin_Door DOOR = 2;.*?" + r"optional float TEMPERATURE = 3 \[\(buf\.validate\.field\)\.float = \{gte: -100, lte: 100\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Cabin message with complete nested expanded instance structure and MACROCASE field names" + + assert re.search( + r"message Door \{.*?" + r'option \(source\) = "Door";.*?' + r"optional bool IS_LOCKED = 1;.*?" + r"optional int32 POSITION = 2 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Door message with MACROCASE fields" + + assert re.search( + r"message Seat \{.*?" + r'option \(source\) = "Seat";.*?' + r"optional bool IS_OCCUPIED = 1;.*?" + r"optional int32 HEIGHT = 2 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Seat message with MACROCASE fields" + + assert "SeatRowEnum" not in result + assert "SeatPositionEnum" not in result + assert "RowEnum" not in result + assert "SideEnum" not in result + + def test_flatten_mode_expanded_instances_with_naming_config(self, test_schema_path: list[Path]) -> None: + """Test that naming config is applied to type name in flattened prefix with expanded instances.""" + naming_config = {"field": {"object": "snake_case"}} + graphql_schema = load_schema_with_naming(test_schema_path, naming_config) + selection_query = parse( + "query Selection { cabin { seats { isOccupied height } doors { isLocked position } temperature } }" + ) + result = translate_to_protobuf( + graphql_schema, + root_type="Cabin", + flatten_naming=True, + expanded_instances=True, + naming_config=naming_config, + selection_query=selection_query, + ) + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"optional bool Cabin_seats_ROW1_LEFT_is_occupied = 1;.*?" + r"optional int32 Cabin_seats_ROW1_LEFT_height = 2 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW1_CENTER_is_occupied = 3;.*?" + r"optional int32 Cabin_seats_ROW1_CENTER_height = 4 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW1_RIGHT_is_occupied = 5;.*?" + r"optional int32 Cabin_seats_ROW1_RIGHT_height = 6 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW2_LEFT_is_occupied = 7;.*?" + r"optional int32 Cabin_seats_ROW2_LEFT_height = 8 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW2_CENTER_is_occupied = 9;.*?" + r"optional int32 Cabin_seats_ROW2_CENTER_height = 10 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW2_RIGHT_is_occupied = 11;.*?" + r"optional int32 Cabin_seats_ROW2_RIGHT_height = 12 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW3_LEFT_is_occupied = 13;.*?" + r"optional int32 Cabin_seats_ROW3_LEFT_height = 14 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW3_CENTER_is_occupied = 15;.*?" + r"optional int32 Cabin_seats_ROW3_CENTER_height = 16 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_seats_ROW3_RIGHT_is_occupied = 17;.*?" + r"optional int32 Cabin_seats_ROW3_RIGHT_height = 18 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_doors_ROW1_DRIVERSIDE_is_locked = 19;.*?" + r"optional int32 Cabin_doors_ROW1_DRIVERSIDE_position = 20 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_doors_ROW1_PASSENGERSIDE_is_locked = 21;.*?" + r"optional int32 Cabin_doors_ROW1_PASSENGERSIDE_position = 22 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_doors_ROW2_DRIVERSIDE_is_locked = 23;.*?" + r"optional int32 Cabin_doors_ROW2_DRIVERSIDE_position = 24 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional bool Cabin_doors_ROW2_PASSENGERSIDE_is_locked = 25;.*?" + r"optional int32 Cabin_doors_ROW2_PASSENGERSIDE_position = 26 " + r"\[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"optional float Cabin_temperature = 27 " + r"\[\(buf\.validate\.field\)\.float = \{gte: -100, lte: 100\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Message with all flattened expanded instance fields in snake_case" + + assert "SeatRowEnum" not in result + assert "SeatPositionEnum" not in result + assert "RowEnum" not in result + assert "SideEnum" not in result + + def test_complete_proto_file(self) -> None: + """Test that the complete Protobuf output includes syntax, imports, and source option definition.""" + schema_str = """ + directive @range(min: Float, max: Float) on FIELD_DEFINITION + + enum GearPosition { + PARK + DRIVE + REVERSE + } + + type Transmission { + currentGear: GearPosition + rpm: Int @range(min: 0, max: 8000) + } + + type Query { + transmission: Transmission + } + """ + schema = build_schema(schema_str) + selection_query = parse("query Selection { transmission { currentGear rpm } }") + result = translate_to_protobuf(schema, root_type="Transmission", selection_query=selection_query) + + assert re.search( + r'syntax = "proto3";.*?' + r'import "google/protobuf/descriptor\.proto";.*?' + r'import "buf/validate/validate\.proto";.*?' + r"extend google\.protobuf\.MessageOptions \{.*?" + r"string source = 50001;.*?" + r"\}", + result, + re.DOTALL, + ), "File header with syntax, imports, and source option definition" + + assert re.search( + r"message GearPosition \{.*?" + r'option \(source\) = "GearPosition";.*?' + r"enum Enum \{.*?" + r"GEARPOSITION_UNSPECIFIED = 0;.*?" + r"PARK = 1;.*?" + r"DRIVE = 2;.*?" + r"REVERSE = 3;.*?" + r"\}.*?" + r"\}", + result, + re.DOTALL, + ), "GearPosition enum with source option" + + assert re.search( + r"message Transmission \{.*?" + r'option \(source\) = "Transmission";.*?' + r"GearPosition\.Enum currentGear = 1;.*?" + r"int32 rpm = 2 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 8000\}\];.*?" + r"\}", + result, + re.DOTALL, + ), "Transmission message with enum field and validation" + + def test_flatten_naming_multiple_root_types_unnamed_query(self, test_schema_path: list[Path]) -> None: + """Test that flatten mode without root_type flattens all root-level types with unnamed query.""" + graphql_schema = load_schema_with_naming(test_schema_path, None) + + vehicle_type = cast(GraphQLObjectType, graphql_schema.type_map["Vehicle"]) + cabin_type = cast(GraphQLObjectType, graphql_schema.type_map["Cabin"]) + door_type = cast(GraphQLObjectType, graphql_schema.type_map["Door"]) + + query_type = GraphQLObjectType( + "Query", + { + "vehicle": GraphQLField(vehicle_type), + "cabin": GraphQLField(cabin_type), + "door": GraphQLField(door_type), + }, + ) + + types = [type_def for type_def in graphql_schema.type_map.values() if type_def.name != "Query"] + graphql_schema = GraphQLSchema(query=query_type, types=types, directives=graphql_schema.directives) + + # Selection query that selects vehicle, cabin, and door at the top level + query_str = """ + query Selection { + vehicle { + doors { isLocked } + model + } + cabin { + seats { isOccupied } + temperature + } + door { + isLocked + position + instanceTag { row side } + } + } + """ + selection_query = parse(query_str) + graphql_schema = prune_schema_using_query_selection(graphql_schema, selection_query) + + result = translate_to_protobuf( + graphql_schema, flatten_naming=True, expanded_instances=False, selection_query=selection_query + ) + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"optional repeated Door Vehicle_doors = 1;.*?" + r"optional string Vehicle_model = 2;.*?" + r"optional repeated Seat Cabin_seats = 3;.*?" + r"optional float Cabin_temperature = 4 \[\(buf\.validate\.field\)\.float = \{gte: -100, lte: 100\}\];.*?" + r"optional bool Door_isLocked = 5;.*?" + r"optional int32 Door_position = 6 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"RowEnum\.Enum Door_instanceTag_row = 7 \[\(buf\.validate\.field\)\.required = true\];.*?" + r"SideEnum\.Enum Door_instanceTag_side = 8 \[\(buf\.validate\.field\)\.required = true\];.*?" + r"\}", + result, + re.DOTALL, + ), "Message with source option and flattened fields from all root-level types (Vehicle, Cabin, Door)" + + assert "message Seat {" in result, "Seat message should be included as it's referenced by arrays" + + assert "message Vehicle {" not in result, "Vehicle should be completely flattened" + assert "message Cabin {" not in result, "Cabin should be completely flattened" + assert "message Door {" not in result, "Door should be completely flattened" + + def test_flatten_naming_multiple_root_types_named_query(self, test_schema_path: list[Path]) -> None: + """Test that flatten mode uses the selection query name for the output message.""" + graphql_schema = load_schema_with_naming(test_schema_path, None) + + vehicle_type = cast(GraphQLObjectType, graphql_schema.type_map["Vehicle"]) + cabin_type = cast(GraphQLObjectType, graphql_schema.type_map["Cabin"]) + door_type = cast(GraphQLObjectType, graphql_schema.type_map["Door"]) + + query_type = GraphQLObjectType( + "Query", + { + "vehicle": GraphQLField(vehicle_type), + "cabin": GraphQLField(cabin_type), + "door": GraphQLField(door_type), + }, + ) + + types = [type_def for type_def in graphql_schema.type_map.values() if type_def.name != "Query"] + graphql_schema = GraphQLSchema(query=query_type, types=types, directives=graphql_schema.directives) + + query_str = """ + query Selection { + vehicle { + doors { isLocked } + model + } + cabin { + seats { isOccupied } + temperature + } + door { + isLocked + position + instanceTag { row side } + } + } + """ + selection_query = parse(query_str) + graphql_schema = prune_schema_using_query_selection(graphql_schema, selection_query) + + result = translate_to_protobuf( + graphql_schema, flatten_naming=True, expanded_instances=False, selection_query=selection_query + ) + + assert re.search( + r"message Selection \{.*?" + r'option \(source\) = "query: Selection";.*?' + r"optional repeated Door Vehicle_doors = 1;.*?" + r"optional string Vehicle_model = 2;.*?" + r"optional repeated Seat Cabin_seats = 3;.*?" + r"optional float Cabin_temperature = 4 \[\(buf\.validate\.field\)\.float = \{gte: -100, lte: 100\}\];.*?" + r"optional bool Door_isLocked = 5;.*?" + r"optional int32 Door_position = 6 \[\(buf\.validate\.field\)\.int32 = \{gte: 0, lte: 100\}\];.*?" + r"RowEnum\.Enum Door_instanceTag_row = 7 \[\(buf\.validate\.field\)\.required = true\];.*?" + r"SideEnum\.Enum Door_instanceTag_side = 8 \[\(buf\.validate\.field\)\.required = true\];.*?" + r"\}", + result, + re.DOTALL, + ), "Selection message with source option and flattened fields from all root-level types" + + assert "message Seat {" in result, "Seat message should be included as it's referenced by arrays" + + assert "message Vehicle {" not in result, "Vehicle should be completely flattened" + assert "message Cabin {" not in result, "Cabin should be completely flattened" + assert "message Door {" not in result, "Door should be completely flattened" diff --git a/uv.lock b/uv.lock index ff2834da..05b56b29 100644 --- a/uv.lock +++ b/uv.lock @@ -2177,6 +2177,7 @@ dependencies = [ { name = "case-converter" }, { name = "click" }, { name = "graphql-core" }, + { name = "jinja2" }, { name = "langcodes" }, { name = "pydantic" }, { name = "pyshacl" }, @@ -2209,6 +2210,7 @@ requires-dist = [ { name = "case-converter", specifier = ">=1.2.0" }, { name = "click", specifier = ">=8.1.7" }, { name = "graphql-core", specifier = ">=3.2.6" }, + { name = "jinja2", specifier = ">=3.1.0" }, { name = "langcodes", specifier = ">=3.5.0" }, { name = "pydantic", specifier = ">=2.10.6" }, { name = "pyshacl", specifier = ">=0.30.0" },