@@ -1707,6 +1707,36 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl
17071707 self .fields = [field .replace_table (current_table , new_table ) for field in self .fields ]
17081708
17091709
1710+ class ForeignKey :
1711+ """Represents a foreign key constraint."""
1712+
1713+ def __init__ (
1714+ self ,
1715+ columns : List [Column ],
1716+ reference_table : Union [str , Table ],
1717+ reference_columns : List [Column ],
1718+ on_delete : ReferenceOption = None ,
1719+ on_update : ReferenceOption = None ,
1720+ ) -> None :
1721+ self .columns = columns
1722+ self .reference_table = reference_table
1723+ self .reference_columns = reference_columns
1724+ self .on_delete = on_delete
1725+ self .on_update = on_update
1726+
1727+ def get_sql (self , ** kwargs : Any ) -> str :
1728+ foreign_key_sql = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})" .format (
1729+ columns = "," .join (column .get_name_sql (** kwargs ) for column in self .columns ),
1730+ table_name = self .reference_table .get_sql (** kwargs ),
1731+ reference_columns = "," .join (column .get_name_sql (** kwargs ) for column in self .reference_columns ),
1732+ )
1733+ if self .on_delete :
1734+ foreign_key_sql += " ON DELETE " + self .on_delete .value
1735+ if self .on_update :
1736+ foreign_key_sql += " ON UPDATE " + self .on_update .value
1737+ return foreign_key_sql
1738+
1739+
17101740class CreateQueryBuilder :
17111741 """
17121742 Query builder used to build CREATE queries.
@@ -1729,11 +1759,7 @@ def __init__(self, dialect: Optional[Dialects] = None) -> None:
17291759 self ._uniques = []
17301760 self ._if_not_exists = False
17311761 self .dialect = dialect
1732- self ._foreign_key = None
1733- self ._foreign_key_reference_table = None
1734- self ._foreign_key_reference = None
1735- self ._foreign_key_on_update : ReferenceOption = None
1736- self ._foreign_key_on_delete : ReferenceOption = None
1762+ self ._foreign_keys = []
17371763
17381764 def _set_kwargs_defaults (self , kwargs : dict ) -> None :
17391765 kwargs .setdefault ("quote_char" , self .QUOTE_CHAR )
@@ -1908,19 +1934,19 @@ def foreign_key(
19081934
19091935 Update option.
19101936
1911- :raises AttributeError:
1912- If the foreign key is already defined.
1913-
19141937 :return:
19151938 CreateQueryBuilder.
19161939 """
1917- if self ._foreign_key :
1918- raise AttributeError ("'Query' object already has attribute foreign_key" )
1919- self ._foreign_key = self ._prepare_columns_input (columns )
1920- self ._foreign_key_reference_table = reference_table
1921- self ._foreign_key_reference = self ._prepare_columns_input (reference_columns )
1922- self ._foreign_key_on_delete = on_delete
1923- self ._foreign_key_on_update = on_update
1940+
1941+ self ._foreign_keys .append (
1942+ ForeignKey (
1943+ columns = self ._prepare_columns_input (columns ),
1944+ reference_table = reference_table ,
1945+ reference_columns = self ._prepare_columns_input (reference_columns ),
1946+ on_delete = on_delete ,
1947+ on_update = on_update ,
1948+ )
1949+ )
19241950
19251951 @builder
19261952 def as_select (self , query_builder : QueryBuilder ) -> "CreateQueryBuilder" :
@@ -2017,28 +2043,17 @@ def _primary_key_clause(self, **kwargs) -> str:
20172043 columns = "," .join (column .get_name_sql (** kwargs ) for column in self ._primary_key )
20182044 )
20192045
2020- def _foreign_key_clause (self , ** kwargs ) -> str :
2021- clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})" .format (
2022- columns = "," .join (column .get_name_sql (** kwargs ) for column in self ._foreign_key ),
2023- table_name = self ._foreign_key_reference_table .get_sql (** kwargs ),
2024- reference_columns = "," .join (column .get_name_sql (** kwargs ) for column in self ._foreign_key_reference ),
2025- )
2026- if self ._foreign_key_on_delete :
2027- clause += " ON DELETE " + self ._foreign_key_on_delete .value
2028- if self ._foreign_key_on_update :
2029- clause += " ON UPDATE " + self ._foreign_key_on_update .value
2030-
2031- return clause
2046+ def _foreign_key_clauses (self , ** kwargs ) -> str :
2047+ return [foreign_key .get_sql (** kwargs ) for foreign_key in self ._foreign_keys ]
20322048
20332049 def _body_sql (self , ** kwargs ) -> str :
20342050 clauses = self ._column_clauses (** kwargs )
20352051 clauses += self ._period_for_clauses (** kwargs )
20362052 clauses += self ._unique_key_clauses (** kwargs )
2053+ clauses += self ._foreign_key_clauses (** kwargs )
20372054
20382055 if self ._primary_key :
20392056 clauses .append (self ._primary_key_clause (** kwargs ))
2040- if self ._foreign_key :
2041- clauses .append (self ._foreign_key_clause (** kwargs ))
20422057
20432058 return "," .join (clauses )
20442059
0 commit comments