Skip to content

Commit bc6bc7f

Browse files
committed
Add unit test, breakdown plugin into multiple plugins
1 parent 8ce0e10 commit bc6bc7f

File tree

12 files changed

+20017
-299
lines changed

12 files changed

+20017
-299
lines changed

infrahub_sdk/ctl/graphql.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ariadne_codegen.plugins.manager import PluginManager
1010
from ariadne_codegen.schema import (
1111
filter_operations_definitions,
12+
filter_fragments_definitions,
1213
get_graphql_schema_from_path,
1314
)
1415
from ariadne_codegen.settings import ClientSettings, CommentsStrategy
@@ -23,7 +24,11 @@
2324
app = AsyncTyper()
2425
console = Console()
2526

26-
ARIADNE_PLUGINS = ["infrahub_sdk.graphql.plugin.InfrahubPlugin"]
27+
ARIADNE_PLUGINS = [
28+
"infrahub_sdk.graphql.plugin.PydanticBaseModelPlugin",
29+
"infrahub_sdk.graphql.plugin.FutureAnnotationPlugin",
30+
"infrahub_sdk.graphql.plugin.StandardTypeHintPlugin",
31+
]
2732

2833

2934
def find_gql_files(query_path: Path) -> list[Path]:
@@ -108,33 +113,36 @@ async def generate_return_types(
108113
for gql_file in gql_files:
109114
gql_per_directory[gql_file.parent].append(gql_file)
110115

116+
111117
# Generate the Pydantic Models for the GraphQL queries
112118
for directory, gql_files in gql_per_directory.items():
113-
package_generator = get_package_generator(
114-
schema=graphql_schema,
115-
fragments=[],
116-
settings=ClientSettings(
117-
schema_path=str(schema),
118-
target_package_name=directory.name,
119-
queries_path=str(directory),
120-
include_comments=CommentsStrategy.NONE,
121-
),
122-
plugin_manager=plugin_manager,
123-
)
124-
125119
for gql_file in gql_files:
126120
try:
127121
definitions = get_graphql_query(queries_path=gql_file, schema=graphql_schema)
128122
except ValueError as e:
129123
print(f"Error generating result types for {gql_file}: {e}")
130124
continue
131125
queries = filter_operations_definitions(definitions)
126+
fragments = filter_fragments_definitions(definitions)
127+
128+
package_generator = get_package_generator(
129+
schema=graphql_schema,
130+
fragments=fragments,
131+
settings=ClientSettings(
132+
schema_path=str(schema),
133+
target_package_name=directory.name,
134+
queries_path=str(directory),
135+
include_comments=CommentsStrategy.NONE,
136+
),
137+
plugin_manager=plugin_manager,
138+
)
139+
132140

133141
for query_operation in queries:
134142
package_generator.add_operation(query_operation)
135143

136-
# package_generator._generate_result_types()
137-
generate_result_types(directory=directory, package=package_generator)
144+
package_generator._generate_fragments()
145+
generate_result_types(directory=directory, package=package_generator)
138146

139-
for file_name in package_generator._result_types_files.keys():
140-
print(f"Generated {file_name} in {directory}")
147+
for file_name in package_generator._result_types_files.keys():
148+
print(f"Generated {file_name} in {directory}")

infrahub_sdk/graphql/plugin.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,48 +5,32 @@
55

66
from ariadne_codegen.plugins.base import Plugin
77

8-
from .return_type import GraphQLReturnTypeModel
9-
108
if TYPE_CHECKING:
119
from graphql import ExecutableDefinitionNode
1210

1311

14-
class InfrahubPlugin(Plugin):
12+
class FutureAnnotationPlugin(Plugin):
1513
@staticmethod
16-
def find_base_model_index(module: ast.Module) -> int:
17-
for idx, item in enumerate(module.body):
18-
if isinstance(item, ast.ImportFrom) and item.module == "base_model":
19-
return idx
20-
return -1
14+
def insert_future_annotation(module: ast.Module) -> ast.Module:
15+
# First check if the future annotation is already present
16+
for item in module.body:
17+
if isinstance(item, ast.ImportFrom) and item.module == "__future__" and item.names[0].name == "annotations":
18+
return module
2119

22-
@classmethod
23-
def replace_base_model_import(cls, module: ast.Module) -> ast.Module:
24-
base_model_index = cls.find_base_model_index(module)
25-
if base_model_index == -1:
26-
raise ValueError("BaseModel not found in module")
27-
module.body[base_model_index] = ast.ImportFrom(
28-
module="infrahub_sdk.graphql", names=[ast.alias(name=GraphQLReturnTypeModel.__name__)], level=2
29-
)
20+
module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0))
3021
return module
3122

32-
@staticmethod
33-
def replace_base_model_class(module: ast.Module) -> ast.Module:
34-
"""Replace the BaseModel inserted by Ariadne with the GraphQLReturnTypeModel class."""
35-
for item in module.body:
36-
if not isinstance(item, ast.ClassDef):
37-
continue
23+
def generate_result_types_module(
24+
self,
25+
module: ast.Module,
26+
operation_definition: ExecutableDefinitionNode, # noqa: ARG002
27+
) -> ast.Module:
28+
module = self.insert_future_annotation(module)
3829

39-
for base in item.bases:
40-
if isinstance(base, ast.Name) and base.id == "BaseModel":
41-
base.id = GraphQLReturnTypeModel.__name__
4230
return module
4331

44-
@staticmethod
45-
def insert_future_annotation(module: ast.Module) -> ast.Module:
46-
"""Insert the future annotation at the beginning of the module."""
47-
module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0))
48-
return module
4932

33+
class StandardTypeHintPlugin(Plugin):
5034
@classmethod
5135
def replace_list_in_subscript(cls, subscript: ast.Subscript) -> ast.Subscript:
5236
if isinstance(subscript.value, ast.Name) and subscript.value.id == "List":
@@ -76,10 +60,33 @@ def generate_result_types_module(
7660
module: ast.Module,
7761
operation_definition: ExecutableDefinitionNode, # noqa: ARG002
7862
) -> ast.Module:
79-
module = self.insert_future_annotation(module)
80-
module = self.replace_base_model_import(module)
81-
module = self.replace_base_model_class(module)
82-
63+
module = FutureAnnotationPlugin.insert_future_annotation(module)
8364
module = self.replace_list_annotations(module)
8465

8566
return module
67+
68+
69+
class PydanticBaseModelPlugin(Plugin):
70+
@staticmethod
71+
def find_base_model_index(module: ast.Module) -> int:
72+
for idx, item in enumerate(module.body):
73+
if isinstance(item, ast.ImportFrom) and item.module == "base_model":
74+
return idx
75+
return -1
76+
77+
@classmethod
78+
def replace_base_model_import(cls, module: ast.Module) -> ast.Module:
79+
base_model_index = cls.find_base_model_index(module)
80+
if base_model_index == -1:
81+
raise ValueError("BaseModel not found in module")
82+
module.body[base_model_index] = ast.ImportFrom(module="pydantic", names=[ast.alias(name="BaseModel")], level=0)
83+
return module
84+
85+
def generate_result_types_module(
86+
self,
87+
module: ast.Module,
88+
operation_definition: ExecutableDefinitionNode, # noqa: ARG002
89+
) -> ast.Module:
90+
module = self.replace_base_model_import(module)
91+
92+
return module

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ exclude = [
132132
"build",
133133
"dist",
134134
"examples",
135+
"tests/fixtures/unit/test_graphql_plugin",
135136
]
136137

137138

@@ -257,7 +258,6 @@ max-complexity = 17
257258
"S105", # 'PASS' is not a password but a state
258259
]
259260

260-
261261
"tests/**/*.py" = [
262262
"PLR2004", # Magic value used in comparison
263263
"S101", # Use of assert detected
@@ -277,6 +277,11 @@ max-complexity = 17
277277
"tests/unit/sdk/test_client.py" = [
278278
"W293", # Blank line contains whitespace (used within output check)
279279
]
280+
281+
"tests/fixtures/unit/test_graphql_plugin/*.py" = [
282+
"FA100", # Add `from __future__ import annotations` to simplify `typing.Optional`
283+
]
284+
280285
"tasks.py" = [
281286
"PLC0415", # `import` should be at the top-level of a file
282287
]
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional
4+
5+
from pydantic import Field
6+
7+
from infrahub_sdk.graphql import GraphQLReturnTypeModel
8+
9+
10+
class CreateDevice(GraphQLReturnTypeModel):
11+
infra_device_upsert: Optional[CreateDeviceInfraDeviceUpsert] = Field(alias="InfraDeviceUpsert")
12+
13+
14+
class CreateDeviceInfraDeviceUpsert(GraphQLReturnTypeModel):
15+
ok: Optional[bool]
16+
object: Optional[CreateDeviceInfraDeviceUpsertObject]
17+
18+
19+
class CreateDeviceInfraDeviceUpsertObject(GraphQLReturnTypeModel):
20+
id: str
21+
name: Optional[CreateDeviceInfraDeviceUpsertObjectName]
22+
description: Optional[CreateDeviceInfraDeviceUpsertObjectDescription]
23+
status: Optional[CreateDeviceInfraDeviceUpsertObjectStatus]
24+
25+
26+
class CreateDeviceInfraDeviceUpsertObjectName(GraphQLReturnTypeModel):
27+
value: Optional[str]
28+
29+
30+
class CreateDeviceInfraDeviceUpsertObjectDescription(GraphQLReturnTypeModel):
31+
value: Optional[str]
32+
33+
34+
class CreateDeviceInfraDeviceUpsertObjectStatus(GraphQLReturnTypeModel):
35+
value: Optional[str]
36+
37+
38+
CreateDevice.model_rebuild()
39+
CreateDeviceInfraDeviceUpsert.model_rebuild()
40+
CreateDeviceInfraDeviceUpsertObject.model_rebuild()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional
4+
5+
from pydantic import Field
6+
7+
from infrahub_sdk.graphql import GraphQLReturnTypeModel
8+
9+
10+
class CreateDevice(GraphQLReturnTypeModel):
11+
infra_device_upsert: Optional[CreateDeviceInfraDeviceUpsert] = Field(alias="InfraDeviceUpsert")
12+
13+
14+
class CreateDeviceInfraDeviceUpsert(GraphQLReturnTypeModel):
15+
ok: Optional[bool]
16+
object: Optional[CreateDeviceInfraDeviceUpsertObject]
17+
18+
19+
class CreateDeviceInfraDeviceUpsertObject(GraphQLReturnTypeModel):
20+
id: str
21+
name: Optional[CreateDeviceInfraDeviceUpsertObjectName]
22+
description: Optional[CreateDeviceInfraDeviceUpsertObjectDescription]
23+
status: Optional[CreateDeviceInfraDeviceUpsertObjectStatus]
24+
25+
26+
class CreateDeviceInfraDeviceUpsertObjectName(GraphQLReturnTypeModel):
27+
value: Optional[str]
28+
29+
30+
class CreateDeviceInfraDeviceUpsertObjectDescription(GraphQLReturnTypeModel):
31+
value: Optional[str]
32+
33+
34+
class CreateDeviceInfraDeviceUpsertObjectStatus(GraphQLReturnTypeModel):
35+
value: Optional[str]
36+
37+
38+
CreateDevice.model_rebuild()
39+
CreateDeviceInfraDeviceUpsert.model_rebuild()
40+
CreateDeviceInfraDeviceUpsertObject.model_rebuild()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Optional
2+
3+
from pydantic import Field
4+
5+
from infrahub_sdk.graphql import GraphQLReturnTypeModel
6+
7+
8+
class CreateDevice(GraphQLReturnTypeModel):
9+
infra_device_upsert: Optional["CreateDeviceInfraDeviceUpsert"] = Field(alias="InfraDeviceUpsert")
10+
11+
12+
class CreateDeviceInfraDeviceUpsert(GraphQLReturnTypeModel):
13+
ok: Optional[bool]
14+
object: Optional[dict]
15+
16+
17+
CreateDevice.model_rebuild()
18+
CreateDeviceInfraDeviceUpsert.model_rebuild()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional
4+
5+
from pydantic import Field
6+
7+
from infrahub_sdk.graphql import GraphQLReturnTypeModel
8+
9+
10+
class CreateDevice(GraphQLReturnTypeModel):
11+
infra_device_upsert: Optional["CreateDeviceInfraDeviceUpsert"] = Field(alias="InfraDeviceUpsert")
12+
13+
14+
class CreateDeviceInfraDeviceUpsert(GraphQLReturnTypeModel):
15+
ok: Optional[bool]
16+
object: Optional[dict]
17+
18+
19+
CreateDevice.model_rebuild()
20+
CreateDeviceInfraDeviceUpsert.model_rebuild()

0 commit comments

Comments
 (0)