11"""Tests for the create module."""
22import itertools as itt
3+ from pathlib import Path
34from typing import Any , Generator , Tuple
45from unittest .mock import MagicMock , call , patch
56
7+ from sqlalchemy import Column , Integer , create_engine
8+ from sqlalchemy .orm import declarative_base
9+
610from sqlsynthgen .create import (
11+ Story ,
12+ _populate_story ,
713 create_db_data ,
814 create_db_tables ,
915 create_db_vocab ,
1016 populate ,
1117)
12- from tests .utils import SSGTestCase , get_test_settings
18+ from tests .utils import RequiresDBTestCase , SSGTestCase , get_test_settings , run_psql
1319
1420
1521class MyTestCase (SSGTestCase ):
@@ -48,8 +54,7 @@ def test_create_db_tables(
4854 )
4955 mock_meta .create_all .assert_called_once_with (mock_create_engine .return_value )
5056
51- @patch ("sqlsynthgen.create.insert" )
52- def test_populate (self , mock_insert : MagicMock ) -> None :
57+ def test_populate (self ) -> None :
5358 """Test the populate function."""
5459 table_name = "table_name"
5560
@@ -62,34 +67,47 @@ def mock_story_gen(_: Any) -> Generator[Tuple[str, dict], None, None]:
6267 return story ()
6368
6469 for num_stories_per_pass , num_rows_per_pass in itt .product ([0 , 2 ], [0 , 3 ]):
65- mock_dst_conn = MagicMock ()
66- mock_dst_conn .execute .return_value .returned_defaults = {}
67- mock_table = MagicMock ()
68- mock_table .name = table_name
69- mock_gen = MagicMock ()
70- mock_gen .num_rows_per_pass = num_rows_per_pass
71- mock_gen .return_value = {}
72-
73- tables = [mock_table ]
74- row_generators = {table_name : mock_gen }
75- story_generators = (
76- [{"name" : mock_story_gen , "num_stories_per_pass" : num_stories_per_pass }]
77- if num_stories_per_pass > 0
78- else []
79- )
80- populate (mock_dst_conn , tables , row_generators , story_generators )
81-
82- mock_gen .assert_has_calls (
83- [call (mock_dst_conn )] * (num_stories_per_pass + num_rows_per_pass )
84- )
85- mock_insert .return_value .values .assert_has_calls (
86- [call (mock_gen .return_value )]
87- * (num_stories_per_pass + num_rows_per_pass )
88- )
89- mock_dst_conn .execute .assert_has_calls (
90- [call (mock_insert .return_value .values .return_value )]
91- * (num_stories_per_pass + num_rows_per_pass )
92- )
70+ with patch ("sqlsynthgen.create.insert" ) as mock_insert :
71+ mock_values = mock_insert .return_value .values
72+ mock_dst_conn = MagicMock ()
73+ mock_dst_conn .execute .return_value .returned_defaults = {}
74+ mock_table = MagicMock ()
75+ mock_table .name = table_name
76+ mock_gen = MagicMock ()
77+ mock_gen .num_rows_per_pass = num_rows_per_pass
78+ mock_gen .return_value = {}
79+
80+ tables = [mock_table ]
81+ row_generators = {table_name : mock_gen }
82+ story_generators = (
83+ [
84+ {
85+ "name" : mock_story_gen ,
86+ "num_stories_per_pass" : num_stories_per_pass ,
87+ }
88+ ]
89+ if num_stories_per_pass > 0
90+ else []
91+ )
92+ populate (mock_dst_conn , tables , row_generators , story_generators )
93+
94+ self .assertListEqual (
95+ [call (mock_dst_conn )] * (num_stories_per_pass + num_rows_per_pass ),
96+ mock_gen .call_args_list ,
97+ )
98+ self .assertListEqual (
99+ [call (mock_gen .return_value )]
100+ * (num_stories_per_pass + num_rows_per_pass ),
101+ mock_values .call_args_list ,
102+ )
103+ self .assertListEqual (
104+ (
105+ [call (mock_values .return_value .return_defaults .return_value )]
106+ * num_stories_per_pass
107+ )
108+ + ([call (mock_values .return_value )] * num_rows_per_pass ),
109+ mock_dst_conn .execute .call_args_list ,
110+ )
93111
94112 @patch ("sqlsynthgen.create.insert" )
95113 def test_populate_diff_length (self , mock_insert : MagicMock ) -> None :
@@ -131,3 +149,42 @@ def test_create_db_vocab(
131149 )
132150 # Running the same insert twice should be fine.
133151 create_db_vocab (vocab_list )
152+
153+
154+ class TestStoryDefaults (RequiresDBTestCase ):
155+ """Test that we can handle column defaults in stories."""
156+
157+ # pylint: disable=invalid-name
158+ Base = declarative_base ()
159+ # pylint: enable=invalid-name
160+ metadata = Base .metadata
161+
162+ class ColumnDefaultsTable (Base ): # type: ignore
163+ """A SQLAlchemy model."""
164+
165+ __tablename__ = "column_defaults"
166+ someval = Column (Integer , primary_key = True )
167+ otherval = Column (Integer , server_default = "8" )
168+
169+ def setUp (self ) -> None :
170+ """Ensure we have an empty DB to work with."""
171+ dump_file_path = Path ("dst.dump" )
172+ examples_dir = Path ("tests/examples" )
173+ run_psql (examples_dir / dump_file_path )
174+
175+ def test_populate (self ) -> None :
176+ """Check that we can populate a table that has column defaults."""
177+ engine = create_engine (
178+ "postgresql://postgres:password@localhost:5432/dst" ,
179+ )
180+ self .metadata .create_all (engine )
181+
182+ def my_story () -> Story :
183+ """A story generator."""
184+ first_row = yield "column_defaults" , {}
185+ self .assertEqual (1 , first_row ["someval" ])
186+ self .assertEqual (8 , first_row ["otherval" ])
187+
188+ with engine .connect () as conn :
189+ with conn .begin ():
190+ _populate_story (my_story (), dict (self .metadata .tables ), {}, conn )
0 commit comments