77from pathlib import Path
88from sys import stderr
99from types import ModuleType
10- from typing import Any , Dict , Final , List , Optional , Tuple
10+ from typing import Any , Final , Mapping , Optional , Sequence , Tuple
1111
1212import pandas as pd
1313import snsql
2525from sqlsynthgen .settings import get_settings
2626from sqlsynthgen .utils import create_db_engine , download_table , get_sync_engine
2727
28- PROVIDER_IMPORTS : Final [List [str ]] = []
28+ PROVIDER_IMPORTS : Final [list [str ]] = []
2929for entry_name , entry in inspect .getmembers (providers , inspect .isclass ):
3030 if issubclass (entry , BaseProvider ) and entry .__module__ == "sqlsynthgen.providers" :
3131 PROVIDER_IMPORTS .append (entry_name )
@@ -49,14 +49,14 @@ class FunctionCall:
4949 """Contains the ssg.py content related function calls."""
5050
5151 function_name : str
52- argument_values : List [str ]
52+ argument_values : list [str ]
5353
5454
5555@dataclass
5656class RowGeneratorInfo :
5757 """Contains the ssg.py content related to row generators of a table."""
5858
59- variable_names : List [str ]
59+ variable_names : list [str ]
6060 function_call : FunctionCall
6161 primary_key : bool = False
6262
@@ -68,8 +68,8 @@ class TableGeneratorInfo:
6868 class_name : str
6969 table_name : str
7070 rows_per_pass : int
71- row_gens : List [RowGeneratorInfo ] = field (default_factory = list )
72- unique_constraints : List [UniqueConstraint ] = field (default_factory = list )
71+ row_gens : list [RowGeneratorInfo ] = field (default_factory = list )
72+ unique_constraints : list [UniqueConstraint ] = field (default_factory = list )
7373
7474
7575@dataclass
@@ -100,38 +100,38 @@ def _orm_class_from_table_name(
100100
101101def _get_function_call (
102102 function_name : str ,
103- positional_arguments : Optional [List [Any ]] = None ,
104- keyword_arguments : Optional [Dict [str , Any ]] = None ,
103+ positional_arguments : Optional [Sequence [Any ]] = None ,
104+ keyword_arguments : Optional [Mapping [str , Any ]] = None ,
105105) -> FunctionCall :
106106 if positional_arguments is None :
107107 positional_arguments = []
108108
109109 if keyword_arguments is None :
110110 keyword_arguments = {}
111111
112- argument_values : List [str ] = [str (value ) for value in positional_arguments ]
112+ argument_values : list [str ] = [str (value ) for value in positional_arguments ]
113113 argument_values += [f"{ key } ={ value } " for key , value in keyword_arguments .items ()]
114114
115115 return FunctionCall (function_name = function_name , argument_values = argument_values )
116116
117117
118118def _get_row_generator (
119- table_config : dict [str , Any ],
120- ) -> tuple [List [RowGeneratorInfo ], list [str ]]:
119+ table_config : Mapping [str , Any ],
120+ ) -> tuple [list [RowGeneratorInfo ], list [str ]]:
121121 """Get the row generators information, for the given table."""
122- row_gen_info : List [RowGeneratorInfo ] = []
123- config : List [ Dict [str , Any ]] = table_config .get ("row_generators" , {})
122+ row_gen_info : list [RowGeneratorInfo ] = []
123+ config : list [ dict [str , Any ]] = table_config .get ("row_generators" , {})
124124 columns_covered = []
125125 for gen_conf in config :
126126 name : str = gen_conf ["name" ]
127127 columns_assigned = gen_conf ["columns_assigned" ]
128- keyword_arguments : Dict [str , Any ] = gen_conf .get ("kwargs" , {})
129- positional_arguments : List [str ] = gen_conf .get ("args" , [])
128+ keyword_arguments : Mapping [str , Any ] = gen_conf .get ("kwargs" , {})
129+ positional_arguments : Sequence [str ] = gen_conf .get ("args" , [])
130130
131131 if isinstance (columns_assigned , str ):
132132 columns_assigned = [columns_assigned ]
133133
134- variable_names : List [str ] = columns_assigned
134+ variable_names : list [str ] = columns_assigned
135135 try :
136136 columns_covered += columns_assigned
137137 except TypeError :
@@ -158,9 +158,9 @@ def _get_default_generator(
158158
159159 # If it's a foreign key column, pull random values from the column it
160160 # references.
161- variable_names : List [str ] = []
161+ variable_names : list [str ] = []
162162 generator_function : str = ""
163- generator_arguments : List [str ] = []
163+ generator_arguments : list [str ] = []
164164
165165 if column .foreign_keys :
166166 if len (column .foreign_keys ) > 1 :
@@ -202,19 +202,19 @@ def _get_default_generator(
202202 )
203203
204204
205- def _get_provider_for_column (column : Column ) -> Tuple [List [str ], str , List [str ]]:
205+ def _get_provider_for_column (column : Column ) -> Tuple [list [str ], str , list [str ]]:
206206 """
207207 Get a default Mimesis provider and its arguments for a SQL column type.
208208
209209 Args:
210210 column: SQLAlchemy column object
211211
212212 Returns:
213- Tuple[str, str, List [str]]: Tuple containing the variable names to assign to,
213+ Tuple[str, str, list [str]]: Tuple containing the variable names to assign to,
214214 generator function and any generator arguments.
215215 """
216- variable_names : List [str ] = [column .name ]
217- generator_arguments : List [str ] = []
216+ variable_names : list [str ] = [column .name ]
217+ generator_arguments : list [str ] = []
218218
219219 column_type = type (column .type )
220220 column_size : Optional [int ] = getattr (column .type , "length" , None )
@@ -291,7 +291,7 @@ def _enforce_unique_constraints(table_data: TableGeneratorInfo) -> None:
291291
292292
293293def _get_generator_for_table (
294- tables_module : ModuleType , table_config : dict [str , Any ], table : Table
294+ tables_module : ModuleType , table_config : Mapping [str , Any ], table : Table
295295) -> TableGeneratorInfo :
296296 """Get generator information for the given table."""
297297 unique_constraints = [
@@ -318,7 +318,7 @@ def _get_generator_for_table(
318318 return table_data
319319
320320
321- def _get_story_generators (config : dict ) -> List [StoryGeneratorInfo ]:
321+ def _get_story_generators (config : Mapping ) -> list [StoryGeneratorInfo ]:
322322 """Get story generators."""
323323 generators = []
324324 for gen in config .get ("story_generators" , []):
@@ -339,7 +339,7 @@ def _get_story_generators(config: dict) -> List[StoryGeneratorInfo]:
339339
340340def make_table_generators (
341341 tables_module : ModuleType ,
342- config : dict ,
342+ config : Mapping ,
343343 src_stats_filename : Optional [str ],
344344 overwrite_files : bool = False ,
345345) -> str :
@@ -359,14 +359,13 @@ def make_table_generators(
359359 story_generator_module_name = config .get ("story_generators_module" , None )
360360
361361 settings = get_settings ()
362- engine = get_sync_engine (
363- create_db_engine (
364- settings .src_dsn , schema_name = settings .src_schema # type: ignore
365- )
366- )
362+ src_dsn : str = settings .src_dsn or ""
363+ assert src_dsn != "" , "Missing SRC_DSN setting."
364+
365+ engine = get_sync_engine (create_db_engine (src_dsn , schema_name = settings .src_schema ))
367366
368- tables : List [TableGeneratorInfo ] = []
369- vocabulary_tables : List [VocabularyTableGeneratorInfo ] = []
367+ tables : list [TableGeneratorInfo ] = []
368+ vocabulary_tables : list [VocabularyTableGeneratorInfo ] = []
370369
371370 for table in tables_module .Base .metadata .sorted_tables :
372371 table_config = config .get ("tables" , {}).get (table .name , {})
@@ -398,7 +397,7 @@ def make_table_generators(
398397 )
399398
400399
401- def generate_ssg_content (template_context : Dict [str , Any ]) -> str :
400+ def generate_ssg_content (template_context : Mapping [str , Any ]) -> str :
402401 """Generate the content of the ssg.py file as a string."""
403402 environment : Environment = Environment (
404403 loader = FileSystemLoader (TEMPLATE_DIRECTORY ),
@@ -467,8 +466,8 @@ def make_tables_file(db_dsn: str, schema_name: Optional[str]) -> str:
467466
468467
469468async def make_src_stats (
470- dsn : str , config : dict , schema_name : Optional [str ] = None
471- ) -> Dict [str , List [dict ]]:
469+ dsn : str , config : Mapping , schema_name : Optional [str ] = None
470+ ) -> dict [str , list [dict ]]:
472471 """Run the src-stats queries specified by the configuration.
473472
474473 Query the src database with the queries in the src-stats block of the `config`
@@ -485,7 +484,7 @@ async def make_src_stats(
485484 use_asyncio = config .get ("use-asyncio" , False )
486485 engine = create_db_engine (dsn , schema_name = schema_name , use_asyncio = use_asyncio )
487486
488- async def execute_query (query_block : Dict [str , Any ]) -> Any :
487+ async def execute_query (query_block : Mapping [str , Any ]) -> Any :
489488 """Execute query in query_block."""
490489 query = text (query_block ["query" ])
491490 if isinstance (engine , AsyncEngine ):
0 commit comments