Skip to content

Commit 9142d91

Browse files
authored
Merge pull request #166 from alan-turing-institute/log-row-counts-iain
Log row counts (Iain's additions)
2 parents 9b5592a + 8596657 commit 9142d91

File tree

7 files changed

+91
-70
lines changed

7 files changed

+91
-70
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sqlsynthgen"
3-
version = "0.3.3"
3+
version = "0.4.1"
44
description = "Synthetic SQL data generator"
55
authors = ["Iain <[email protected]>"]
66
license = "MIT"

sqlsynthgen/create.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Functions and classes to create and populate the target database."""
2-
from typing import Any, Generator, Mapping, Optional, Sequence, Tuple
2+
from collections import Counter
3+
from typing import Any, Generator, Mapping, Sequence, Tuple
34

45
from sqlalchemy import Connection, insert
56
from sqlalchemy.exc import IntegrityError
@@ -10,7 +11,7 @@
1011
from sqlsynthgen.utils import create_db_engine, get_sync_engine, logger
1112

1213
Story = Generator[Tuple[str, dict[str, Any]], dict[str, Any], None]
13-
RowCounts = dict[str, int]
14+
RowCounts = Counter[str]
1415

1516

1617
def create_db_tables(metadata: MetaData) -> None:
@@ -68,15 +69,14 @@ def create_db_data(
6869
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
6970
)
7071

71-
row_counts: RowCounts = {}
72+
row_counts: Counter[str] = Counter()
7273
with dst_engine.connect() as dst_conn:
7374
for _ in range(num_passes):
74-
row_counts = populate(
75+
row_counts += populate(
7576
dst_conn,
7677
sorted_tables,
7778
table_generator_dict,
7879
story_generator_list,
79-
row_counts,
8080
)
8181
return row_counts
8282

@@ -86,13 +86,13 @@ def _populate_story(
8686
table_dict: Mapping[str, Table],
8787
table_generator_dict: Mapping[str, TableGenerator],
8888
dst_conn: Connection,
89-
row_counts: RowCounts,
9089
) -> RowCounts:
9190
"""Write to the database all the rows created by the given story."""
9291
# Loop over the rows generated by the story, insert them into their
9392
# respective tables. Ideally this would say
9493
# `for table_name, provided_values in story:`
9594
# but we have to loop more manually to be able to use the `send` function.
95+
row_counts: Counter[str] = Counter()
9696
try:
9797
table_name, provided_values = next(story)
9898
while True:
@@ -129,11 +129,9 @@ def populate(
129129
tables: Sequence[Table],
130130
table_generator_dict: Mapping[str, TableGenerator],
131131
story_generator_list: Sequence[Mapping[str, Any]],
132-
row_counts: Optional[RowCounts] = None,
133132
) -> RowCounts:
134133
"""Populate a database schema with synthetic data."""
135-
if row_counts is None:
136-
row_counts = {}
134+
row_counts: Counter[str] = Counter()
137135
table_dict = {table.name: table for table in tables}
138136
# Generate stories
139137
# Each story generator returns a python generator (an unfortunate naming clash with
@@ -151,10 +149,10 @@ def populate(
151149
)
152150
for name, story in stories:
153151
# Run the inserts for each story within a transaction.
154-
logger.debug("Generating data for story %s", name)
152+
logger.debug('Generating data for story "%s".', name)
155153
with dst_conn.begin():
156-
row_counts = _populate_story(
157-
story, table_dict, table_generator_dict, dst_conn, row_counts
154+
row_counts += _populate_story(
155+
story, table_dict, table_generator_dict, dst_conn
158156
)
159157

160158
# Generate individual rows, table by table.
@@ -166,7 +164,7 @@ def populate(
166164
table_generator = table_generator_dict[table.name]
167165
if table_generator.num_rows_per_pass == 0:
168166
continue
169-
logger.debug("Generating data for table %s", table.name)
167+
logger.debug('Generating data for table "%s".', table.name)
170168
# Run all the inserts for one table in a transaction
171169
with dst_conn.begin():
172170
for _ in range(table_generator.num_rows_per_pass):

sqlsynthgen/main.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,16 @@ def create_data(
102102
story_generator_list,
103103
num_passes,
104104
)
105-
logger.debug("Data created in %s passes.", num_passes)
105+
logger.debug(
106+
"Data created in %s %s.", num_passes, "pass" if num_passes == 1 else "passes"
107+
)
106108
for table_name, row_count in row_counts.items():
107-
logger.debug("%s: %s rows created", table_name, row_count)
109+
logger.debug(
110+
"%s: %s %s created.",
111+
table_name,
112+
row_count,
113+
"row" if row_count == 1 else "rows",
114+
)
108115

109116

110117
@app.command()

sqlsynthgen/remove.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def remove_db_data(
2929
for table in reversed(metadata.sorted_tables):
3030
# We presume that all tables that aren't vocab should be truncated
3131
if table.name not in ssg_module.vocab_dict:
32-
logger.debug("Truncating table %s", table.name)
32+
logger.debug('Truncating table "%s".', table.name)
3333
dst_conn.execute(delete(table))
3434
dst_conn.commit()
3535

@@ -50,7 +50,7 @@ def remove_db_vocab(
5050
for table in reversed(metadata.sorted_tables):
5151
# We presume that all tables that are vocab should be truncated
5252
if table.name in ssg_module.vocab_dict:
53-
logger.debug("Truncating vocabulary table %s", table.name)
53+
logger.debug('Truncating vocabulary table "%s".', table.name)
5454
dst_conn.execute(delete(table))
5555
dst_conn.commit()
5656

tests/test_create.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the create module."""
22
import itertools as itt
3+
from collections import Counter
34
from pathlib import Path
45
from typing import Any, Generator, Tuple
56
from unittest.mock import MagicMock, call, patch
@@ -82,7 +83,7 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
8283
mock_gen = MagicMock(spec=TableGenerator)
8384
mock_gen.num_rows_per_pass = num_rows_per_pass
8485
mock_gen.return_value = {}
85-
row_counts = (
86+
row_counts = Counter(
8687
{table_name: num_initial_rows} if num_initial_rows > 0 else {}
8788
)
8889

@@ -97,20 +98,23 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
9798
if num_stories_per_pass > 0
9899
else []
99100
)
100-
row_counts = populate(
101+
row_counts += populate(
101102
mock_dst_conn,
102103
[mock_table],
103104
{table_name: mock_gen},
104105
story_generators,
105-
row_counts,
106106
)
107107

108108
expected_row_count = (
109109
num_stories_per_pass + num_rows_per_pass + num_initial_rows
110110
)
111111
self.assertEqual(
112+
Counter(
113+
{table_name: expected_row_count}
114+
if expected_row_count > 0
115+
else {}
116+
),
112117
row_counts,
113-
{table_name: expected_row_count} if expected_row_count > 0 else {},
114118
)
115119
self.assertListEqual(
116120
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass),
@@ -148,7 +152,7 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
148152
"three": mock_gen_three,
149153
}
150154

151-
row_counts = populate(mock_dst_conn, tables, row_generators, [], {})
155+
row_counts = populate(mock_dst_conn, tables, row_generators, [])
152156
self.assertEqual(row_counts, {"two": 1, "three": 1})
153157
self.assertListEqual(
154158
[call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list
@@ -221,4 +225,4 @@ def my_story() -> Story:
221225

222226
with engine.connect() as conn:
223227
with conn.begin():
224-
_populate_story(my_story(), dict(self.metadata.tables), {}, conn, {})
228+
_populate_story(my_story(), dict(self.metadata.tables), {}, conn)

tests/test_functional.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -336,40 +336,40 @@ def test_workflow_maximal_args(self) -> None:
336336
self.assertSuccess(completed_process)
337337
self.assertEqual(
338338
"Creating data.\n"
339-
"Generating data for story story_generators.short_story\n"
340-
"Generating data for story story_generators.short_story\n"
341-
"Generating data for story story_generators.short_story\n"
342-
"Generating data for story story_generators.full_row_story\n"
343-
"Generating data for story story_generators.long_story\n"
344-
"Generating data for story story_generators.long_story\n"
345-
"Generating data for table data_type_test\n"
346-
"Generating data for table no_pk_test\n"
347-
"Generating data for table person\n"
348-
"Generating data for table unique_constraint_test\n"
349-
"Generating data for table unique_constraint_test2\n"
350-
"Generating data for table test_entity\n"
351-
"Generating data for table hospital_visit\n"
352-
"Generating data for story story_generators.short_story\n"
353-
"Generating data for story story_generators.short_story\n"
354-
"Generating data for story story_generators.short_story\n"
355-
"Generating data for story story_generators.full_row_story\n"
356-
"Generating data for story story_generators.long_story\n"
357-
"Generating data for story story_generators.long_story\n"
358-
"Generating data for table data_type_test\n"
359-
"Generating data for table no_pk_test\n"
360-
"Generating data for table person\n"
361-
"Generating data for table unique_constraint_test\n"
362-
"Generating data for table unique_constraint_test2\n"
363-
"Generating data for table test_entity\n"
364-
"Generating data for table hospital_visit\n"
339+
'Generating data for story "story_generators.short_story".\n'
340+
'Generating data for story "story_generators.short_story".\n'
341+
'Generating data for story "story_generators.short_story".\n'
342+
'Generating data for story "story_generators.full_row_story".\n'
343+
'Generating data for story "story_generators.long_story".\n'
344+
'Generating data for story "story_generators.long_story".\n'
345+
'Generating data for table "data_type_test".\n'
346+
'Generating data for table "no_pk_test".\n'
347+
'Generating data for table "person".\n'
348+
'Generating data for table "unique_constraint_test".\n'
349+
'Generating data for table "unique_constraint_test2".\n'
350+
'Generating data for table "test_entity".\n'
351+
'Generating data for table "hospital_visit".\n'
352+
'Generating data for story "story_generators.short_story".\n'
353+
'Generating data for story "story_generators.short_story".\n'
354+
'Generating data for story "story_generators.short_story".\n'
355+
'Generating data for story "story_generators.full_row_story".\n'
356+
'Generating data for story "story_generators.long_story".\n'
357+
'Generating data for story "story_generators.long_story".\n'
358+
'Generating data for table "data_type_test".\n'
359+
'Generating data for table "no_pk_test".\n'
360+
'Generating data for table "person".\n'
361+
'Generating data for table "unique_constraint_test".\n'
362+
'Generating data for table "unique_constraint_test2".\n'
363+
'Generating data for table "test_entity".\n'
364+
'Generating data for table "hospital_visit".\n'
365365
"Data created in 2 passes.\n"
366-
f"person: {2*(3+1+2+2)} rows created\n"
367-
f"hospital_visit: {2*(2*2+3)} rows created\n"
368-
"data_type_test: 2 rows created\n"
369-
"no_pk_test: 2 rows created\n"
370-
"unique_constraint_test: 2 rows created\n"
371-
"unique_constraint_test2: 2 rows created\n"
372-
"test_entity: 2 rows created\n",
366+
f"person: {2*(3+1+2+2)} rows created.\n"
367+
f"hospital_visit: {2*(2*2+3)} rows created.\n"
368+
"data_type_test: 2 rows created.\n"
369+
"no_pk_test: 2 rows created.\n"
370+
"unique_constraint_test: 2 rows created.\n"
371+
"unique_constraint_test2: 2 rows created.\n"
372+
"test_entity: 2 rows created.\n",
373373
completed_process.stdout.decode("utf-8"),
374374
)
375375

@@ -390,13 +390,13 @@ def test_workflow_maximal_args(self) -> None:
390390
self.assertSuccess(completed_process)
391391
self.assertEqual(
392392
"Truncating non-vocabulary tables.\n"
393-
"Truncating table hospital_visit\n"
394-
"Truncating table test_entity\n"
395-
"Truncating table unique_constraint_test2\n"
396-
"Truncating table unique_constraint_test\n"
397-
"Truncating table person\n"
398-
"Truncating table no_pk_test\n"
399-
"Truncating table data_type_test\n"
393+
'Truncating table "hospital_visit".\n'
394+
'Truncating table "test_entity".\n'
395+
'Truncating table "unique_constraint_test2".\n'
396+
'Truncating table "unique_constraint_test".\n'
397+
'Truncating table "person".\n'
398+
'Truncating table "no_pk_test".\n'
399+
'Truncating table "data_type_test".\n'
400400
"Non-vocabulary tables truncated.\n",
401401
completed_process.stdout.decode("utf-8"),
402402
)
@@ -418,11 +418,11 @@ def test_workflow_maximal_args(self) -> None:
418418
self.assertSuccess(completed_process)
419419
self.assertEqual(
420420
"Truncating vocabulary tables.\n"
421-
"Truncating vocabulary table concept\n"
422-
"Truncating vocabulary table concept_type\n"
423-
"Truncating vocabulary table ref_to_unignorable_table\n"
424-
"Truncating vocabulary table mitigation_type\n"
425-
"Truncating vocabulary table empty_vocabulary\n"
421+
'Truncating vocabulary table "concept".\n'
422+
'Truncating vocabulary table "concept_type".\n'
423+
'Truncating vocabulary table "ref_to_unignorable_table".\n'
424+
'Truncating vocabulary table "mitigation_type".\n'
425+
'Truncating vocabulary table "empty_vocabulary".\n'
426426
"Vocabulary tables truncated.\n",
427427
completed_process.stdout.decode("utf-8"),
428428
)

tests/test_main.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,15 @@ def test_create_tables(
152152
mock_create.assert_called_once_with(mock_import.return_value.Base.metadata)
153153
self.assertSuccess(result)
154154

155+
@patch("sqlsynthgen.main.logger")
155156
@patch("sqlsynthgen.main.import_file")
156157
@patch("sqlsynthgen.main.create_db_data")
157-
def test_create_data(self, mock_create: MagicMock, mock_import: MagicMock) -> None:
158+
def test_create_data(
159+
self, mock_create: MagicMock, mock_import: MagicMock, mock_logger: MagicMock
160+
) -> None:
158161
"""Test the create-data sub-command."""
159162

163+
mock_create.return_value = {"a": 1}
160164
result = runner.invoke(
161165
app,
162166
[
@@ -180,6 +184,14 @@ def test_create_data(self, mock_create: MagicMock, mock_import: MagicMock) -> No
180184
)
181185
self.assertSuccess(result)
182186

187+
mock_logger.debug.assert_has_calls(
188+
[
189+
call("Creating data."),
190+
call("Data created in %s %s.", 1, "pass"),
191+
call("%s: %s %s created.", "a", 1, "row"),
192+
]
193+
)
194+
183195
@patch("sqlsynthgen.main.Path")
184196
@patch("sqlsynthgen.main.make_tables_file")
185197
@patch("sqlsynthgen.main.get_settings")

0 commit comments

Comments
 (0)