Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit bbb406f

Browse files
authored
Merge pull request #509 from pik94/make-temp-schema-optional
Make temp schema optional
2 parents 4a20f0c + 5890b57 commit bbb406f

File tree

4 files changed

+165
-25
lines changed

4 files changed

+165
-25
lines changed

data_diff/cloud/data_source.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import time
23
from typing import List, Optional, Union, overload
34

@@ -50,14 +51,33 @@ def _validate_temp_schema(temp_schema: str):
5051
raise ValueError("Temporary schema should have a format <database>.<schema>")
5152

5253

54+
def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]:
55+
diff_vars = dbt_parser.get_datadiff_variables()
56+
config_prod_database = diff_vars.get("prod_database")
57+
config_prod_schema = diff_vars.get("prod_schema")
58+
if config_prod_database is not None and config_prod_schema is not None:
59+
temp_schema = f"{config_prod_database}.{config_prod_schema}"
60+
if db_type == "snowflake":
61+
return temp_schema.upper()
62+
elif db_type in {"pg", "postgres_aurora", "postgres_aws_rds", "redshift"}:
63+
return temp_schema.lower()
64+
return temp_schema
65+
return
66+
67+
5368
def create_ds_config(
5469
ds_config: TCloudApiDataSourceConfigSchema,
5570
data_source_name: str,
5671
dbt_parser: Optional[DbtParser] = None,
5772
) -> TDsConfig:
5873
options = _parse_ds_credentials(ds_config=ds_config, only_basic_settings=True, dbt_parser=dbt_parser)
5974

60-
temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
75+
temp_schema = _get_temp_schema(dbt_parser=dbt_parser, db_type=ds_config.db_type) if dbt_parser else None
76+
if temp_schema:
77+
temp_schema = TemporarySchemaPrompt.ask("Temporary schema", default=temp_schema)
78+
else:
79+
temp_schema = TemporarySchemaPrompt.ask("Temporary schema (<database>.<schema>)")
80+
6181
float_tolerance = FloatPrompt.ask("Float tolerance", default=0.000001)
6282

6383
return TDsConfig(
@@ -92,6 +112,37 @@ def _cast_value(value: str, type_: str) -> Union[bool, int, str]:
92112
return value
93113

94114

115+
def _get_data_from_bigquery_json(path: str):
116+
with open(path, "r") as file:
117+
return json.load(file)
118+
119+
120+
def _align_dbt_cred_params_with_datafold_params(dbt_creds: dict) -> dict:
121+
db_type = dbt_creds["type"]
122+
if db_type == "bigquery":
123+
method = dbt_creds["method"]
124+
if method == "service-account":
125+
data = _get_data_from_bigquery_json(path=dbt_creds["keyfile"])
126+
dbt_creds["jsonKeyFile"] = json.dumps(data)
127+
elif method == "service-account-json":
128+
dbt_creds["jsonKeyFile"] = json.dumps(dbt_creds["keyfile_json"])
129+
else:
130+
rich.print(
131+
f'[red]Cannot extract bigquery credentials from dbt_project.yml for "{method}" type. '
132+
f"If you want to provide credentials via dbt_project.yml, "
133+
f'please, use "service-account" or "service-account-json" '
134+
f"(more in docs: https://docs.getdbt.com/reference/warehouse-setups/bigquery-setup). "
135+
f"Otherwise, you can provide a path to a json key file or a json key file data as an input."
136+
)
137+
dbt_creds["projectId"] = dbt_creds["project"]
138+
elif db_type == "snowflake":
139+
dbt_creds["default_db"] = dbt_creds["database"]
140+
elif db_type == "databricks":
141+
dbt_creds["http_password"] = dbt_creds["token"]
142+
dbt_creds["database"] = dbt_creds.get("catalog")
143+
return dbt_creds
144+
145+
95146
def _parse_ds_credentials(
96147
ds_config: TCloudApiDataSourceConfigSchema, only_basic_settings: bool = True, dbt_parser: Optional[DbtParser] = None
97148
):
@@ -101,6 +152,7 @@ def _parse_ds_credentials(
101152
use_dbt_data = Confirm.ask("Would you like to extract database credentials from dbt profiles.yml?")
102153
try:
103154
creds = dbt_parser.get_connection_creds()[0]
155+
creds = _align_dbt_cred_params_with_datafold_params(dbt_creds=creds)
104156
except Exception as e:
105157
rich.print(f"[red]Cannot parse database credentials from dbt profiles.yml. Reason: {e}")
106158

data_diff/cloud/datafold_api.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import base64
12
import dataclasses
23
import enum
34
import time
@@ -162,7 +163,7 @@ class TCloudDataSourceTestResult(pydantic.BaseModel):
162163
class TCloudApiDataSourceTestResult(pydantic.BaseModel):
163164
name: str
164165
status: str
165-
result: TCloudDataSourceTestResult
166+
result: Optional[TCloudDataSourceTestResult]
166167

167168

168169
@dataclasses.dataclass
@@ -194,7 +195,11 @@ def get_data_sources(self) -> List[TCloudApiDataSource]:
194195
return [TCloudApiDataSource(**item) for item in rv.json()]
195196

196197
def create_data_source(self, config: TDsConfig) -> TCloudApiDataSource:
197-
rv = self.make_post_request(url="api/v1/data_sources", payload=config.dict())
198+
payload = config.dict()
199+
if config.type == "bigquery":
200+
json_string = payload["options"]["jsonKeyFile"].encode("utf-8")
201+
payload["options"]["jsonKeyFile"] = base64.b64encode(json_string).decode("utf-8")
202+
rv = self.make_post_request(url="api/v1/data_sources", payload=payload)
198203
return TCloudApiDataSource(**rv.json())
199204

200205
def get_data_source_schema_config(
@@ -254,7 +259,9 @@ def check_data_source_test_results(self, job_id: int) -> List[TCloudApiDataSourc
254259
status=item["result"]["code"].lower(),
255260
message=item["result"]["message"],
256261
outcome=item["result"]["outcome"],
257-
),
262+
)
263+
if item["result"] is not None
264+
else None,
258265
)
259266
for item in rv.json()["results"]
260267
]

data_diff/dbt_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def set_connection(self):
145145
"role": credentials.get("role"),
146146
"schema": credentials.get("schema"),
147147
"insecure_mode": credentials.get("insecure_mode", False),
148-
"client_session_keep_alive": credentials.get("client_session_keep_alive", False)
148+
"client_session_keep_alive": credentials.get("client_session_keep_alive", False),
149149
}
150150
self.threads = credentials.get("threads")
151151
self.requires_upper = True

tests/cloud/test_data_source.py

Lines changed: 101 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import copy
21
from io import StringIO
32
import json
43
from pathlib import Path
54
from parameterized import parameterized
65
import unittest
7-
from unittest.mock import MagicMock, Mock, patch
6+
from unittest.mock import Mock, patch
87

98
from data_diff.cloud.datafold_api import (
109
TCloudApiDataSourceConfigSchema,
@@ -13,20 +12,19 @@
1312
TCloudApiDataSourceTestResult,
1413
TCloudDataSourceTestResult,
1514
TDsConfig,
16-
TestDataSourceStatus,
1715
)
18-
from data_diff.dbt_parser import DbtParser
1916
from data_diff.cloud.data_source import (
2017
TDataSourceTestStage,
2118
TestDataSourceStatus,
2219
create_ds_config,
2320
_check_data_source_exists,
21+
_get_temp_schema,
2422
_test_data_source,
2523
)
2624

2725

28-
DATA_SOURCE_CONFIGS = [
29-
TDsConfig(
26+
DATA_SOURCE_CONFIGS = {
27+
"snowflake": TDsConfig(
3028
name="ds_name",
3129
type="snowflake",
3230
options={
@@ -40,7 +38,7 @@
4038
float_tolerance=0.000001,
4139
temp_schema="database.temp_schema",
4240
),
43-
TDsConfig(
41+
"pg": TDsConfig(
4442
name="ds_name",
4543
type="pg",
4644
options={
@@ -53,18 +51,18 @@
5351
float_tolerance=0.000001,
5452
temp_schema="database.temp_schema",
5553
),
56-
TDsConfig(
54+
"bigquery": TDsConfig(
5755
name="ds_name",
5856
type="bigquery",
5957
options={
6058
"projectId": "project_id",
61-
"jsonKeyFile": "some_string",
59+
"jsonKeyFile": '{"key1": "value1"}',
6260
"location": "US",
6361
},
6462
float_tolerance=0.000001,
6563
temp_schema="database.temp_schema",
6664
),
67-
TDsConfig(
65+
"databricks": TDsConfig(
6866
name="ds_name",
6967
type="databricks",
7068
options={
@@ -76,7 +74,7 @@
7674
float_tolerance=0.000001,
7775
temp_schema="database.temp_schema",
7876
),
79-
TDsConfig(
77+
"redshift": TDsConfig(
8078
name="ds_name",
8179
type="redshift",
8280
options={
@@ -89,7 +87,7 @@
8987
float_tolerance=0.000001,
9088
temp_schema="database.temp_schema",
9189
),
92-
TDsConfig(
90+
"postgres_aurora": TDsConfig(
9391
name="ds_name",
9492
type="postgres_aurora",
9593
options={
@@ -102,7 +100,7 @@
102100
float_tolerance=0.000001,
103101
temp_schema="database.temp_schema",
104102
),
105-
TDsConfig(
103+
"postgres_aws_rds": TDsConfig(
106104
name="ds_name",
107105
type="postgres_aws_rds",
108106
options={
@@ -115,7 +113,7 @@
115113
float_tolerance=0.000001,
116114
temp_schema="database.temp_schema",
117115
),
118-
]
116+
}
119117

120118

121119
def format_data_source_config_test(testcase_func, param_num, param):
@@ -144,7 +142,23 @@ def setUp(self) -> None:
144142
self.api.get_data_source_schema_config.return_value = self.data_source_schema
145143
self.api.get_data_sources.return_value = self.data_sources
146144

147-
@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS], name_func=format_data_source_config_test)
145+
@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS.values()], name_func=format_data_source_config_test)
146+
@patch("data_diff.dbt_parser.DbtParser.__new__")
147+
def test_get_temp_schema(self, config: TDsConfig, mock_dbt_parser):
148+
diff_vars = {
149+
"prod_database": "db",
150+
"prod_schema": "schema",
151+
}
152+
mock_dbt_parser.get_datadiff_variables.return_value = diff_vars
153+
temp_schema = f'{diff_vars["prod_database"]}.{diff_vars["prod_schema"]}'
154+
if config.type == "snowflake":
155+
temp_schema = temp_schema.upper()
156+
elif config.type in {"pg", "postgres_aurora", "postgres_aws_rds", "redshift"}:
157+
temp_schema = temp_schema.lower()
158+
159+
assert _get_temp_schema(dbt_parser=mock_dbt_parser, db_type=config.type) == temp_schema
160+
161+
@parameterized.expand([(c,) for c in DATA_SOURCE_CONFIGS.values()], name_func=format_data_source_config_test)
148162
def test_create_ds_config(self, config: TDsConfig):
149163
inputs = list(config.options.values()) + [config.temp_schema, config.float_tolerance]
150164
with patch("rich.prompt.Console.input", side_effect=map(str, inputs)):
@@ -155,8 +169,8 @@ def test_create_ds_config(self, config: TDsConfig):
155169
self.assertEqual(actual_config, config)
156170

157171
@patch("data_diff.dbt_parser.DbtParser.__new__")
158-
def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
159-
config = DATA_SOURCE_CONFIGS[0]
172+
def test_create_snowflake_ds_config_from_dbt_profiles(self, mock_dbt_parser):
173+
config = DATA_SOURCE_CONFIGS["snowflake"]
160174
mock_dbt_parser.get_connection_creds.return_value = (config.options,)
161175
with patch("rich.prompt.Console.input", side_effect=["y", config.temp_schema, str(config.float_tolerance)]):
162176
actual_config = create_ds_config(
@@ -166,11 +180,78 @@ def test_create_ds_config_from_dbt_profiles(self, mock_dbt_parser):
166180
)
167181
self.assertEqual(actual_config, config)
168182

183+
@patch("data_diff.dbt_parser.DbtParser.__new__")
184+
def test_create_bigquery_ds_config_dbt_oauth(self, mock_dbt_parser):
185+
config = DATA_SOURCE_CONFIGS["bigquery"]
186+
mock_dbt_parser.get_connection_creds.return_value = (config.options,)
187+
with patch("rich.prompt.Console.input", side_effect=["y", config.temp_schema, str(config.float_tolerance)]):
188+
actual_config = create_ds_config(
189+
ds_config=self.db_type_data_source_schemas[config.type],
190+
data_source_name=config.name,
191+
dbt_parser=mock_dbt_parser,
192+
)
193+
self.assertEqual(actual_config, config)
194+
195+
@patch("data_diff.dbt_parser.DbtParser.__new__")
196+
@patch("data_diff.cloud.data_source._get_data_from_bigquery_json")
197+
def test_create_bigquery_ds_config_dbt_service_account(self, mock_get_data_from_bigquery_json, mock_dbt_parser):
198+
config = DATA_SOURCE_CONFIGS["bigquery"]
199+
200+
mock_get_data_from_bigquery_json.return_value = json.loads(config.options["jsonKeyFile"])
201+
mock_dbt_parser.get_connection_creds.return_value = (
202+
{
203+
"type": "bigquery",
204+
"method": "service-account",
205+
"project": config.options["projectId"],
206+
"threads": 1,
207+
"keyfile": "/some/path",
208+
},
209+
)
210+
211+
with patch(
212+
"rich.prompt.Console.input",
213+
side_effect=["y", config.options["location"], config.temp_schema, str(config.float_tolerance)],
214+
):
215+
actual_config = create_ds_config(
216+
ds_config=self.db_type_data_source_schemas[config.type],
217+
data_source_name=config.name,
218+
dbt_parser=mock_dbt_parser,
219+
)
220+
self.assertEqual(actual_config, config)
221+
222+
@patch("data_diff.dbt_parser.DbtParser.__new__")
223+
def test_create_bigquery_ds_config_dbt_service_account_json(self, mock_dbt_parser):
224+
config = DATA_SOURCE_CONFIGS["bigquery"]
225+
226+
mock_dbt_parser.get_connection_creds.return_value = (
227+
{
228+
"type": "bigquery",
229+
"method": "service-account-json",
230+
"project": config.options["projectId"],
231+
"threads": 1,
232+
"keyfile_json": json.loads(config.options["jsonKeyFile"]),
233+
},
234+
)
235+
236+
with patch(
237+
"rich.prompt.Console.input",
238+
side_effect=["y", config.options["location"], config.temp_schema, str(config.float_tolerance)],
239+
):
240+
actual_config = create_ds_config(
241+
ds_config=self.db_type_data_source_schemas[config.type],
242+
data_source_name=config.name,
243+
dbt_parser=mock_dbt_parser,
244+
)
245+
self.assertEqual(actual_config, config)
246+
169247
@patch("sys.stdout", new_callable=StringIO)
170248
@patch("data_diff.dbt_parser.DbtParser.__new__")
171-
def test_create_ds_config_from_dbt_profiles_one_param_passed_through_input(self, mock_dbt_parser, mock_stdout):
172-
config = DATA_SOURCE_CONFIGS[0]
173-
options = copy.copy(config.options)
249+
def test_create_ds_snowflake_config_from_dbt_profiles_one_param_passed_through_input(
250+
self, mock_dbt_parser, mock_stdout
251+
):
252+
config = DATA_SOURCE_CONFIGS["snowflake"]
253+
options = {**config.options, "type": "snowflake"}
254+
options["database"] = options.pop("default_db")
174255
account = options.pop("account")
175256
mock_dbt_parser.get_connection_creds.return_value = (options,)
176257
with patch(

0 commit comments

Comments
 (0)