|
5 | 5 |
|
6 | 6 | from ariadne_codegen.plugins.base import Plugin
|
7 | 7 |
|
8 |
| -from .return_type import GraphQLReturnTypeModel |
9 |
| - |
10 | 8 | if TYPE_CHECKING:
|
11 | 9 | from graphql import ExecutableDefinitionNode
|
12 | 10 |
|
13 | 11 |
|
14 |
| -class InfrahubPlugin(Plugin): |
| 12 | +class FutureAnnotationPlugin(Plugin): |
15 | 13 | @staticmethod
|
16 |
| - def find_base_model_index(module: ast.Module) -> int: |
17 |
| - for idx, item in enumerate(module.body): |
18 |
| - if isinstance(item, ast.ImportFrom) and item.module == "base_model": |
19 |
| - return idx |
20 |
| - return -1 |
| 14 | + def insert_future_annotation(module: ast.Module) -> ast.Module: |
| 15 | + # First check if the future annotation is already present |
| 16 | + for item in module.body: |
| 17 | + if isinstance(item, ast.ImportFrom) and item.module == "__future__" and item.names[0].name == "annotations": |
| 18 | + return module |
21 | 19 |
|
22 |
| - @classmethod |
23 |
| - def replace_base_model_import(cls, module: ast.Module) -> ast.Module: |
24 |
| - base_model_index = cls.find_base_model_index(module) |
25 |
| - if base_model_index == -1: |
26 |
| - raise ValueError("BaseModel not found in module") |
27 |
| - module.body[base_model_index] = ast.ImportFrom( |
28 |
| - module="infrahub_sdk.graphql", names=[ast.alias(name=GraphQLReturnTypeModel.__name__)], level=2 |
29 |
| - ) |
| 20 | + module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0)) |
30 | 21 | return module
|
31 | 22 |
|
32 |
| - @staticmethod |
33 |
| - def replace_base_model_class(module: ast.Module) -> ast.Module: |
34 |
| - """Replace the BaseModel inserted by Ariadne with the GraphQLReturnTypeModel class.""" |
35 |
| - for item in module.body: |
36 |
| - if not isinstance(item, ast.ClassDef): |
37 |
| - continue |
| 23 | + def generate_result_types_module( |
| 24 | + self, |
| 25 | + module: ast.Module, |
| 26 | + operation_definition: ExecutableDefinitionNode, # noqa: ARG002 |
| 27 | + ) -> ast.Module: |
| 28 | + module = self.insert_future_annotation(module) |
38 | 29 |
|
39 |
| - for base in item.bases: |
40 |
| - if isinstance(base, ast.Name) and base.id == "BaseModel": |
41 |
| - base.id = GraphQLReturnTypeModel.__name__ |
42 | 30 | return module
|
43 | 31 |
|
44 |
| - @staticmethod |
45 |
| - def insert_future_annotation(module: ast.Module) -> ast.Module: |
46 |
| - """Insert the future annotation at the beginning of the module.""" |
47 |
| - module.body.insert(0, ast.ImportFrom(module="__future__", names=[ast.alias(name="annotations")], level=0)) |
48 |
| - return module |
49 | 32 |
|
| 33 | +class StandardTypeHintPlugin(Plugin): |
50 | 34 | @classmethod
|
51 | 35 | def replace_list_in_subscript(cls, subscript: ast.Subscript) -> ast.Subscript:
|
52 | 36 | if isinstance(subscript.value, ast.Name) and subscript.value.id == "List":
|
@@ -76,10 +60,33 @@ def generate_result_types_module(
|
76 | 60 | module: ast.Module,
|
77 | 61 | operation_definition: ExecutableDefinitionNode, # noqa: ARG002
|
78 | 62 | ) -> ast.Module:
|
79 |
| - module = self.insert_future_annotation(module) |
80 |
| - module = self.replace_base_model_import(module) |
81 |
| - module = self.replace_base_model_class(module) |
82 |
| - |
| 63 | + module = FutureAnnotationPlugin.insert_future_annotation(module) |
83 | 64 | module = self.replace_list_annotations(module)
|
84 | 65 |
|
85 | 66 | return module
|
| 67 | + |
| 68 | + |
| 69 | +class PydanticBaseModelPlugin(Plugin): |
| 70 | + @staticmethod |
| 71 | + def find_base_model_index(module: ast.Module) -> int: |
| 72 | + for idx, item in enumerate(module.body): |
| 73 | + if isinstance(item, ast.ImportFrom) and item.module == "base_model": |
| 74 | + return idx |
| 75 | + return -1 |
| 76 | + |
| 77 | + @classmethod |
| 78 | + def replace_base_model_import(cls, module: ast.Module) -> ast.Module: |
| 79 | + base_model_index = cls.find_base_model_index(module) |
| 80 | + if base_model_index == -1: |
| 81 | + raise ValueError("BaseModel not found in module") |
| 82 | + module.body[base_model_index] = ast.ImportFrom(module="pydantic", names=[ast.alias(name="BaseModel")], level=0) |
| 83 | + return module |
| 84 | + |
| 85 | + def generate_result_types_module( |
| 86 | + self, |
| 87 | + module: ast.Module, |
| 88 | + operation_definition: ExecutableDefinitionNode, # noqa: ARG002 |
| 89 | + ) -> ast.Module: |
| 90 | + module = self.replace_base_model_import(module) |
| 91 | + |
| 92 | + return module |
0 commit comments