Skip to content

Commit 362452a

Browse files
authored
Merge pull request #162 from alan-turing-institute/log-row-counts
Log row counts for create-data
2 parents f5a66b7 + 79a77f8 commit 362452a

File tree

7 files changed

+125
-70
lines changed

7 files changed

+125
-70
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sqlsynthgen"
3-
version = "0.4.0"
3+
version = "0.4.1"
44
description = "Synthetic SQL data generator"
55
authors = ["Iain <[email protected]>"]
66
license = "MIT"

sqlsynthgen/create.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Functions and classes to create and populate the target database."""
2+
from collections import Counter
23
from typing import Any, Generator, Mapping, Sequence, Tuple
34

45
from sqlalchemy import Connection, insert
@@ -10,6 +11,7 @@
1011
from sqlsynthgen.utils import create_db_engine, get_sync_engine, logger
1112

1213
Story = Generator[Tuple[str, dict[str, Any]], dict[str, Any], None]
14+
RowCounts = Counter[str]
1315

1416

1517
def create_db_tables(metadata: MetaData) -> None:
@@ -57,7 +59,7 @@ def create_db_data(
5759
table_generator_dict: Mapping[str, TableGenerator],
5860
story_generator_list: Sequence[Mapping[str, Any]],
5961
num_passes: int,
60-
) -> None:
62+
) -> RowCounts:
6163
"""Connect to a database and populate it with data."""
6264
settings = get_settings()
6365
dst_dsn: str = settings.dst_dsn or ""
@@ -67,27 +69,30 @@ def create_db_data(
6769
create_db_engine(dst_dsn, schema_name=settings.dst_schema)
6870
)
6971

72+
row_counts: Counter[str] = Counter()
7073
with dst_engine.connect() as dst_conn:
7174
for _ in range(num_passes):
72-
populate(
75+
row_counts += populate(
7376
dst_conn,
7477
sorted_tables,
7578
table_generator_dict,
7679
story_generator_list,
7780
)
81+
return row_counts
7882

7983

8084
def _populate_story(
8185
story: Story,
8286
table_dict: Mapping[str, Table],
8387
table_generator_dict: Mapping[str, TableGenerator],
8488
dst_conn: Connection,
85-
) -> None:
89+
) -> RowCounts:
8690
"""Write to the database all the rows created by the given story."""
8791
# Loop over the rows generated by the story, insert them into their
8892
# respective tables. Ideally this would say
8993
# `for table_name, provided_values in story:`
9094
# but we have to loop more manually to be able to use the `send` function.
95+
row_counts: Counter[str] = Counter()
9196
try:
9297
table_name, provided_values = next(story)
9398
while True:
@@ -111,19 +116,22 @@ def _populate_story(
111116
else:
112117
return_values = {}
113118
final_values = {**insert_values, **return_values}
119+
row_counts[table_name] = row_counts.get(table_name, 0) + 1
114120
table_name, provided_values = story.send(final_values)
115121
except StopIteration:
116122
# The story has finished, it has no more rows to generate
117123
pass
124+
return row_counts
118125

119126

120127
def populate(
121128
dst_conn: Connection,
122129
tables: Sequence[Table],
123130
table_generator_dict: Mapping[str, TableGenerator],
124131
story_generator_list: Sequence[Mapping[str, Any]],
125-
) -> None:
132+
) -> RowCounts:
126133
"""Populate a database schema with synthetic data."""
134+
row_counts: Counter[str] = Counter()
127135
table_dict = {table.name: table for table in tables}
128136
# Generate stories
129137
# Each story generator returns a python generator (an unfortunate naming clash with
@@ -141,9 +149,11 @@ def populate(
141149
)
142150
for name, story in stories:
143151
# Run the inserts for each story within a transaction.
144-
logger.debug("Generating data for story %s", name)
152+
logger.debug('Generating data for story "%s".', name)
145153
with dst_conn.begin():
146-
_populate_story(story, table_dict, table_generator_dict, dst_conn)
154+
row_counts += _populate_story(
155+
story, table_dict, table_generator_dict, dst_conn
156+
)
147157

148158
# Generate individual rows, table by table.
149159
for table in tables:
@@ -154,9 +164,11 @@ def populate(
154164
table_generator = table_generator_dict[table.name]
155165
if table_generator.num_rows_per_pass == 0:
156166
continue
157-
logger.debug("Generating data for table %s", table.name)
167+
logger.debug('Generating data for table "%s".', table.name)
158168
# Run all the inserts for one table in a transaction
159169
with dst_conn.begin():
160170
for _ in range(table_generator.num_rows_per_pass):
161171
stmt = insert(table).values(table_generator(dst_conn))
162172
dst_conn.execute(stmt)
173+
row_counts[table.name] = row_counts.get(table.name, 0) + 1
174+
return row_counts

sqlsynthgen/main.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,22 @@ def create_data(
9696
orm_metadata = get_orm_metadata(orm_module, tables_config)
9797
table_generator_dict = ssg_module.table_generator_dict
9898
story_generator_list = ssg_module.story_generator_list
99-
create_db_data(
99+
row_counts = create_db_data(
100100
orm_metadata.sorted_tables,
101101
table_generator_dict,
102102
story_generator_list,
103103
num_passes,
104104
)
105-
logger.debug("Data created in %s passes.", num_passes)
105+
logger.debug(
106+
"Data created in %s %s.", num_passes, "pass" if num_passes == 1 else "passes"
107+
)
108+
for table_name, row_count in row_counts.items():
109+
logger.debug(
110+
"%s: %s %s created.",
111+
table_name,
112+
row_count,
113+
"row" if row_count == 1 else "rows",
114+
)
106115

107116

108117
@app.command()

sqlsynthgen/remove.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def remove_db_data(
2929
for table in reversed(metadata.sorted_tables):
3030
# We presume that all tables that aren't vocab should be truncated
3131
if table.name not in ssg_module.vocab_dict:
32-
logger.debug("Truncating table %s", table.name)
32+
logger.debug('Truncating table "%s".', table.name)
3333
dst_conn.execute(delete(table))
3434
dst_conn.commit()
3535

@@ -50,7 +50,7 @@ def remove_db_vocab(
5050
for table in reversed(metadata.sorted_tables):
5151
# We presume that all tables that are vocab should be truncated
5252
if table.name in ssg_module.vocab_dict:
53-
logger.debug("Truncating vocabulary table %s", table.name)
53+
logger.debug('Truncating vocabulary table "%s".', table.name)
5454
dst_conn.execute(delete(table))
5555
dst_conn.commit()
5656

tests/test_create.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for the create module."""
22
import itertools as itt
3+
from collections import Counter
34
from pathlib import Path
45
from typing import Any, Generator, Tuple
56
from unittest.mock import MagicMock, call, patch
@@ -34,11 +35,13 @@ def test_create_db_data(
3435
) -> None:
3536
"""Test the generate function."""
3637
mock_get_settings.return_value = get_test_settings()
38+
mock_populate.return_value = {}
3739

3840
num_passes = 23
39-
create_db_data([], {}, [], num_passes)
41+
row_counts = create_db_data([], {}, [], num_passes)
4042

4143
self.assertEqual(len(mock_populate.call_args_list), num_passes)
44+
self.assertEqual(row_counts, {})
4245
mock_create_engine.assert_called()
4346

4447
@patch("sqlsynthgen.create.get_settings")
@@ -62,13 +65,15 @@ def test_populate(self) -> None:
6265

6366
def story() -> Generator[Tuple[str, dict], None, None]:
6467
"""Mock story."""
65-
yield "table_name", {}
68+
yield table_name, {}
6669

6770
def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
6871
"""A function that returns mock stories."""
6972
return story()
7073

71-
for num_stories_per_pass, num_rows_per_pass in itt.product([0, 2], [0, 3]):
74+
for num_stories_per_pass, num_rows_per_pass, num_initial_rows in itt.product(
75+
[0, 2], [0, 3], [0, 17]
76+
):
7277
with patch("sqlsynthgen.create.insert") as mock_insert:
7378
mock_values = mock_insert.return_value.values
7479
mock_dst_conn = MagicMock(spec=Connection)
@@ -78,9 +83,10 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
7883
mock_gen = MagicMock(spec=TableGenerator)
7984
mock_gen.num_rows_per_pass = num_rows_per_pass
8085
mock_gen.return_value = {}
86+
row_counts = Counter(
87+
{table_name: num_initial_rows} if num_initial_rows > 0 else {}
88+
)
8189

82-
tables: list[Table] = [mock_table]
83-
row_generators: dict[str, TableGenerator] = {table_name: mock_gen}
8490
story_generators: list[dict[str, Any]] = (
8591
[
8692
{
@@ -92,13 +98,24 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
9298
if num_stories_per_pass > 0
9399
else []
94100
)
95-
populate(
101+
row_counts += populate(
96102
mock_dst_conn,
97-
tables,
98-
row_generators,
103+
[mock_table],
104+
{table_name: mock_gen},
99105
story_generators,
100106
)
101107

108+
expected_row_count = (
109+
num_stories_per_pass + num_rows_per_pass + num_initial_rows
110+
)
111+
self.assertEqual(
112+
Counter(
113+
{table_name: expected_row_count}
114+
if expected_row_count > 0
115+
else {}
116+
),
117+
row_counts,
118+
)
102119
self.assertListEqual(
103120
[call(mock_dst_conn)] * (num_stories_per_pass + num_rows_per_pass),
104121
mock_gen.call_args_list,
@@ -135,7 +152,8 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:
135152
"three": mock_gen_three,
136153
}
137154

138-
populate(mock_dst_conn, tables, row_generators, [])
155+
row_counts = populate(mock_dst_conn, tables, row_generators, [])
156+
self.assertEqual(row_counts, {"two": 1, "three": 1})
139157
self.assertListEqual(
140158
[call(mock_table_two), call(mock_table_three)], mock_insert.call_args_list
141159
)

0 commit comments

Comments
 (0)