Skip to content

Commit 0959fa5

Browse files
authored
Merge pull request #138 from alan-turing-institute/sqla2-type-hints
Add more type hints
2 parents 8259c0f + 27eb43c commit 0959fa5

File tree

12 files changed

+317
-258
lines changed

12 files changed

+317
-258
lines changed

poetry.lock

Lines changed: 151 additions & 159 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlsynthgen/base.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,41 @@
11
"""Base table generator classes."""
22
import logging
3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass
45
from pathlib import Path
56
from typing import Any
67

78
import yaml
8-
from sqlalchemy import insert
9+
from sqlalchemy import Connection, insert
910
from sqlalchemy.exc import SQLAlchemyError
11+
from sqlalchemy.schema import Table
12+
13+
14+
class TableGenerator(ABC):
15+
"""Abstract base class for table generator classes."""
16+
17+
num_rows_per_pass: int = 1
18+
19+
@abstractmethod
20+
def __call__(self, dst_db_conn: Connection) -> dict[str, Any]:
21+
"""Return, as a dictionary, a new row for the table that we are generating.
22+
23+
The only argument, `dst_db_conn`, should be a database connection to the
24+
database to which the data is being written. Most generators won't use it, but
25+
some do, and thus it's required by the interface.
26+
27+
The return value should be a dictionary with column names as strings for keys,
28+
and the values being the values for the new row.
29+
"""
1030

1131

1232
@dataclass
1333
class FileUploader:
1434
"""For uploading data files."""
1535

16-
table: Any
36+
table: Table
1737

18-
def load(self, connection: Any) -> None:
38+
def load(self, connection: Connection) -> None:
1939
"""Load the data from file."""
2040
yaml_file = Path(self.table.fullname + ".yaml")
2141
if not yaml_file.exists():

sqlsynthgen/create.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,46 @@
22
import logging
33
from typing import Any, Dict, Generator, List, Tuple
44

5-
from sqlalchemy import insert
5+
from sqlalchemy import Connection, insert
66
from sqlalchemy.exc import IntegrityError
7-
from sqlalchemy.schema import CreateSchema
7+
from sqlalchemy.schema import CreateSchema, MetaData, Table
88

9+
from sqlsynthgen.base import FileUploader, TableGenerator
910
from sqlsynthgen.settings import get_settings
10-
from sqlsynthgen.utils import create_db_engine
11+
from sqlsynthgen.utils import create_db_engine, get_sync_engine
1112

1213
Story = Generator[Tuple[str, Dict[str, Any]], Dict[str, Any], None]
1314

1415

15-
def create_db_tables(metadata: Any) -> Any:
16+
def create_db_tables(metadata: MetaData) -> None:
1617
"""Create tables described by the sqlalchemy metadata object."""
1718
settings = get_settings()
1819

19-
engine = create_db_engine(settings.dst_dsn) # type: ignore
20+
engine = get_sync_engine(create_db_engine(settings.dst_dsn)) # type: ignore
2021

2122
# Create schema, if necessary.
2223
if settings.dst_schema:
2324
schema_name = settings.dst_schema
24-
if not engine.dialect.has_schema(engine, schema=schema_name):
25-
engine.execute(CreateSchema(schema_name, if_not_exists=True))
25+
with engine.connect() as connection:
26+
if not engine.dialect.has_schema(connection, schema_name=schema_name):
27+
connection.execute(CreateSchema(schema_name, if_not_exists=True))
2628

2729
# Recreate the engine, this time with a schema specified
28-
engine = create_db_engine(
29-
settings.dst_dsn, schema_name=schema_name # type: ignore
30+
engine = get_sync_engine(
31+
create_db_engine(settings.dst_dsn, schema_name=schema_name) # type: ignore
3032
)
3133

3234
metadata.create_all(engine)
3335

3436

35-
def create_db_vocab(vocab_dict: Dict[str, Any]) -> None:
37+
def create_db_vocab(vocab_dict: Dict[str, FileUploader]) -> None:
3638
"""Load vocabulary tables from files."""
3739
settings = get_settings()
3840

39-
dst_engine = create_db_engine(
40-
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
41+
dst_engine = get_sync_engine(
42+
create_db_engine(
43+
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
44+
)
4145
)
4246

4347
with dst_engine.connect() as dst_conn:
@@ -51,16 +55,18 @@ def create_db_vocab(vocab_dict: Dict[str, Any]) -> None:
5155

5256

5357
def create_db_data(
54-
sorted_tables: list,
55-
table_generator_dict: dict,
56-
story_generator_list: list,
58+
sorted_tables: list[Table],
59+
table_generator_dict: dict[str, TableGenerator],
60+
story_generator_list: list[dict[str, Any]],
5761
num_passes: int,
5862
) -> None:
5963
"""Connect to a database and populate it with data."""
6064
settings = get_settings()
6165

62-
dst_engine = create_db_engine(
63-
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
66+
dst_engine = get_sync_engine(
67+
create_db_engine(
68+
settings.dst_dsn, schema_name=settings.dst_schema # type: ignore
69+
)
6470
)
6571

6672
with dst_engine.connect() as dst_conn:
@@ -75,9 +81,9 @@ def create_db_data(
7581

7682
def _populate_story(
7783
story: Story,
78-
table_dict: Dict[str, Any],
79-
table_generator_dict: Dict[str, Any],
80-
dst_conn: Any,
84+
table_dict: Dict[str, Table],
85+
table_generator_dict: Dict[str, TableGenerator],
86+
dst_conn: Connection,
8187
) -> None:
8288
"""Write to the database all the rows created by the given story."""
8389
# Loop over the rows generated by the story, insert them into their
@@ -100,7 +106,9 @@ def _populate_story(
100106
# because other parts of the story may refer to them.
101107
if cursor.returned_defaults:
102108
# pylint: disable=protected-access
103-
return_values = cursor.returned_defaults._mapping
109+
return_values = {
110+
str(k): v for k, v in cursor.returned_defaults._mapping.items()
111+
}
104112
# pylint: enable=protected-access
105113
else:
106114
return_values = {}
@@ -112,10 +120,10 @@ def _populate_story(
112120

113121

114122
def populate(
115-
dst_conn: Any,
116-
tables: list,
117-
table_generator_dict: dict,
118-
story_generator_list: list,
123+
dst_conn: Connection,
124+
tables: list[Table],
125+
table_generator_dict: dict[str, TableGenerator],
126+
story_generator_list: list[dict[str, Any]],
119127
) -> None:
120128
"""Populate a database schema with synthetic data."""
121129
table_dict = {table.name: table for table in tables}

sqlsynthgen/main.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import yaml
1111
from jsonschema.exceptions import ValidationError
1212
from jsonschema.validators import validate
13+
from sqlalchemy.schema import MetaData
1314
from typer import Option, Typer, echo
1415

1516
from sqlsynthgen.create import create_db_data, create_db_tables, create_db_vocab
@@ -82,10 +83,13 @@ def create_data(
8283
echo("Creating data.")
8384
orm_module = import_file(orm_file)
8485
ssg_module = import_file(ssg_file)
86+
orm_metadata: MetaData = orm_module.Base.metadata
87+
table_generator_dict = ssg_module.table_generator_dict
88+
story_generator_list = ssg_module.story_generator_list
8589
create_db_data(
86-
orm_module.Base.metadata.sorted_tables,
87-
ssg_module.table_generator_dict,
88-
ssg_module.story_generator_list,
90+
orm_metadata.sorted_tables,
91+
table_generator_dict,
92+
story_generator_list,
8993
num_passes,
9094
)
9195
if verbose:
@@ -136,7 +140,8 @@ def create_tables(
136140
echo("Creating tables.")
137141

138142
orm_module = import_file(orm_file)
139-
create_db_tables(orm_module.Base.metadata)
143+
orm_metadata: MetaData = orm_module.Base.metadata
144+
create_db_tables(orm_metadata)
140145

141146
if verbose:
142147
echo("Tables created.")

0 commit comments

Comments
 (0)