Skip to content

Commit 7d1ec47

Browse files
committed
sqlmodel: improve syntax
1 parent 56bd3a1 commit 7d1ec47

File tree

2 files changed

+95
-37
lines changed

2 files changed

+95
-37
lines changed

fquery/sqlmodel.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import fields, is_dataclass
1+
from dataclasses import _FIELD, dataclass, field, fields, is_dataclass
22
from datetime import date, datetime, time
33
from typing import (
44
ClassVar,
@@ -42,6 +42,33 @@
4242
SQL_PK = {"metadata": {"SQL": {"primary_key": True}}}
4343

4444

45+
def unique():
46+
pass
47+
48+
49+
def foreignkey(name):
50+
return field(
51+
default=None, metadata={"SQL": {"relationship": True, "back_populates": False}}
52+
)
53+
54+
55+
def one_to_many():
56+
return field(default=None, metadata={"SQL": {"relationship": True}})
57+
58+
59+
def many_to_one(back_populates=None):
60+
ret = field(
61+
default=None, metadata={"SQL": {"relationship": True, "many_to_one": True}}
62+
)
63+
if back_populates is not None:
64+
ret.metadata["SQL"][back_populates] = back_populates
65+
return ret
66+
67+
68+
def sqlmodel(cls):
69+
return model()(dataclass(kw_only=True)(cls))
70+
71+
4572
def model(table: bool = True, table_name: str = None, global_id: bool = False):
4673
"""
4774
A decorator that generates a SQLModel from a dataclass.
@@ -58,6 +85,8 @@ def sqlmodel(self) -> SQLModel:
5885
return self.__sqlmodel__(**attrs)
5986

6087
def get_field_def(cls, field) -> Union[Field, Relationship]:
88+
if field.default == unique:
89+
return Field(unique=True)
6190
sql_meta = field.metadata.get("SQL", {})
6291
has_foreign_key = bool(sql_meta.get("foreign_key", None))
6392
has_relationship = bool(sql_meta.get("relationship", None))
@@ -68,7 +97,7 @@ def get_field_def(cls, field) -> Union[Field, Relationship]:
6897
# TODO: revisit the idea of using string for unknown types
6998
sa_column=Column(
7099
SA_TYPEMAP.get(field.type, String),
71-
GLOBAL_ID_SEQ if global_id else None,
100+
GLOBAL_ID_SEQ if global_id else cls.id_seq,
72101
primary_key=(
73102
field.name == "id"
74103
or field.metadata.get("SQL", {}).get("primary_key", False)
@@ -135,15 +164,44 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
135164
):
136165
sqlmodel_cls.__annotations__[field.name] = Optional[sqlmodel_cls]
137166

167+
def default_table_name(clsname: str) -> str:
168+
return inflection.underscore(inflection.pluralize(clsname))
169+
138170
def decorator(cls):
139171
# Check if the class is a dataclass
140172
if not is_dataclass(cls):
141173
raise ValueError("The class must be a dataclass")
142174

143175
nonlocal table_name
144-
table_name = table_name or inflection.underscore(
145-
inflection.pluralize(cls.__name__)
146-
)
176+
table_name = table_name or default_table_name(cls.__name__)
177+
178+
if not global_id:
179+
cls.id_seq = Sequence(f"{table_name}_seq")
180+
181+
# Insert any foreign keys as necessary
182+
for cfield in fields(cls):
183+
sql_meta = cfield.metadata.get("SQL", {})
184+
has_relationship = bool(sql_meta.get("relationship", None))
185+
if has_relationship:
186+
many_to_one = sql_meta.get("many_to_one", False)
187+
foreign_key_name = cfield.name + "_id"
188+
key_table_name = table_name
189+
if many_to_one:
190+
type_class = cfield.type
191+
other_class = type_class.__args__[0]
192+
other_class = getattr(other_class, "__name__", None)
193+
key_table_name = default_table_name(other_class)
194+
back_populates = sql_meta.get("back_populates", None)
195+
if back_populates is False or many_to_one:
196+
new_field = field(
197+
default=None,
198+
metadata={"SQL": {"foreign_key": f"{key_table_name}.id"}},
199+
)
200+
new_field._field_type = _FIELD
201+
new_field.name = foreign_key_name
202+
new_field.type = Optional[int]
203+
cls.__dataclass_fields__[foreign_key_name] = new_field
204+
setattr(cls, new_field.name, new_field.default)
147205

148206
# Generate the SQLModel class
149207
sqlmodel_cls = type(
@@ -167,16 +225,17 @@ def decorator(cls):
167225
# For SQLModel's SQLModelMetaClass
168226
table=table,
169227
)
170-
171228
cls.__sqlmodel__ = sqlmodel_cls
172229
# Update type annotations in any class with a relationship with this class to point
173230
# to the SQLModel, not the dataclass
174-
for field in fields(cls):
175-
if not field.name in sqlmodel_cls.__sqlmodel_relationships__:
231+
for cfield in fields(cls):
232+
if not cfield.name in sqlmodel_cls.__sqlmodel_relationships__:
176233
continue
177-
rel = sqlmodel_cls.__sqlmodel_relationships__.get(field.name, None)
234+
rel = sqlmodel_cls.__sqlmodel_relationships__.get(cfield.name, None)
178235
if rel and hasattr(rel, "back_populates"):
179-
patch_back_populates_types(field, rel.back_populates, cls, sqlmodel_cls)
236+
patch_back_populates_types(
237+
cfield, rel.back_populates, cls, sqlmodel_cls
238+
)
180239
cls.sqlmodel = sqlmodel
181240
return cls
182241

tests/test_sqlmodel.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,41 @@
1-
from dataclasses import dataclass, field
1+
from dataclasses import field
22
from datetime import datetime
33
from typing import List, Optional
44

55
from sqlalchemy import create_engine
66
from sqlalchemy.orm import sessionmaker
77
from sqlmodel import SQLModel
88

9-
from fquery.sqlmodel import SQL_PK, model
9+
from fquery.sqlmodel import (
10+
SQL_PK,
11+
foreignkey,
12+
many_to_one,
13+
one_to_many,
14+
sqlmodel,
15+
unique,
16+
)
1017

1118

12-
@model(global_id=True)
13-
@dataclass(kw_only=True)
19+
@sqlmodel
1420
class User:
1521
id: int | None = None
1622
name: str
17-
email: str
23+
email: str = unique()
1824
created_at: datetime = None
1925
updated_at: datetime = None
20-
friend_id: Optional[int] = field(
21-
default=None, metadata={"SQL": {"foreign_key": "users.id"}}
22-
)
23-
friend: Optional["User"] = field(
24-
default=None, metadata={"SQL": {"relationship": True, "back_populates": False}}
25-
)
26-
reviews: List["Review"] = field(
27-
default=None, metadata={"SQL": {"relationship": True}}
28-
)
29-
30-
31-
@model(global_id=True)
32-
@dataclass(kw_only=True)
26+
27+
friend: Optional["User"] = foreignkey("users.id")
28+
reviews: List["Review"] = one_to_many()
29+
30+
31+
@sqlmodel
3332
class Review:
3433
id: int | None = None
3534
score: int
36-
user_id: Optional[int] = field(
37-
default=None, metadata={"SQL": {"foreign_key": "users.id"}}
38-
)
39-
user: Optional[User] = field(
40-
default=None, metadata={"SQL": {"relationship": True, "many_to_one": True}}
41-
)
35+
user: Optional[User] = many_to_one("users.id")
4236

4337

44-
@model(global_id=True)
45-
@dataclass
38+
@sqlmodel
4639
class Relation:
4740
src: int | None = field(**SQL_PK)
4841
type: int = field(**SQL_PK)
@@ -93,7 +86,13 @@ def test_sqlmodel():
9386
session.add(user1.sqlmodel())
9487
session.commit()
9588

96-
relation = Relation(user.id, 1, user1.id, datetime.now(), datetime.now())
89+
relation = Relation(
90+
src=user.id,
91+
type=1,
92+
dst=user1.id,
93+
created_at=datetime.now(),
94+
updated_at=datetime.now(),
95+
)
9796
session.add(relation.sqlmodel())
9897
session.commit()
9998
# Read all users from the database

0 commit comments

Comments
 (0)