22import logging
33from typing import Any , Dict , Generator , List , Tuple
44
5- from sqlalchemy import insert
5+ from sqlalchemy import Connection , insert
66from sqlalchemy .exc import IntegrityError
7- from sqlalchemy .schema import CreateSchema
7+ from sqlalchemy .schema import CreateSchema , MetaData , Table
88
9+ from sqlsynthgen .base import FileUploader , TableGenerator
910from sqlsynthgen .settings import get_settings
10- from sqlsynthgen .utils import create_db_engine
11+ from sqlsynthgen .utils import create_db_engine , get_sync_engine
1112
1213Story = Generator [Tuple [str , Dict [str , Any ]], Dict [str , Any ], None ]
1314
1415
15- def create_db_tables (metadata : Any ) -> Any :
16+ def create_db_tables (metadata : MetaData ) -> None :
1617 """Create tables described by the sqlalchemy metadata object."""
1718 settings = get_settings ()
1819
19- engine = create_db_engine (settings .dst_dsn ) # type: ignore
20+ engine = get_sync_engine ( create_db_engine (settings .dst_dsn ) ) # type: ignore
2021
2122 # Create schema, if necessary.
2223 if settings .dst_schema :
2324 schema_name = settings .dst_schema
24- if not engine .dialect .has_schema (engine , schema = schema_name ):
25- engine .execute (CreateSchema (schema_name , if_not_exists = True ))
25+ with engine .connect () as connection :
26+ if not engine .dialect .has_schema (connection , schema_name = schema_name ):
27+ connection .execute (CreateSchema (schema_name , if_not_exists = True ))
2628
2729 # Recreate the engine, this time with a schema specified
28- engine = create_db_engine (
29- settings .dst_dsn , schema_name = schema_name # type: ignore
30+ engine = get_sync_engine (
31+ create_db_engine ( settings .dst_dsn , schema_name = schema_name ) # type: ignore
3032 )
3133
3234 metadata .create_all (engine )
3335
3436
35- def create_db_vocab (vocab_dict : Dict [str , Any ]) -> None :
37+ def create_db_vocab (vocab_dict : Dict [str , FileUploader ]) -> None :
3638 """Load vocabulary tables from files."""
3739 settings = get_settings ()
3840
39- dst_engine = create_db_engine (
40- settings .dst_dsn , schema_name = settings .dst_schema # type: ignore
41+ dst_engine = get_sync_engine (
42+ create_db_engine (
43+ settings .dst_dsn , schema_name = settings .dst_schema # type: ignore
44+ )
4145 )
4246
4347 with dst_engine .connect () as dst_conn :
@@ -51,16 +55,18 @@ def create_db_vocab(vocab_dict: Dict[str, Any]) -> None:
5155
5256
5357def create_db_data (
54- sorted_tables : list ,
55- table_generator_dict : dict ,
56- story_generator_list : list ,
58+ sorted_tables : list [ Table ] ,
59+ table_generator_dict : dict [ str , TableGenerator ] ,
60+ story_generator_list : list [ dict [ str , Any ]] ,
5761 num_passes : int ,
5862) -> None :
5963 """Connect to a database and populate it with data."""
6064 settings = get_settings ()
6165
62- dst_engine = create_db_engine (
63- settings .dst_dsn , schema_name = settings .dst_schema # type: ignore
66+ dst_engine = get_sync_engine (
67+ create_db_engine (
68+ settings .dst_dsn , schema_name = settings .dst_schema # type: ignore
69+ )
6470 )
6571
6672 with dst_engine .connect () as dst_conn :
@@ -75,9 +81,9 @@ def create_db_data(
7581
7682def _populate_story (
7783 story : Story ,
78- table_dict : Dict [str , Any ],
79- table_generator_dict : Dict [str , Any ],
80- dst_conn : Any ,
84+ table_dict : Dict [str , Table ],
85+ table_generator_dict : Dict [str , TableGenerator ],
86+ dst_conn : Connection ,
8187) -> None :
8288 """Write to the database all the rows created by the given story."""
8389 # Loop over the rows generated by the story, insert them into their
@@ -100,7 +106,9 @@ def _populate_story(
100106 # because other parts of the story may refer to them.
101107 if cursor .returned_defaults :
102108 # pylint: disable=protected-access
103- return_values = cursor .returned_defaults ._mapping
109+ return_values = {
110+ str (k ): v for k , v in cursor .returned_defaults ._mapping .items ()
111+ }
104112 # pylint: enable=protected-access
105113 else :
106114 return_values = {}
@@ -112,10 +120,10 @@ def _populate_story(
112120
113121
114122def populate (
115- dst_conn : Any ,
116- tables : list ,
117- table_generator_dict : dict ,
118- story_generator_list : list ,
123+ dst_conn : Connection ,
124+ tables : list [ Table ] ,
125+ table_generator_dict : dict [ str , TableGenerator ] ,
126+ story_generator_list : list [ dict [ str , Any ]] ,
119127) -> None :
120128 """Populate a database schema with synthetic data."""
121129 table_dict = {table .name : table for table in tables }
0 commit comments