Skip to content

Commit 521666f

Browse files
committed
Initial Prototype to generate Infrahub Schema from Pydantic models
1 parent f3334a6 commit 521666f

File tree

5 files changed

+580
-1
lines changed

5 files changed

+580
-1
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import annotations
2+
3+
from asyncio import run as aiorun
4+
5+
from typing import Annotated
6+
7+
from pydantic import BaseModel, Field
8+
from infrahub_sdk import InfrahubClient
9+
from rich import print as rprint
10+
from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic
11+
12+
13+
class Tag(BaseModel):
14+
name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")]
15+
label: str | None = Field(description="The label of the tag")
16+
description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None
17+
18+
19+
class Car(BaseModel):
20+
name: str = Field(description="The name of the car")
21+
tags: list[Tag]
22+
owner: Annotated[Person, RelParam(identifier="car__person")]
23+
secondary_owner: Person | None = None
24+
25+
26+
class Person(BaseModel):
27+
name: str
28+
cars: Annotated[list[Car] | None, RelParam(identifier="car__person")] = None
29+
30+
31+
async def main():
32+
client = InfrahubClient()
33+
schema = from_pydantic(models=[Person, Car, Tag])
34+
rprint(schema.to_schema_dict())
35+
response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True)
36+
rprint(response)
37+
38+
if __name__ == "__main__":
39+
aiorun(main())

infrahub_sdk/schema/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..graphql import Mutation
2020
from ..queries import SCHEMA_HASH_SYNC_STATUS
2121
from .main import (
22+
AttributeKind,
2223
AttributeSchema,
2324
AttributeSchemaAPI,
2425
BranchSchema,
@@ -36,6 +37,7 @@
3637
SchemaRootAPI,
3738
TemplateSchemaAPI,
3839
)
40+
from .pydantic_utils import InfrahubAttributeParam, InfrahubRelationshipParam, from_pydantic
3941

4042
if TYPE_CHECKING:
4143
from ..client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync
@@ -45,11 +47,14 @@
4547

4648

4749
__all__ = [
50+
"AttributeKind",
4851
"AttributeSchema",
4952
"AttributeSchemaAPI",
5053
"BranchSupportType",
5154
"GenericSchema",
5255
"GenericSchemaAPI",
56+
"InfrahubAttributeParam",
57+
"InfrahubRelationshipParam",
5358
"NodeSchema",
5459
"NodeSchemaAPI",
5560
"ProfileSchemaAPI",
@@ -60,6 +65,7 @@
6065
"SchemaRoot",
6166
"SchemaRootAPI",
6267
"TemplateSchemaAPI",
68+
"from_pydantic",
6369
]
6470

6571

infrahub_sdk/schema/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ class SchemaRoot(BaseModel):
338338
node_extensions: list[NodeExtensionSchema] = Field(default_factory=list)
339339

340340
def to_schema_dict(self) -> dict[str, Any]:
341-
return self.model_dump(exclude_unset=True, exclude_defaults=True)
341+
return self.model_dump(exclude_defaults=True, mode="json")
342342

343343

344344
class SchemaRootAPI(BaseModel):
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
from __future__ import annotations
2+
3+
import typing
4+
from dataclasses import dataclass
5+
from types import UnionType
6+
from typing import Any
7+
8+
from pydantic import BaseModel
9+
from pydantic.fields import FieldInfo, PydanticUndefined
10+
11+
from infrahub_sdk.schema.main import AttributeSchema, NodeSchema, RelationshipSchema, SchemaRoot
12+
13+
from .main import AttributeKind, BranchSupportType, SchemaState
14+
15+
KIND_MAPPING: dict[type, AttributeKind] = {
16+
int: AttributeKind.NUMBER,
17+
float: AttributeKind.NUMBER,
18+
str: AttributeKind.TEXT,
19+
bool: AttributeKind.BOOLEAN,
20+
}
21+
22+
23+
@dataclass
24+
class InfrahubAttributeParam:
25+
state: SchemaState = SchemaState.PRESENT
26+
kind: AttributeKind | None = None
27+
label: str | None = None
28+
unique: bool = False
29+
branch: BranchSupportType | None = None
30+
31+
32+
@dataclass
33+
class InfrahubRelationshipParam:
34+
identifier: str | None = None
35+
branch: BranchSupportType | None = None
36+
37+
38+
@dataclass
39+
class InfrahubFieldInfo:
40+
name: str
41+
types: list[type]
42+
optional: bool
43+
default: Any
44+
45+
@property
46+
def primary_type(self) -> type:
47+
if len(self.types) == 0:
48+
raise ValueError("No types found")
49+
if self.is_list:
50+
return typing.get_args(self.types[0])[0]
51+
52+
return self.types[0]
53+
54+
@property
55+
def is_attribute(self) -> bool:
56+
return self.primary_type in KIND_MAPPING
57+
58+
@property
59+
def is_relationship(self) -> bool:
60+
return issubclass(self.primary_type, BaseModel)
61+
62+
@property
63+
def is_list(self) -> bool:
64+
return typing.get_origin(self.types[0]) is list
65+
66+
def to_dict(self) -> dict:
67+
return {
68+
"name": self.name,
69+
"primary_type": self.primary_type,
70+
"optional": self.optional,
71+
"default": self.default,
72+
"is_attribute": self.is_attribute,
73+
"is_relationship": self.is_relationship,
74+
"is_list": self.is_list,
75+
}
76+
77+
78+
def analyze_field(field_name: str, field: FieldInfo) -> InfrahubFieldInfo:
79+
clean_types = []
80+
if isinstance(field.annotation, UnionType) or (
81+
hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr]
82+
):
83+
clean_types = [t for t in field.annotation.__args__ if t is not type(None)] # type: ignore[union-attr]
84+
else:
85+
clean_types.append(field.annotation)
86+
87+
return InfrahubFieldInfo(
88+
name=field.alias or field_name,
89+
types=clean_types,
90+
optional=not field.is_required(),
91+
default=field.default if field.default is not PydanticUndefined else None,
92+
)
93+
94+
95+
def get_attribute_kind(field: FieldInfo) -> AttributeKind:
96+
if field.annotation in KIND_MAPPING:
97+
return KIND_MAPPING[field.annotation]
98+
99+
if isinstance(field.annotation, UnionType) or (
100+
hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr]
101+
):
102+
valid_types = [t for t in field.annotation.__args__ if t is not type(None)] # type: ignore[union-attr]
103+
if len(valid_types) == 1 and valid_types[0] in KIND_MAPPING:
104+
return KIND_MAPPING[valid_types[0]]
105+
106+
raise ValueError(f"Unknown field type: {field.annotation}")
107+
108+
109+
def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo) -> AttributeSchema: # noqa: ARG001
110+
field_param = InfrahubAttributeParam()
111+
field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubAttributeParam)]
112+
if len(field_params) == 1:
113+
field_param = field_params[0]
114+
115+
return AttributeSchema(
116+
name=field_name,
117+
label=field_param.label,
118+
description=field.description,
119+
kind=field_param.kind or get_attribute_kind(field),
120+
optional=not field.is_required(),
121+
unique=field_param.unique,
122+
branch=field_param.branch,
123+
)
124+
125+
126+
def field_to_relationship(
127+
field_name: str,
128+
field_info: InfrahubFieldInfo,
129+
field: FieldInfo,
130+
namespace: str = "Testing",
131+
) -> RelationshipSchema:
132+
field_param = InfrahubRelationshipParam()
133+
field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubRelationshipParam)]
134+
if len(field_params) == 1:
135+
field_param = field_params[0]
136+
137+
return RelationshipSchema(
138+
name=field_name,
139+
description=field.description,
140+
peer=f"{namespace}{field_info.primary_type.__name__}",
141+
identifier=field_param.identifier,
142+
cardinality="many" if field_info.is_list else "one",
143+
optional=field_info.optional,
144+
branch=field_param.branch,
145+
)
146+
147+
148+
def from_pydantic(models: list[type[BaseModel]], namespace: str = "Testing") -> SchemaRoot:
149+
schema = SchemaRoot(version="1.0")
150+
151+
for model in models:
152+
node = NodeSchema(
153+
name=model.__name__,
154+
namespace=namespace,
155+
)
156+
157+
for field_name, field in model.model_fields.items():
158+
field_info = analyze_field(field_name, field)
159+
160+
if field_info.is_attribute:
161+
node.attributes.append(field_to_attribute(field_name, field_info, field))
162+
elif field_info.is_relationship:
163+
node.relationships.append(field_to_relationship(field_name, field_info, field, namespace))
164+
165+
schema.nodes.append(node)
166+
167+
return schema

0 commit comments

Comments
 (0)