1- from dataclasses import fields , is_dataclass
1+ from dataclasses import _FIELD , dataclass , field , fields , is_dataclass
22from datetime import date , datetime , time
33from typing import (
44 ClassVar ,
4242SQL_PK = {"metadata" : {"SQL" : {"primary_key" : True }}}
4343
4444
45+ def unique ():
46+ pass
47+
48+
49+ def foreignkey (name ):
50+ return field (
51+ default = None , metadata = {"SQL" : {"relationship" : True , "back_populates" : False }}
52+ )
53+
54+
55+ def one_to_many ():
56+ return field (default = None , metadata = {"SQL" : {"relationship" : True }})
57+
58+
59+ def many_to_one (back_populates = None ):
60+ ret = field (
61+ default = None , metadata = {"SQL" : {"relationship" : True , "many_to_one" : True }}
62+ )
63+ if back_populates is not None :
64+ ret .metadata ["SQL" ][back_populates ] = back_populates
65+ return ret
66+
67+
68+ def sqlmodel (cls ):
69+ return model ()(dataclass (kw_only = True )(cls ))
70+
71+
4572def model (table : bool = True , table_name : str = None , global_id : bool = False ):
4673 """
4774 A decorator that generates a SQLModel from a dataclass.
@@ -58,6 +85,8 @@ def sqlmodel(self) -> SQLModel:
5885 return self .__sqlmodel__ (** attrs )
5986
6087 def get_field_def (cls , field ) -> Union [Field , Relationship ]:
88+ if field .default == unique :
89+ return Field (unique = True )
6190 sql_meta = field .metadata .get ("SQL" , {})
6291 has_foreign_key = bool (sql_meta .get ("foreign_key" , None ))
6392 has_relationship = bool (sql_meta .get ("relationship" , None ))
@@ -68,7 +97,7 @@ def get_field_def(cls, field) -> Union[Field, Relationship]:
6897 # TODO: revisit the idea of using string for unknown types
6998 sa_column = Column (
7099 SA_TYPEMAP .get (field .type , String ),
71- GLOBAL_ID_SEQ if global_id else None ,
100+ GLOBAL_ID_SEQ if global_id else cls . id_seq ,
72101 primary_key = (
73102 field .name == "id"
74103 or field .metadata .get ("SQL" , {}).get ("primary_key" , False )
@@ -135,15 +164,44 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls):
135164 ):
136165 sqlmodel_cls .__annotations__ [field .name ] = Optional [sqlmodel_cls ]
137166
167+ def default_table_name (clsname : str ) -> str :
168+ return inflection .underscore (inflection .pluralize (clsname ))
169+
138170 def decorator (cls ):
139171 # Check if the class is a dataclass
140172 if not is_dataclass (cls ):
141173 raise ValueError ("The class must be a dataclass" )
142174
143175 nonlocal table_name
144- table_name = table_name or inflection .underscore (
145- inflection .pluralize (cls .__name__ )
146- )
176+ table_name = table_name or default_table_name (cls .__name__ )
177+
178+ if not global_id :
179+ cls .id_seq = Sequence (f"{ table_name } _seq" )
180+
181+ # Insert any foreign keys as necessary
182+ for cfield in fields (cls ):
183+ sql_meta = cfield .metadata .get ("SQL" , {})
184+ has_relationship = bool (sql_meta .get ("relationship" , None ))
185+ if has_relationship :
186+ many_to_one = sql_meta .get ("many_to_one" , False )
187+ foreign_key_name = cfield .name + "_id"
188+ key_table_name = table_name
189+ if many_to_one :
190+ type_class = cfield .type
191+ other_class = type_class .__args__ [0 ]
192+ other_class = getattr (other_class , "__name__" , None )
193+ key_table_name = default_table_name (other_class )
194+ back_populates = sql_meta .get ("back_populates" , None )
195+ if back_populates is False or many_to_one :
196+ new_field = field (
197+ default = None ,
198+ metadata = {"SQL" : {"foreign_key" : f"{ key_table_name } .id" }},
199+ )
200+ new_field ._field_type = _FIELD
201+ new_field .name = foreign_key_name
202+ new_field .type = Optional [int ]
203+ cls .__dataclass_fields__ [foreign_key_name ] = new_field
204+ setattr (cls , new_field .name , new_field .default )
147205
148206 # Generate the SQLModel class
149207 sqlmodel_cls = type (
@@ -167,16 +225,17 @@ def decorator(cls):
167225 # For SQLModel's SQLModelMetaClass
168226 table = table ,
169227 )
170-
171228 cls .__sqlmodel__ = sqlmodel_cls
172229 # Update type annotations in any class with a relationship with this class to point
173230 # to the SQLModel, not the dataclass
174- for field in fields (cls ):
175- if not field .name in sqlmodel_cls .__sqlmodel_relationships__ :
231+ for cfield in fields (cls ):
232+ if not cfield .name in sqlmodel_cls .__sqlmodel_relationships__ :
176233 continue
177- rel = sqlmodel_cls .__sqlmodel_relationships__ .get (field .name , None )
234+ rel = sqlmodel_cls .__sqlmodel_relationships__ .get (cfield .name , None )
178235 if rel and hasattr (rel , "back_populates" ):
179- patch_back_populates_types (field , rel .back_populates , cls , sqlmodel_cls )
236+ patch_back_populates_types (
237+ cfield , rel .back_populates , cls , sqlmodel_cls
238+ )
180239 cls .sqlmodel = sqlmodel
181240 return cls
182241
0 commit comments