1
- import copy
2
1
from io import StringIO
3
2
import json
4
3
from pathlib import Path
5
4
from parameterized import parameterized
6
5
import unittest
7
- from unittest .mock import MagicMock , Mock , patch
6
+ from unittest .mock import Mock , patch
8
7
9
8
from data_diff .cloud .datafold_api import (
10
9
TCloudApiDataSourceConfigSchema ,
13
12
TCloudApiDataSourceTestResult ,
14
13
TCloudDataSourceTestResult ,
15
14
TDsConfig ,
16
- TestDataSourceStatus ,
17
15
)
18
- from data_diff .dbt_parser import DbtParser
19
16
from data_diff .cloud .data_source import (
20
17
TDataSourceTestStage ,
21
18
TestDataSourceStatus ,
22
19
create_ds_config ,
23
20
_check_data_source_exists ,
21
+ _get_temp_schema ,
24
22
_test_data_source ,
25
23
)
26
24
27
25
28
- DATA_SOURCE_CONFIGS = [
29
- TDsConfig (
26
+ DATA_SOURCE_CONFIGS = {
27
+ "snowflake" : TDsConfig (
30
28
name = "ds_name" ,
31
29
type = "snowflake" ,
32
30
options = {
40
38
float_tolerance = 0.000001 ,
41
39
temp_schema = "database.temp_schema" ,
42
40
),
43
- TDsConfig (
41
+ "pg" : TDsConfig (
44
42
name = "ds_name" ,
45
43
type = "pg" ,
46
44
options = {
53
51
float_tolerance = 0.000001 ,
54
52
temp_schema = "database.temp_schema" ,
55
53
),
56
- TDsConfig (
54
+ "bigquery" : TDsConfig (
57
55
name = "ds_name" ,
58
56
type = "bigquery" ,
59
57
options = {
60
58
"projectId" : "project_id" ,
61
- "jsonKeyFile" : "some_string" ,
59
+ "jsonKeyFile" : '{"key1": "value1"}' ,
62
60
"location" : "US" ,
63
61
},
64
62
float_tolerance = 0.000001 ,
65
63
temp_schema = "database.temp_schema" ,
66
64
),
67
- TDsConfig (
65
+ "databricks" : TDsConfig (
68
66
name = "ds_name" ,
69
67
type = "databricks" ,
70
68
options = {
76
74
float_tolerance = 0.000001 ,
77
75
temp_schema = "database.temp_schema" ,
78
76
),
79
- TDsConfig (
77
+ "redshift" : TDsConfig (
80
78
name = "ds_name" ,
81
79
type = "redshift" ,
82
80
options = {
89
87
float_tolerance = 0.000001 ,
90
88
temp_schema = "database.temp_schema" ,
91
89
),
92
- TDsConfig (
90
+ "postgres_aurora" : TDsConfig (
93
91
name = "ds_name" ,
94
92
type = "postgres_aurora" ,
95
93
options = {
102
100
float_tolerance = 0.000001 ,
103
101
temp_schema = "database.temp_schema" ,
104
102
),
105
- TDsConfig (
103
+ "postgres_aws_rds" : TDsConfig (
106
104
name = "ds_name" ,
107
105
type = "postgres_aws_rds" ,
108
106
options = {
115
113
float_tolerance = 0.000001 ,
116
114
temp_schema = "database.temp_schema" ,
117
115
),
118
- ]
116
+ }
119
117
120
118
121
119
def format_data_source_config_test (testcase_func , param_num , param ):
@@ -144,7 +142,23 @@ def setUp(self) -> None:
144
142
self .api .get_data_source_schema_config .return_value = self .data_source_schema
145
143
self .api .get_data_sources .return_value = self .data_sources
146
144
147
- @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS ], name_func = format_data_source_config_test )
145
+ @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS .values ()], name_func = format_data_source_config_test )
146
+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
147
+ def test_get_temp_schema (self , config : TDsConfig , mock_dbt_parser ):
148
+ diff_vars = {
149
+ "prod_database" : "db" ,
150
+ "prod_schema" : "schema" ,
151
+ }
152
+ mock_dbt_parser .get_datadiff_variables .return_value = diff_vars
153
+ temp_schema = f'{ diff_vars ["prod_database" ]} .{ diff_vars ["prod_schema" ]} '
154
+ if config .type == "snowflake" :
155
+ temp_schema = temp_schema .upper ()
156
+ elif config .type in {"pg" , "postgres_aurora" , "postgres_aws_rds" , "redshift" }:
157
+ temp_schema = temp_schema .lower ()
158
+
159
+ assert _get_temp_schema (dbt_parser = mock_dbt_parser , db_type = config .type ) == temp_schema
160
+
161
+ @parameterized .expand ([(c ,) for c in DATA_SOURCE_CONFIGS .values ()], name_func = format_data_source_config_test )
148
162
def test_create_ds_config (self , config : TDsConfig ):
149
163
inputs = list (config .options .values ()) + [config .temp_schema , config .float_tolerance ]
150
164
with patch ("rich.prompt.Console.input" , side_effect = map (str , inputs )):
@@ -155,8 +169,8 @@ def test_create_ds_config(self, config: TDsConfig):
155
169
self .assertEqual (actual_config , config )
156
170
157
171
@patch ("data_diff.dbt_parser.DbtParser.__new__" )
158
- def test_create_ds_config_from_dbt_profiles (self , mock_dbt_parser ):
159
- config = DATA_SOURCE_CONFIGS [0 ]
172
+ def test_create_snowflake_ds_config_from_dbt_profiles (self , mock_dbt_parser ):
173
+ config = DATA_SOURCE_CONFIGS ["snowflake" ]
160
174
mock_dbt_parser .get_connection_creds .return_value = (config .options ,)
161
175
with patch ("rich.prompt.Console.input" , side_effect = ["y" , config .temp_schema , str (config .float_tolerance )]):
162
176
actual_config = create_ds_config (
@@ -166,11 +180,78 @@ def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
166
180
)
167
181
self .assertEqual (actual_config , config )
168
182
183
+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
184
+ def test_create_bigquery_ds_config_dbt_oauth (self , mock_dbt_parser ):
185
+ config = DATA_SOURCE_CONFIGS ["bigquery" ]
186
+ mock_dbt_parser .get_connection_creds .return_value = (config .options ,)
187
+ with patch ("rich.prompt.Console.input" , side_effect = ["y" , config .temp_schema , str (config .float_tolerance )]):
188
+ actual_config = create_ds_config (
189
+ ds_config = self .db_type_data_source_schemas [config .type ],
190
+ data_source_name = config .name ,
191
+ dbt_parser = mock_dbt_parser ,
192
+ )
193
+ self .assertEqual (actual_config , config )
194
+
195
+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
196
+ @patch ("data_diff.cloud.data_source._get_data_from_bigquery_json" )
197
+ def test_create_bigquery_ds_config_dbt_service_account (self , mock_get_data_from_bigquery_json , mock_dbt_parser ):
198
+ config = DATA_SOURCE_CONFIGS ["bigquery" ]
199
+
200
+ mock_get_data_from_bigquery_json .return_value = json .loads (config .options ["jsonKeyFile" ])
201
+ mock_dbt_parser .get_connection_creds .return_value = (
202
+ {
203
+ "type" : "bigquery" ,
204
+ "method" : "service-account" ,
205
+ "project" : config .options ["projectId" ],
206
+ "threads" : 1 ,
207
+ "keyfile" : "/some/path" ,
208
+ },
209
+ )
210
+
211
+ with patch (
212
+ "rich.prompt.Console.input" ,
213
+ side_effect = ["y" , config .options ["location" ], config .temp_schema , str (config .float_tolerance )],
214
+ ):
215
+ actual_config = create_ds_config (
216
+ ds_config = self .db_type_data_source_schemas [config .type ],
217
+ data_source_name = config .name ,
218
+ dbt_parser = mock_dbt_parser ,
219
+ )
220
+ self .assertEqual (actual_config , config )
221
+
222
+ @patch ("data_diff.dbt_parser.DbtParser.__new__" )
223
+ def test_create_bigquery_ds_config_dbt_service_account_json (self , mock_dbt_parser ):
224
+ config = DATA_SOURCE_CONFIGS ["bigquery" ]
225
+
226
+ mock_dbt_parser .get_connection_creds .return_value = (
227
+ {
228
+ "type" : "bigquery" ,
229
+ "method" : "service-account-json" ,
230
+ "project" : config .options ["projectId" ],
231
+ "threads" : 1 ,
232
+ "keyfile_json" : json .loads (config .options ["jsonKeyFile" ]),
233
+ },
234
+ )
235
+
236
+ with patch (
237
+ "rich.prompt.Console.input" ,
238
+ side_effect = ["y" , config .options ["location" ], config .temp_schema , str (config .float_tolerance )],
239
+ ):
240
+ actual_config = create_ds_config (
241
+ ds_config = self .db_type_data_source_schemas [config .type ],
242
+ data_source_name = config .name ,
243
+ dbt_parser = mock_dbt_parser ,
244
+ )
245
+ self .assertEqual (actual_config , config )
246
+
169
247
@patch ("sys.stdout" , new_callable = StringIO )
170
248
@patch ("data_diff.dbt_parser.DbtParser.__new__" )
171
- def test_create_ds_config_from_dbt_profiles_one_param_passed_through_input (self , mock_dbt_parser , mock_stdout ):
172
- config = DATA_SOURCE_CONFIGS [0 ]
173
- options = copy .copy (config .options )
249
+ def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_input (
250
+ self , mock_dbt_parser , mock_stdout
251
+ ):
252
+ config = DATA_SOURCE_CONFIGS ["snowflake" ]
253
+ options = {** config .options , "type" : "snowflake" }
254
+ options ["database" ] = options .pop ("default_db" )
174
255
account = options .pop ("account" )
175
256
mock_dbt_parser .get_connection_creds .return_value = (options ,)
176
257
with patch (
0 commit comments