11"""Functions and classes to create and populate the target database."""
2+ from collections import Counter
23from typing import Any , Generator , Mapping , Sequence , Tuple
34
45from sqlalchemy import Connection , insert
1011from sqlsynthgen .utils import create_db_engine , get_sync_engine , logger
1112
1213Story = Generator [Tuple [str , dict [str , Any ]], dict [str , Any ], None ]
14+ RowCounts = Counter [str ]
1315
1416
1517def create_db_tables (metadata : MetaData ) -> None :
@@ -57,7 +59,7 @@ def create_db_data(
5759 table_generator_dict : Mapping [str , TableGenerator ],
5860 story_generator_list : Sequence [Mapping [str , Any ]],
5961 num_passes : int ,
60- ) -> None :
62+ ) -> RowCounts :
6163 """Connect to a database and populate it with data."""
6264 settings = get_settings ()
6365 dst_dsn : str = settings .dst_dsn or ""
@@ -67,27 +69,30 @@ def create_db_data(
6769 create_db_engine (dst_dsn , schema_name = settings .dst_schema )
6870 )
6971
72+ row_counts : Counter [str ] = Counter ()
7073 with dst_engine .connect () as dst_conn :
7174 for _ in range (num_passes ):
72- populate (
75+ row_counts += populate (
7376 dst_conn ,
7477 sorted_tables ,
7578 table_generator_dict ,
7679 story_generator_list ,
7780 )
81+ return row_counts
7882
7983
8084def _populate_story (
8185 story : Story ,
8286 table_dict : Mapping [str , Table ],
8387 table_generator_dict : Mapping [str , TableGenerator ],
8488 dst_conn : Connection ,
85- ) -> None :
89+ ) -> RowCounts :
8690 """Write to the database all the rows created by the given story."""
8791 # Loop over the rows generated by the story, insert them into their
8892 # respective tables. Ideally this would say
8993 # `for table_name, provided_values in story:`
9094 # but we have to loop more manually to be able to use the `send` function.
95+ row_counts : Counter [str ] = Counter ()
9196 try :
9297 table_name , provided_values = next (story )
9398 while True :
@@ -111,19 +116,22 @@ def _populate_story(
111116 else :
112117 return_values = {}
113118 final_values = {** insert_values , ** return_values }
119+ row_counts [table_name ] = row_counts .get (table_name , 0 ) + 1
114120 table_name , provided_values = story .send (final_values )
115121 except StopIteration :
116122 # The story has finished, it has no more rows to generate
117123 pass
124+ return row_counts
118125
119126
120127def populate (
121128 dst_conn : Connection ,
122129 tables : Sequence [Table ],
123130 table_generator_dict : Mapping [str , TableGenerator ],
124131 story_generator_list : Sequence [Mapping [str , Any ]],
125- ) -> None :
132+ ) -> RowCounts :
126133 """Populate a database schema with synthetic data."""
134+ row_counts : Counter [str ] = Counter ()
127135 table_dict = {table .name : table for table in tables }
128136 # Generate stories
129137 # Each story generator returns a python generator (an unfortunate naming clash with
@@ -141,9 +149,11 @@ def populate(
141149 )
142150 for name , story in stories :
143151 # Run the inserts for each story within a transaction.
144- logger .debug (" Generating data for story %s" , name )
152+ logger .debug (' Generating data for story " %s".' , name )
145153 with dst_conn .begin ():
146- _populate_story (story , table_dict , table_generator_dict , dst_conn )
154+ row_counts += _populate_story (
155+ story , table_dict , table_generator_dict , dst_conn
156+ )
147157
148158 # Generate individual rows, table by table.
149159 for table in tables :
@@ -154,9 +164,11 @@ def populate(
154164 table_generator = table_generator_dict [table .name ]
155165 if table_generator .num_rows_per_pass == 0 :
156166 continue
157- logger .debug (" Generating data for table %s" , table .name )
167+ logger .debug (' Generating data for table " %s".' , table .name )
158168 # Run all the inserts for one table in a transaction
159169 with dst_conn .begin ():
160170 for _ in range (table_generator .num_rows_per_pass ):
161171 stmt = insert (table ).values (table_generator (dst_conn ))
162172 dst_conn .execute (stmt )
173+ row_counts [table .name ] = row_counts .get (table .name , 0 ) + 1
174+ return row_counts
0 commit comments