Skip to content

Commit 8259c0f

Browse files
authored
Merge pull request #137 from alan-turing-institute/sqlalchemy2
Sqlalchemy2
2 parents 35bdc2d + b5fdcc8 commit 8259c0f

File tree

17 files changed

+521
-431
lines changed

17 files changed

+521
-431
lines changed

.github/workflows/pre-commit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ jobs:
5757
uses: actions/cache@v3
5858
with:
5959
path: ${{ env.PRE_COMMIT_HOME }}
60-
key: hooks-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
60+
key: hooks-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }}
6161
- name: Install Pre-Commit Hooks
6262
shell: bash
6363
if: steps.pre-commit-cache.outputs.cache-hit != 'true'

mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ warn_return_any = True
1010
show_error_codes = True
1111
warn_unused_ignores = True
1212
ignore_missing_imports = True
13+
plugins = pydantic.mypy

poetry.lock

Lines changed: 381 additions & 375 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ sqlalchemy-utils = "^0.38.3"
2323
mimesis = "^6.1.1"
2424
typer = "^0.7.0"
2525
pyyaml = "^5.0"
26-
sqlalchemy = {version = "^1.4", extras = ["asyncio"]}
26+
sqlalchemy = "^2"
2727
sphinx-rtd-theme = {version = "^1.2.0", optional = true}
2828
sphinxcontrib-napoleon = {version = "^0.7", optional = true}
29-
smartnoise-sql = "^0.2.11"
29+
smartnoise-sql = "^1"
3030
jinja2 = "^3.1.2"
3131
black = "^23.3.0"
3232
jsonschema = "^4.17.3"
33-
sqlacodegen = "3.0.0rc1"
33+
sqlacodegen = "^3.0.0rc3"
3434
asyncpg = "^0.27.0"
3535
greenlet = "^2.0.2"
3636
pymysql = "^1.1.0"
@@ -39,7 +39,7 @@ pandas = "^2"
3939
[tool.poetry.group.dev.dependencies]
4040
isort = "^5.10.1"
4141
pylint = "^2.15.8"
42-
mypy = "^0.991"
42+
mypy = "^1.5"
4343
types-pyyaml = "^6.0.12.4"
4444
pydocstyle = "^6.3.0"
4545
restructuredtext-lint = "^1.4.0"

sqlsynthgen/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def load(self, connection: Any) -> None:
3535
try:
3636
stmt = insert(self.table).values(list(rows))
3737
connection.execute(stmt)
38+
connection.commit()
3839
except SQLAlchemyError as e:
3940
logging.warning(
4041
"Error inserting rows into table %s: %s", self.table.fullname, e

sqlsynthgen/create.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,16 @@ def _populate_story(
9494
else:
9595
default_values = {}
9696
insert_values = {**default_values, **provided_values}
97-
stmt = insert(table).values(insert_values)
97+
stmt = insert(table).values(insert_values).return_defaults()
9898
cursor = dst_conn.execute(stmt)
9999
# We need to return all the default values etc. to the generator,
100100
# because other parts of the story may refer to them.
101-
return_values = dict(cursor.returned_defaults or {})
101+
if cursor.returned_defaults:
102+
# pylint: disable=protected-access
103+
return_values = cursor.returned_defaults._mapping
104+
# pylint: enable=protected-access
105+
else:
106+
return_values = {}
102107
final_values = {**insert_values, **return_values}
103108
table_name, provided_values = story.send(final_values)
104109
except StopIteration:

sqlsynthgen/make.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def _get_provider_for_column(column: Any) -> Tuple[List[str], str, List[str]]:
222222
(sqltypes.DateTime, False): "generic.datetime.datetime",
223223
(sqltypes.Numeric, False): "generic.numeric.float_number",
224224
(sqltypes.LargeBinary, False): "generic.bytes_provider.bytes",
225+
(sqltypes.Uuid, False): "generic.cryptographic.uuid",
225226
(postgresql.UUID, False): "generic.cryptographic.uuid",
226227
(sqltypes.String, False): "generic.text.color",
227228
(sqltypes.String, True): "generic.person.password",

sqlsynthgen/providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from mimesis import Datetime, Text
77
from mimesis.providers.base import BaseDataProvider, BaseProvider
8-
from sqlalchemy.sql import func, select
8+
from sqlalchemy.sql import functions, select
99

1010

1111
class ColumnValueProvider(BaseProvider):
@@ -18,7 +18,7 @@ class Meta:
1818

1919
def column_value(self, db_connection: Any, orm_class: Any, column_name: str) -> Any:
2020
"""Return a random value from the column specified."""
21-
query = select(orm_class).order_by(func.random()).limit(1)
21+
query = select(orm_class).order_by(functions.random()).limit(1)
2222
random_row = db_connection.execute(query).first()
2323

2424
if random_row:

sqlsynthgen/remove.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def remove_db_data(orm_module: ModuleType, ssg_module: ModuleType) -> None:
1919
# We presume that all tables that aren't vocab should be truncated
2020
if table.name not in ssg_module.vocab_dict:
2121
dst_conn.execute(delete(table))
22+
dst_conn.commit()
2223

2324

2425
def remove_db_vocab(orm_module: ModuleType, ssg_module: ModuleType) -> None:
@@ -33,6 +34,7 @@ def remove_db_vocab(orm_module: ModuleType, ssg_module: ModuleType) -> None:
3334
# We presume that all tables that are vocab should be truncated
3435
if table.name in ssg_module.vocab_dict:
3536
dst_conn.execute(delete(table))
37+
dst_conn.commit()
3638

3739

3840
def remove_db_tables(orm_module: ModuleType) -> None:

sqlsynthgen/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from importlib import import_module
66
from pathlib import Path
77
from types import ModuleType
8-
from typing import Any, Final, Optional
8+
from typing import Any, Final, Optional, Union
99

1010
import yaml
1111
from jsonschema.exceptions import ValidationError
@@ -64,11 +64,11 @@ def import_file(file_path: str) -> ModuleType:
6464
return module
6565

6666

67-
def download_table(table: Any, engine: Any, yaml_file_name: str) -> None:
67+
def download_table(table: Any, engine: Any, yaml_file_name: Union[str, Path]) -> None:
6868
"""Download a Table and store it as a .yaml file."""
69-
stmt = select([table])
69+
stmt = select(table)
7070
with engine.connect() as conn:
71-
result = [dict(row) for row in conn.execute(stmt)]
71+
result = [dict(row) for row in conn.execute(stmt).mappings()]
7272

7373
with Path(yaml_file_name).open("w", newline="", encoding="utf-8") as yamlfile:
7474
yamlfile.write(yaml.dump(result))
@@ -83,7 +83,7 @@ def create_db_engine(
8383
"""Create a SQLAlchemy Engine."""
8484
if use_asyncio:
8585
async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://")
86-
engine = create_async_engine(async_dsn, **kwargs)
86+
engine: Any = create_async_engine(async_dsn, **kwargs)
8787
event_engine = engine.sync_engine
8888
else:
8989
engine = create_engine(db_dsn, **kwargs)
@@ -93,7 +93,7 @@ def create_db_engine(
9393

9494
@event.listens_for(event_engine, "connect", insert=True)
9595
def connect(dbapi_connection: Any, _: Any) -> None:
96-
set_search_path(dbapi_connection, schema_name) # type: ignore
96+
set_search_path(dbapi_connection, schema_name)
9797

9898
return engine
9999

@@ -105,7 +105,8 @@ def set_search_path(connection: Any, schema: str) -> None:
105105
connection.autocommit = True
106106

107107
cursor = connection.cursor()
108-
cursor.execute("SET search_path to %s;", (schema,))
108+
# Parametrised queries don't work with asyncpg, hence the f-string.
109+
cursor.execute(f"SET search_path TO {schema};")
109110
cursor.close()
110111

111112
connection.autocommit = existing_autocommit

0 commit comments

Comments
 (0)