Skip to content

Commit b8221e1

Browse files
committed
Add typing support for get | filters | all methods when using Pydantic
1 parent 521666f commit b8221e1

File tree

6 files changed

+836
-97
lines changed

6 files changed

+836
-97
lines changed

docs/docs/python-sdk/examples/schema_pydantic.py renamed to docs/docs/python-sdk/examples/pydantic_car.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,47 @@
44

55
from typing import Annotated
66

7-
from pydantic import BaseModel, Field
7+
from pydantic import BaseModel, Field, ConfigDict
88
from infrahub_sdk import InfrahubClient
99
from rich import print as rprint
10-
from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic
10+
from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericModel
1111

1212

13-
class Tag(BaseModel):
13+
class Tag(NodeModel):
14+
model_config = ConfigDict(
15+
node_schema=NodeSchema(name="Tag", namespace="Test", human_readable_fields=["name__value"])
16+
)
17+
1418
name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")]
1519
label: str | None = Field(description="The label of the tag")
1620
description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None
1721

1822

19-
class Car(BaseModel):
23+
class TestCar(NodeModel):
2024
name: str = Field(description="The name of the car")
2125
tags: list[Tag]
22-
owner: Annotated[Person, RelParam(identifier="car__person")]
23-
secondary_owner: Person | None = None
26+
owner: Annotated[TestPerson, RelParam(identifier="car__person")]
27+
secondary_owner: TestPerson | None = None
2428

2529

26-
class Person(BaseModel):
30+
class TestPerson(GenericModel):
2731
name: str
28-
cars: Annotated[list[Car] | None, RelParam(identifier="car__person")] = None
32+
33+
class TestCarOwner(NodeModel, TestPerson):
34+
cars: Annotated[list[TestCar] | None, RelParam(identifier="car__person")] = None
2935

3036

3137
async def main():
3238
client = InfrahubClient()
33-
schema = from_pydantic(models=[Person, Car, Tag])
39+
schema = from_pydantic(models=[TestPerson, TestCar, Tag, TestPerson, TestCarOwner])
3440
rprint(schema.to_schema_dict())
3541
response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True)
3642
rprint(response)
3743

44+
# Create a Tag
45+
tag = await client.create("TestTag", name="Blue", label="Blue")
46+
await tag.save(allow_upsert=True)
47+
48+
3849
if __name__ == "__main__":
3950
aiorun(main())
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
from asyncio import run as aiorun
4+
5+
from infrahub_sdk.async_typer import AsyncTyper
6+
7+
from typing import Annotated
8+
9+
from pydantic import BaseModel, Field, ConfigDict
10+
from infrahub_sdk import InfrahubClient
11+
from rich import print as rprint
12+
from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericSchema, GenericModel, RelationshipKind
13+
14+
15+
app = AsyncTyper()
16+
17+
18+
class Site(NodeModel):
19+
model_config = ConfigDict(
20+
node_schema=NodeSchema(name="Site", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"])
21+
)
22+
23+
name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the site")
24+
25+
26+
class Vlan(NodeModel):
27+
model_config = ConfigDict(
28+
node_schema=NodeSchema(name="Vlan", namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"])
29+
)
30+
31+
name: str
32+
vlan_id: int
33+
description: str | None = None
34+
35+
36+
class Device(NodeModel):
37+
model_config = ConfigDict(
38+
node_schema=NodeSchema(name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"])
39+
)
40+
41+
name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the car")
42+
site: Annotated[Site, RelParam(kind=RelationshipKind.ATTRIBUTE, identifier="device__site")]
43+
interfaces: Annotated[list[Interface], RelParam(kind=RelationshipKind.COMPONENT, identifier="device__interfaces")] = Field(default_factory=list)
44+
45+
46+
class Interface(GenericModel):
47+
model_config = ConfigDict(
48+
generic_schema=GenericSchema(name="Interface", namespace="Infra", human_friendly_id=["device__name__value", "name__value"], display_labels=["name__value"])
49+
)
50+
51+
device: Annotated[Device, RelParam(kind=RelationshipKind.PARENT, identifier="device__interfaces")]
52+
name: str
53+
description: str | None = None
54+
55+
class L2Interface(Interface):
56+
model_config = ConfigDict(
57+
node_schema=NodeSchema(name="L2Interface", namespace="Infra")
58+
)
59+
60+
vlans: list[Vlan] = Field(default_factory=list)
61+
62+
class LoopbackInterface(Interface):
63+
model_config = ConfigDict(
64+
node_schema=NodeSchema(name="LoopbackInterface", namespace="Infra")
65+
)
66+
67+
68+
69+
@app.command()
70+
async def load_schema():
71+
client = InfrahubClient()
72+
schema = from_pydantic(models=[Site, Device, Interface, L2Interface, LoopbackInterface, Vlan])
73+
rprint(schema.to_schema_dict())
74+
response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True)
75+
rprint(response)
76+
77+
78+
@app.command()
79+
async def load_data():
80+
client = InfrahubClient()
81+
82+
atl = await client.create("InfraSite", name="ATL")
83+
await atl.save(allow_upsert=True)
84+
cdg = await client.create("InfraSite", name="CDG")
85+
await cdg.save(allow_upsert=True)
86+
87+
device1 = await client.create("InfraDevice", name="atl1-dev1", site=atl)
88+
await device1.save(allow_upsert=True)
89+
device2 = await client.create("InfraDevice", name="atl1-dev2", site=atl)
90+
await device2.save(allow_upsert=True)
91+
92+
lo0dev1 = await client.create("InfraLoopbackInterface", name="lo0", device=device1)
93+
await lo0dev1.save(allow_upsert=True)
94+
lo0dev2 = await client.create("InfraLoopbackInterface", name="lo0", device=device2)
95+
await lo0dev2.save(allow_upsert=True)
96+
97+
for idx in range(1, 3):
98+
interface = await client.create("InfraL2Interface", name=f"Ethernet{idx}", device=device1)
99+
await interface.save(allow_upsert=True)
100+
101+
102+
@app.command()
103+
async def query_data():
104+
client = InfrahubClient()
105+
sites = await client.all(kind=Site)
106+
107+
breakpoint()
108+
devices = await client.all(kind=Device)
109+
for device in devices:
110+
rprint(device)
111+
112+
if __name__ == "__main__":
113+
app()

0 commit comments

Comments
 (0)