Skip to content

Commit 130a156

Browse files
authored
Merge pull request #142 from alan-turing-institute/sqla2-type-hints-iain
Sqla2 type hints iain
2 parents 0959fa5 + 8eebb1c commit 130a156

File tree

6 files changed

+95
-87
lines changed

6 files changed

+95
-87
lines changed

sqlsynthgen/create.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Functions and classes to create and populate the target database."""
22
import logging
3-
from typing import Any, Dict, Generator, List, Tuple
3+
from typing import Any, Generator, Mapping, Sequence, Tuple
44

55
from sqlalchemy import Connection, insert
66
from sqlalchemy.exc import IntegrityError
@@ -10,14 +10,16 @@
1010
from sqlsynthgen.settings import get_settings
1111
from sqlsynthgen.utils import create_db_engine, get_sync_engine
1212

13-
Story = Generator[Tuple[str, Dict[str, Any]], Dict[str, Any], None]
13+
Story = Generator[Tuple[str, dict[str, Any]], dict[str, Any], None]
1414

1515

1616
def create_db_tables(metadata: MetaData) -> None:
1717
"""Create tables described by the sqlalchemy metadata object."""
1818
settings = get_settings()
19+
dst_dsn: str = settings.dst_dsn or ""
20+
assert dst_dsn != "", "Missing DST_DSN setting."
1921

20-
engine = get_sync_engine(create_db_engine(settings.dst_dsn)) # type: ignore
22+
engine = get_sync_engine(create_db_engine(dst_dsn))
2123

2224
# Create schema, if necessary.
2325
if settings.dst_schema:
@@ -27,21 +29,19 @@ def create_db_tables(metadata: MetaData) -> None:
2729
connection.execute(CreateSchema(schema_name, if_not_exists=True))
2830

2931
# Recreate the engine, this time with a schema specified
30-
engine = get_sync_engine(
31-
create_db_engine(settings.dst_dsn, schema_name=schema_name) # type: ignore
32-
)
32+
engine = get_sync_engine(create_db_engine(dst_dsn, schema_name=schema_name))
3333

3434
metadata.create_all(engine)
3535

3636

37-
def create_db_vocab(vocab_dict: Dict[str, FileUploader]) -> None:
37+
def create_db_vocab(vocab_dict: Mapping[str, FileUploader]) -> None:
3838
"""Load vocabulary tables from files."""
3939
settings = get_settings()
40+
dst_dsn: str = settings.dst_dsn or ""
41+
assert dst_dsn != "", "Missing DST_DSN setting."
4042

4143
dst_engine = get_sync_engine(
42-
create_db_engine(
43-
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
44-
)
44+
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
4545
)
4646

4747
with dst_engine.connect() as dst_conn:
@@ -55,18 +55,18 @@ def create_db_vocab(vocab_dict: Dict[str, FileUploader]) -> None:
5555

5656

5757
def create_db_data(
58-
sorted_tables: list[Table],
59-
table_generator_dict: dict[str, TableGenerator],
60-
story_generator_list: list[dict[str, Any]],
58+
sorted_tables: Sequence[Table],
59+
table_generator_dict: Mapping[str, TableGenerator],
60+
story_generator_list: Sequence[Mapping[str, Any]],
6161
num_passes: int,
6262
) -> None:
6363
"""Connect to a database and populate it with data."""
6464
settings = get_settings()
65+
dst_dsn: str = settings.dst_dsn or ""
66+
assert dst_dsn != "", "Missing DST_DSN setting."
6567

6668
dst_engine = get_sync_engine(
67-
create_db_engine(
68-
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
69-
)
69+
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
7070
)
7171

7272
with dst_engine.connect() as dst_conn:
@@ -81,8 +81,8 @@ def create_db_data(
8181

8282
def _populate_story(
8383
story: Story,
84-
table_dict: Dict[str, Table],
85-
table_generator_dict: Dict[str, TableGenerator],
84+
table_dict: Mapping[str, Table],
85+
table_generator_dict: Mapping[str, TableGenerator],
8686
dst_conn: Connection,
8787
) -> None:
8888
"""Write to the database all the rows created by the given story."""
@@ -121,17 +121,17 @@ def _populate_story(
121121

122122
def populate(
123123
dst_conn: Connection,
124-
tables: list[Table],
125-
table_generator_dict: dict[str, TableGenerator],
126-
story_generator_list: list[dict[str, Any]],
124+
tables: Sequence[Table],
125+
table_generator_dict: Mapping[str, TableGenerator],
126+
story_generator_list: Sequence[Mapping[str, Any]],
127127
) -> None:
128128
"""Populate a database schema with synthetic data."""
129129
table_dict = {table.name: table for table in tables}
130130
# Generate stories
131131
# Each story generator returns a python generator (an unfortunate naming clash with
132132
# what we call generators). Iterating over it yields individual rows for the
133133
# database. First, collect all of the python generators into a single list.
134-
stories: List[Story] = sum(
134+
stories: list[Story] = sum(
135135
[
136136
[sg["name"](dst_conn) for _ in range(sg["num_stories_per_pass"])]
137137
for sg in story_generator_list

sqlsynthgen/make.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
from sys import stderr
99
from types import ModuleType
10-
from typing import Any, Dict, Final, List, Optional, Tuple
10+
from typing import Any, Final, Mapping, Optional, Sequence, Tuple
1111

1212
import pandas as pd
1313
import snsql
@@ -25,7 +25,7 @@
2525
from sqlsynthgen.settings import get_settings
2626
from sqlsynthgen.utils import create_db_engine, download_table, get_sync_engine
2727

28-
PROVIDER_IMPORTS: Final[List[str]] = []
28+
PROVIDER_IMPORTS: Final[list[str]] = []
2929
for entry_name, entry in inspect.getmembers(providers, inspect.isclass):
3030
if issubclass(entry, BaseProvider) and entry.__module__ == "sqlsynthgen.providers":
3131
PROVIDER_IMPORTS.append(entry_name)
@@ -49,14 +49,14 @@ class FunctionCall:
4949
"""Contains the ssg.py content related function calls."""
5050

5151
function_name: str
52-
argument_values: List[str]
52+
argument_values: list[str]
5353

5454

5555
@dataclass
5656
class RowGeneratorInfo:
5757
"""Contains the ssg.py content related to row generators of a table."""
5858

59-
variable_names: List[str]
59+
variable_names: list[str]
6060
function_call: FunctionCall
6161
primary_key: bool = False
6262

@@ -68,8 +68,8 @@ class TableGeneratorInfo:
6868
class_name: str
6969
table_name: str
7070
rows_per_pass: int
71-
row_gens: List[RowGeneratorInfo] = field(default_factory=list)
72-
unique_constraints: List[UniqueConstraint] = field(default_factory=list)
71+
row_gens: list[RowGeneratorInfo] = field(default_factory=list)
72+
unique_constraints: list[UniqueConstraint] = field(default_factory=list)
7373

7474

7575
@dataclass
@@ -100,38 +100,38 @@ def _orm_class_from_table_name(
100100

101101
def _get_function_call(
102102
function_name: str,
103-
positional_arguments: Optional[List[Any]] = None,
104-
keyword_arguments: Optional[Dict[str, Any]] = None,
103+
positional_arguments: Optional[Sequence[Any]] = None,
104+
keyword_arguments: Optional[Mapping[str, Any]] = None,
105105
) -> FunctionCall:
106106
if positional_arguments is None:
107107
positional_arguments = []
108108

109109
if keyword_arguments is None:
110110
keyword_arguments = {}
111111

112-
argument_values: List[str] = [str(value) for value in positional_arguments]
112+
argument_values: list[str] = [str(value) for value in positional_arguments]
113113
argument_values += [f"{key}={value}" for key, value in keyword_arguments.items()]
114114

115115
return FunctionCall(function_name=function_name, argument_values=argument_values)
116116

117117

118118
def _get_row_generator(
119-
table_config: dict[str, Any],
120-
) -> tuple[List[RowGeneratorInfo], list[str]]:
119+
table_config: Mapping[str, Any],
120+
) -> tuple[list[RowGeneratorInfo], list[str]]:
121121
"""Get the row generators information, for the given table."""
122-
row_gen_info: List[RowGeneratorInfo] = []
123-
config: List[Dict[str, Any]] = table_config.get("row_generators", {})
122+
row_gen_info: list[RowGeneratorInfo] = []
123+
config: list[dict[str, Any]] = table_config.get("row_generators", {})
124124
columns_covered = []
125125
for gen_conf in config:
126126
name: str = gen_conf["name"]
127127
columns_assigned = gen_conf["columns_assigned"]
128-
keyword_arguments: Dict[str, Any] = gen_conf.get("kwargs", {})
129-
positional_arguments: List[str] = gen_conf.get("args", [])
128+
keyword_arguments: Mapping[str, Any] = gen_conf.get("kwargs", {})
129+
positional_arguments: Sequence[str] = gen_conf.get("args", [])
130130

131131
if isinstance(columns_assigned, str):
132132
columns_assigned = [columns_assigned]
133133

134-
variable_names: List[str] = columns_assigned
134+
variable_names: list[str] = columns_assigned
135135
try:
136136
columns_covered += columns_assigned
137137
except TypeError:
@@ -158,9 +158,9 @@ def _get_default_generator(
158158

159159
# If it's a foreign key column, pull random values from the column it
160160
# references.
161-
variable_names: List[str] = []
161+
variable_names: list[str] = []
162162
generator_function: str = ""
163-
generator_arguments: List[str] = []
163+
generator_arguments: list[str] = []
164164

165165
if column.foreign_keys:
166166
if len(column.foreign_keys) > 1:
@@ -202,19 +202,19 @@ def _get_default_generator(
202202
)
203203

204204

205-
def _get_provider_for_column(column: Column) -> Tuple[List[str], str, List[str]]:
205+
def _get_provider_for_column(column: Column) -> Tuple[list[str], str, list[str]]:
206206
"""
207207
Get a default Mimesis provider and its arguments for a SQL column type.
208208
209209
Args:
210210
column: SQLAlchemy column object
211211
212212
Returns:
213-
Tuple[str, str, List[str]]: Tuple containing the variable names to assign to,
213+
Tuple[str, str, list[str]]: Tuple containing the variable names to assign to,
214214
generator function and any generator arguments.
215215
"""
216-
variable_names: List[str] = [column.name]
217-
generator_arguments: List[str] = []
216+
variable_names: list[str] = [column.name]
217+
generator_arguments: list[str] = []
218218

219219
column_type = type(column.type)
220220
column_size: Optional[int] = getattr(column.type, "length", None)
@@ -291,7 +291,7 @@ def _enforce_unique_constraints(table_data: TableGeneratorInfo) -> None:
291291

292292

293293
def _get_generator_for_table(
294-
tables_module: ModuleType, table_config: dict[str, Any], table: Table
294+
tables_module: ModuleType, table_config: Mapping[str, Any], table: Table
295295
) -> TableGeneratorInfo:
296296
"""Get generator information for the given table."""
297297
unique_constraints = [
@@ -318,7 +318,7 @@ def _get_generator_for_table(
318318
return table_data
319319

320320

321-
def _get_story_generators(config: dict) -> List[StoryGeneratorInfo]:
321+
def _get_story_generators(config: Mapping) -> list[StoryGeneratorInfo]:
322322
"""Get story generators."""
323323
generators = []
324324
for gen in config.get("story_generators", []):
@@ -339,7 +339,7 @@ def _get_story_generators(config: dict) -> List[StoryGeneratorInfo]:
339339

340340
def make_table_generators(
341341
tables_module: ModuleType,
342-
config: dict,
342+
config: Mapping,
343343
src_stats_filename: Optional[str],
344344
overwrite_files: bool = False,
345345
) -> str:
@@ -359,14 +359,13 @@ def make_table_generators(
359359
story_generator_module_name = config.get("story_generators_module", None)
360360

361361
settings = get_settings()
362-
engine = get_sync_engine(
363-
create_db_engine(
364-
settings.src_dsn, schema_name=settings.src_schema # type: ignore
365-
)
366-
)
362+
src_dsn: str = settings.src_dsn or ""
363+
assert src_dsn != "", "Missing SRC_DSN setting."
364+
365+
engine = get_sync_engine(create_db_engine(src_dsn, schema_name=settings.src_schema))
367366

368-
tables: List[TableGeneratorInfo] = []
369-
vocabulary_tables: List[VocabularyTableGeneratorInfo] = []
367+
tables: list[TableGeneratorInfo] = []
368+
vocabulary_tables: list[VocabularyTableGeneratorInfo] = []
370369

371370
for table in tables_module.Base.metadata.sorted_tables:
372371
table_config = config.get("tables", {}).get(table.name, {})
@@ -398,7 +397,7 @@ def make_table_generators(
398397
)
399398

400399

401-
def generate_ssg_content(template_context: Dict[str, Any]) -> str:
400+
def generate_ssg_content(template_context: Mapping[str, Any]) -> str:
402401
"""Generate the content of the ssg.py file as a string."""
403402
environment: Environment = Environment(
404403
loader=FileSystemLoader(TEMPLATE_DIRECTORY),
@@ -467,8 +466,8 @@ def make_tables_file(db_dsn: str, schema_name: Optional[str]) -> str:
467466

468467

469468
async def make_src_stats(
470-
dsn: str, config: dict, schema_name: Optional[str] = None
471-
) -> Dict[str, List[dict]]:
469+
dsn: str, config: Mapping, schema_name: Optional[str] = None
470+
) -> dict[str, list[dict]]:
472471
"""Run the src-stats queries specified by the configuration.
473472
474473
Query the src database with the queries in the src-stats block of the `config`
@@ -485,7 +484,7 @@ async def make_src_stats(
485484
use_asyncio = config.get("use-asyncio", False)
486485
engine = create_db_engine(dsn, schema_name=schema_name, use_asyncio=use_asyncio)
487486

488-
async def execute_query(query_block: Dict[str, Any]) -> Any:
487+
async def execute_query(query_block: Mapping[str, Any]) -> Any:
489488
"""Execute query in query_block."""
490489
query = text(query_block["query"])
491490
if isinstance(engine, AsyncEngine):

sqlsynthgen/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def create_db_engine(
9393
db_dsn: str,
9494
schema_name: Optional[str] = None,
9595
use_asyncio: bool = False,
96-
**kwargs: dict,
96+
**kwargs: Any,
9797
) -> MaybeAsyncEngine:
9898
"""Create a SQLAlchemy Engine."""
9999
if use_asyncio:

0 commit comments

Comments
 (0)