Skip to content

Commit b5fdcc8

Browse files
authored
Merge pull request #140 from alan-turing-institute/sqlalchemy2-iain
column defaults
2 parents 7b5bea3 + 050cec6 commit b5fdcc8

File tree

4 files changed

+93
-36
lines changed

4 files changed

+93
-36
lines changed

.github/workflows/pre-commit.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ jobs:
5757
uses: actions/cache@v3
5858
with:
5959
path: ${{ env.PRE_COMMIT_HOME }}
60-
key: hooks-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}
60+
key: hooks-${{ runner.os }}-${{ hashFiles('.pre-commit-config.yaml') }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('poetry.lock') }}
6161
- name: Install Pre-Commit Hooks
6262
shell: bash
6363
if: steps.pre-commit-cache.outputs.cache-hit != 'true'

poetry.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sqlsynthgen/create.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ def _populate_story(
9494
else:
9595
default_values = {}
9696
insert_values = {**default_values, **provided_values}
97-
stmt = insert(table).values(insert_values)
97+
stmt = insert(table).values(insert_values).return_defaults()
9898
cursor = dst_conn.execute(stmt)
9999
# We need to return all the default values etc. to the generator,
100100
# because other parts of the story may refer to them.
101101
if cursor.returned_defaults:
102102
# pylint: disable=protected-access
103-
return_values = dict(cursor.returned_defaults._mapping.items())
103+
return_values = cursor.returned_defaults._mapping
104104
# pylint: enable=protected-access
105105
else:
106106
return_values = {}

tests/test_create.py

Lines changed: 88 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
"""Tests for the create module."""
22
import itertools as itt
3+
from pathlib import Path
34
from typing import Any, Generator, Tuple
45
from unittest.mock import MagicMock, call, patch
56

7+
from sqlalchemy import Column, Integer, create_engine
8+
from sqlalchemy.orm import declarative_base
9+
610
from sqlsynthgen.create import (
11+
Story,
12+
_populate_story,
713
create_db_data,
814
create_db_tables,
915
create_db_vocab,
1016
populate,
1117
)
12-
from tests.utils import SSGTestCase, get_test_settings
18+
from tests.utils import RequiresDBTestCase, SSGTestCase, get_test_settings, run_psql
1319

1420

1521
class MyTestCase(SSGTestCase):
@@ -48,8 +54,7 @@ def test_create_db_tables(
4854
)
4955
mock_meta.create_all.assert_called_once_with(mock_create_engine.return_value)
5056

51-
@patch("sqlsynthgen.create.insert")
52-
def test_populate(self, mock_insert: MagicMock) -> None:
57+
def test_populate(self) -> None:
5358
"""Test the populate function."""
5459
table_name = "table_name"
5560

@@ -62,34 +67,47 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
6267
return story()
6368

6469
for num_stories_per_pass, num_rows_per_pass in itt.product([0, 2], [0, 3]):
65-
mock_dst_conn = MagicMock()
66-
mock_dst_conn.execute.return_value.returned_defaults = {}
67-
mock_table = MagicMock()
68-
mock_table.name = table_name
69-
mock_gen = MagicMock()
70-
mock_gen.num_rows_per_pass = num_rows_per_pass
71-
mock_gen.return_value = {}
72-
73-
tables = [mock_table]
74-
row_generators = {table_name: mock_gen}
75-
story_generators = (
76-
[{"name": mock_story_gen, "num_stories_per_pass": num_stories_per_pass}]
77-
if num_stories_per_pass > 0
78-
else []
79-
)
80-
populate(mock_dst_conn, tables, row_generators, story_generators)
81-
82-
mock_gen.assert_has_calls(
83-
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass)
84-
)
85-
mock_insert.return_value.values.assert_has_calls(
86-
[call(mock_gen.return_value)]
87-
* (num_stories_per_pass + num_rows_per_pass)
88-
)
89-
mock_dst_conn.execute.assert_has_calls(
90-
[call(mock_insert.return_value.values.return_value)]
91-
* (num_stories_per_pass + num_rows_per_pass)
92-
)
70+
with patch("sqlsynthgen.create.insert") as mock_insert:
71+
mock_values = mock_insert.return_value.values
72+
mock_dst_conn = MagicMock()
73+
mock_dst_conn.execute.return_value.returned_defaults = {}
74+
mock_table = MagicMock()
75+
mock_table.name = table_name
76+
mock_gen = MagicMock()
77+
mock_gen.num_rows_per_pass = num_rows_per_pass
78+
mock_gen.return_value = {}
79+
80+
tables = [mock_table]
81+
row_generators = {table_name: mock_gen}
82+
story_generators = (
83+
[
84+
{
85+
"name": mock_story_gen,
86+
"num_stories_per_pass": num_stories_per_pass,
87+
}
88+
]
89+
if num_stories_per_pass > 0
90+
else []
91+
)
92+
populate(mock_dst_conn, tables, row_generators, story_generators)
93+
94+
self.assertListEqual(
95+
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass),
96+
mock_gen.call_args_list,
97+
)
98+
self.assertListEqual(
99+
[call(mock_gen.return_value)]
100+
* (num_stories_per_pass + num_rows_per_pass),
101+
mock_values.call_args_list,
102+
)
103+
self.assertListEqual(
104+
(
105+
[call(mock_values.return_value.return_defaults.return_value)]
106+
* num_stories_per_pass
107+
)
108+
+ ([call(mock_values.return_value)] * num_rows_per_pass),
109+
mock_dst_conn.execute.call_args_list,
110+
)
93111

94112
@patch("sqlsynthgen.create.insert")
95113
def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
@@ -131,3 +149,42 @@ def test_create_db_vocab(
131149
)
132150
# Running the same insert twice should be fine.
133151
create_db_vocab(vocab_list)
152+
153+
154+
class TestStoryDefaults(RequiresDBTestCase):
155+
"""Test that we can handle column defaults in stories."""
156+
157+
# pylint: disable=invalid-name
158+
Base = declarative_base()
159+
# pylint: enable=invalid-name
160+
metadata = Base.metadata
161+
162+
class ColumnDefaultsTable(Base): # type: ignore
163+
"""A SQLAlchemy model."""
164+
165+
__tablename__ = "column_defaults"
166+
someval = Column(Integer, primary_key=True)
167+
otherval = Column(Integer, server_default="8")
168+
169+
def setUp(self) -> None:
170+
"""Ensure we have an empty DB to work with."""
171+
dump_file_path = Path("dst.dump")
172+
examples_dir = Path("tests/examples")
173+
run_psql(examples_dir / dump_file_path)
174+
175+
def test_populate(self) -> None:
176+
"""Check that we can populate a table that has column defaults."""
177+
engine = create_engine(
178+
"postgresql://postgres:password@localhost:5432/dst",
179+
)
180+
self.metadata.create_all(engine)
181+
182+
def my_story() -> Story:
183+
"""A story generator."""
184+
first_row = yield "column_defaults", {}
185+
self.assertEqual(1, first_row["someval"])
186+
self.assertEqual(8, first_row["otherval"])
187+
188+
with engine.connect() as conn:
189+
with conn.begin():
190+
_populate_story(my_story(), dict(self.metadata.tables), {}, conn)

0 commit comments

Comments
 (0)