Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 0d624fa

Browse files
authored
Merge pull request #600 from datafold/DX-713
add --state flag feature to dbt integration
2 parents 20fbeb8 + 9121902 commit 0d624fa

File tree

8 files changed

+411
-259
lines changed

8 files changed

+411
-259
lines changed

data_diff/__main__.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,14 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
234234
"-s",
235235
default=None,
236236
metavar="PATH",
237-
help="select dbt resources to compare using dbt selection syntax",
237+
help="select dbt resources to compare using dbt selection syntax.",
238+
)
239+
@click.option(
240+
"--state",
241+
"-s",
242+
default=None,
243+
metavar="PATH",
244+
help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.",
238245
)
239246
def main(conf, run, **kw):
240247
if kw["table2"] is None and kw["database2"]:
@@ -267,6 +274,9 @@ def main(conf, run, **kw):
267274
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT)
268275

269276
try:
277+
state = kw.pop("state", None)
278+
if state:
279+
state = os.path.expanduser(state)
270280
profiles_dir_override = kw.pop("dbt_profiles_dir", None)
271281
if profiles_dir_override:
272282
profiles_dir_override = os.path.expanduser(profiles_dir_override)
@@ -279,11 +289,12 @@ def main(conf, run, **kw):
279289
project_dir_override=project_dir_override,
280290
is_cloud=kw["cloud"],
281291
dbt_selection=kw["select"],
292+
state=state,
282293
)
283294
else:
284-
return _data_diff(dbt_project_dir=project_dir_override,
285-
dbt_profiles_dir=profiles_dir_override,
286-
**kw)
295+
return _data_diff(
296+
dbt_project_dir=project_dir_override, dbt_profiles_dir=profiles_dir_override, state=state, **kw
297+
)
287298
except Exception as e:
288299
logging.error(e)
289300
if kw["debug"]:
@@ -324,6 +335,7 @@ def _data_diff(
324335
dbt_profiles_dir,
325336
dbt_project_dir,
326337
select,
338+
state,
327339
threads1=None,
328340
threads2=None,
329341
__conf__=None,

data_diff/cloud/data_source.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def _validate_temp_schema(temp_schema: str):
5252

5353

5454
def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]:
55-
diff_vars = dbt_parser.get_datadiff_variables()
56-
config_prod_database = diff_vars.get("prod_database")
57-
config_prod_schema = diff_vars.get("prod_schema")
55+
config = dbt_parser.get_datadiff_config()
56+
config_prod_database = config.prod_database
57+
config_prod_schema = config.prod_schema
5858
if config_prod_database is not None and config_prod_schema is not None:
5959
temp_schema = f"{config_prod_database}.{config_prod_schema}"
6060
if db_type == "snowflake":

data_diff/dbt.py

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,17 @@
22
import re
33
import time
44
import webbrowser
5-
from typing import List, Optional, Dict
5+
from typing import List, Optional, Dict, Tuple, Union
66
import keyring
7-
87
import pydantic
98
import rich
109
from rich.prompt import Confirm, Prompt
1110

11+
from data_diff.errors import DataDiffCustomSchemaNoConfigError, DataDiffDbtProjectVarsNotFoundError
12+
1213
from . import connect_to_table, diff_tables, Algorithm
1314
from .cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta, get_or_create_data_source
14-
from .dbt_parser import DbtParser, PROJECT_FILE
15+
from .dbt_parser import DbtParser, PROJECT_FILE, TDatadiffConfig
1516
from .tracking import (
1617
bool_ask_for_email,
1718
create_email_signup_event_json,
@@ -55,22 +56,21 @@ def dbt_diff(
5556
project_dir_override: Optional[str] = None,
5657
is_cloud: bool = False,
5758
dbt_selection: Optional[str] = None,
59+
state: Optional[str] = None,
5860
) -> None:
5961
print_version_info()
6062
diff_threads = []
6163
set_entrypoint_name("CLI-dbt")
62-
dbt_parser = DbtParser(profiles_dir_override, project_dir_override)
64+
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state)
6365
models = dbt_parser.get_models(dbt_selection)
64-
datadiff_variables = dbt_parser.get_datadiff_variables()
65-
config_prod_database = datadiff_variables.get("prod_database")
66-
config_prod_schema = datadiff_variables.get("prod_schema")
67-
config_prod_custom_schema = datadiff_variables.get("prod_custom_schema")
68-
datasource_id = datadiff_variables.get("datasource_id")
66+
config = dbt_parser.get_datadiff_config()
6967
_initialize_events(dbt_parser.dbt_user_id, dbt_parser.dbt_version, dbt_parser.dbt_project_id)
7068

71-
if datadiff_variables.get("custom_schemas") is not None:
72-
logger.warning(
73-
"vars: data_diff: custom_schemas: is no longer used and can be removed.\nTo utilize custom schemas, see the documentation here: https://docs.datafold.com/development_testing/open_source"
69+
70+
if not state and not (config.prod_database or config.prod_schema):
71+
doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project"
72+
raise DataDiffDbtProjectVarsNotFoundError(
73+
f"""vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n\nOr specify a production manifest using the `--state` flag."""
7474
)
7575

7676
if is_cloud:
@@ -80,13 +80,13 @@ def dbt_diff(
8080
return
8181
org_meta = api.get_org_meta()
8282

83-
if datasource_id is None:
83+
if config.datasource_id is None:
8484
rich.print("[red]Data source ID not found in dbt_project.yml")
8585
is_create_data_source = Confirm.ask("Would you like to create a new data source?")
8686
if is_create_data_source:
87-
datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser)
87+
config.datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser)
8888
rich.print(f'To use the data source in next runs, please, update your "{PROJECT_FILE}" with a block:')
89-
rich.print(f"[green]vars:\n data_diff:\n datasource_id: {datasource_id}\n")
89+
rich.print(f"[green]vars:\n data_diff:\n datasource_id: {config.datasource_id}\n")
9090
rich.print(
9191
"Read more about Datafold vars in docs: "
9292
"https://docs.datafold.com/os_diff/dbt_integration/#configure-a-data-source\n"
@@ -97,21 +97,29 @@ def dbt_diff(
9797
"\nvars:\n data_diff:\n datasource_id: 1234"
9898
)
9999

100-
data_source = api.get_data_source(datasource_id)
100+
data_source = api.get_data_source(config.datasource_id)
101101
dbt_parser.set_casing_policy_for(connection_type=data_source.type)
102102
rich.print("[green][bold]\nDiffs in progress...[/][/]\n")
103103

104104
else:
105105
dbt_parser.set_connection()
106106

107107
for model in models:
108-
diff_vars = _get_diff_vars(
109-
dbt_parser, config_prod_database, config_prod_schema, config_prod_custom_schema, model
110-
)
108+
diff_vars = _get_diff_vars(dbt_parser, config, model)
109+
110+
# we won't always have a prod path when using state
111+
# when the model DNE in prod manifest, skip the model diff
112+
if (
113+
state and len(diff_vars.prod_path) < 2
114+
): # < 2 because some providers like databricks can legitimately have *only* 2
115+
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
116+
diff_output_str += "[green]New model: nothing to diff![/] \n"
117+
rich.print(diff_output_str)
118+
continue
111119

112120
if diff_vars.primary_keys:
113121
if is_cloud:
114-
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api, org_meta)
122+
diff_thread = run_as_daemon(_cloud_diff, diff_vars, config.datasource_id, api, org_meta)
115123
diff_threads.append(diff_thread)
116124
else:
117125
_local_diff(diff_vars)
@@ -129,41 +137,19 @@ def dbt_diff(
129137

130138
def _get_diff_vars(
131139
dbt_parser: "DbtParser",
132-
config_prod_database: Optional[str],
133-
config_prod_schema: Optional[str],
134-
config_prod_custom_schema: Optional[str],
140+
config: TDatadiffConfig,
135141
model,
136142
) -> TDiffVars:
137143
dev_database = model.database
138144
dev_schema = model.schema_
139145

140146
primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")
141147

142-
# "custom" dbt config database
143-
if model.config.database:
144-
prod_database = model.config.database
145-
elif config_prod_database:
146-
prod_database = config_prod_database
148+
# prod path is constructed via configuration or the prod manifest via --state
149+
if dbt_parser.prod_manifest_obj:
150+
prod_database, prod_schema = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj)
147151
else:
148-
prod_database = dev_database
149-
150-
# prod schema name differs from dev schema name
151-
if config_prod_schema:
152-
custom_schema = model.config.schema_
153-
154-
# the model has a custom schema config(schema='some_schema')
155-
if custom_schema:
156-
if not config_prod_custom_schema:
157-
raise ValueError(
158-
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value!\n"
159-
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
160-
)
161-
prod_schema = config_prod_custom_schema.replace("<custom_schema>", custom_schema)
162-
# no custom schema, use the default
163-
else:
164-
prod_schema = config_prod_schema
165-
else:
166-
prod_schema = dev_schema
152+
prod_database, prod_schema = _get_prod_path_from_config(config, model, dev_database, dev_schema)
167153

168154
if dbt_parser.requires_upper:
169155
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias] if x]
@@ -187,6 +173,45 @@ def _get_diff_vars(
187173
)
188174

189175

176+
def _get_prod_path_from_config(config, model, dev_database, dev_schema) -> Tuple[str, str]:
177+
# "custom" dbt config database
178+
if model.config.database:
179+
prod_database = model.config.database
180+
elif config.prod_database:
181+
prod_database = config.prod_database
182+
else:
183+
prod_database = dev_database
184+
185+
# prod schema name differs from dev schema name
186+
if config.prod_schema:
187+
custom_schema = model.config.schema_
188+
189+
# the model has a custom schema config(schema='some_schema')
190+
if custom_schema:
191+
if not config.prod_custom_schema:
192+
raise DataDiffCustomSchemaNoConfigError(
193+
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value or utilize the `--state` flag!\n\n"
194+
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
195+
)
196+
prod_schema = config.prod_custom_schema.replace("<custom_schema>", custom_schema)
197+
# no custom schema, use the default
198+
else:
199+
prod_schema = config.prod_schema
200+
else:
201+
prod_schema = dev_schema
202+
return prod_database, prod_schema
203+
204+
205+
def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str], Tuple[None, None]]:
206+
prod_database = None
207+
prod_schema = None
208+
prod_model = prod_manifest.nodes.get(model.unique_id, None)
209+
if prod_model:
210+
prod_database = prod_model.database
211+
prod_schema = prod_model.schema_
212+
return prod_database, prod_schema
213+
214+
190215
def _local_diff(diff_vars: TDiffVars) -> None:
191216
dev_qualified_str = ".".join(diff_vars.dev_path)
192217
prod_qualified_str = ".".join(diff_vars.prod_path)

data_diff/dbt_parser.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
DataDiffDbtCoreNoRunnerError,
1717
DataDiffDbtNoSuccessfulModelsInRunError,
1818
DataDiffDbtProfileNotFoundError,
19-
DataDiffDbtProjectVarsNotFoundError,
2019
DataDiffDbtRedshiftPasswordOnlyError,
2120
DataDiffDbtRunResultsVersionError,
2221
DataDiffDbtSelectNoMatchingModelsError,
@@ -88,29 +87,52 @@ class TDatadiffModelConfig(pydantic.BaseModel):
8887
exclude_columns: List[str] = []
8988

9089

90+
class TDatadiffConfig(pydantic.BaseModel):
91+
prod_database: Optional[str] = None
92+
prod_schema: Optional[str] = None
93+
prod_custom_schema: Optional[str] = None
94+
datasource_id: Optional[int] = None
95+
96+
9197
class DbtParser:
92-
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
98+
def __init__(
99+
self,
100+
profiles_dir_override: Optional[str] = None,
101+
project_dir_override: Optional[str] = None,
102+
state: Optional[str] = None,
103+
) -> None:
93104
try_set_dbt_flags()
94105
self.dbt_runner = try_get_dbt_runner()
95106
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
96107
self.project_dir = Path(project_dir_override or default_project_dir())
97108
self.connection = {}
98109
self.project_dict = self.get_project_dict()
99-
self.manifest_obj = self.get_manifest_obj()
100-
self.dbt_user_id = self.manifest_obj.metadata.user_id
101-
self.dbt_version = self.manifest_obj.metadata.dbt_version
102-
self.dbt_project_id = self.manifest_obj.metadata.project_id
110+
self.dev_manifest_obj = self.get_manifest_obj(self.project_dir / MANIFEST_PATH)
111+
self.prod_manifest_obj = None
112+
if state:
113+
self.prod_manifest_obj = self.get_manifest_obj(Path(state))
114+
115+
self.dbt_user_id = self.dev_manifest_obj.metadata.user_id
116+
self.dbt_version = self.dev_manifest_obj.metadata.dbt_version
117+
self.dbt_project_id = self.dev_manifest_obj.metadata.project_id
103118
self.requires_upper = False
104119
self.threads = None
105120
self.unique_columns = self.get_unique_columns()
106121

107-
def get_datadiff_variables(self) -> dict:
108-
doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project"
109-
exception = DataDiffDbtProjectVarsNotFoundError(
110-
f"vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n"
122+
def get_datadiff_config(self) -> TDatadiffConfig:
123+
data_diff_vars = self.project_dict.get("vars", {}).get("data_diff", {})
124+
prod_database = data_diff_vars.get("prod_database")
125+
prod_schema = data_diff_vars.get("prod_schema")
126+
prod_custom_schema = data_diff_vars.get("prod_custom_schema")
127+
datasource_id = data_diff_vars.get("datasource_id")
128+
config = TDatadiffConfig(
129+
prod_database=prod_database,
130+
prod_schema=prod_schema,
131+
prod_custom_schema=prod_custom_schema,
132+
datasource_id=datasource_id,
111133
)
112-
vars_dict = get_from_dict_with_raise(self.project_dict, "vars", exception)
113-
return get_from_dict_with_raise(vars_dict, "data_diff", exception)
134+
logger.info(f"config: {config}")
135+
return config
114136

115137
def get_datadiff_model_config(self, model_meta: dict) -> TDatadiffModelConfig:
116138
where_filter = None
@@ -172,7 +194,7 @@ def get_dbt_selection_models(self, dbt_selection: str) -> List[str]:
172194

173195
if results.success and results.result:
174196
model_list = [json.loads(model)["unique_id"] for model in results.result]
175-
models = [self.manifest_obj.nodes.get(x) for x in model_list]
197+
models = [self.dev_manifest_obj.nodes.get(x) for x in model_list]
176198
return models
177199

178200
if not results.result:
@@ -202,15 +224,17 @@ def get_run_results_models(self):
202224
)
203225

204226
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
205-
models = [self.manifest_obj.nodes.get(x) for x in success_models]
227+
models = [self.dev_manifest_obj.nodes.get(x) for x in success_models]
206228
if not models:
207-
raise DataDiffDbtNoSuccessfulModelsInRunError("Expected > 0 successful models runs from the last dbt command.")
229+
raise DataDiffDbtNoSuccessfulModelsInRunError(
230+
"Expected > 0 successful models runs from the last dbt command."
231+
)
208232

209233
return models
210234

211-
def get_manifest_obj(self):
212-
with open(self.project_dir / MANIFEST_PATH) as manifest:
213-
logger.info(f"Parsing file {MANIFEST_PATH}")
235+
def get_manifest_obj(self, path: Path):
236+
with open(path) as manifest:
237+
logger.info(f"Parsing file {path}")
214238
manifest_dict = json.load(manifest)
215239
manifest_obj = parse_manifest(manifest=manifest_dict)
216240
return manifest_obj
@@ -315,7 +339,9 @@ def set_connection(self):
315339
if (credentials.get("pass") is None and credentials.get("password") is None) or credentials.get(
316340
"method"
317341
) == "iam":
318-
raise DataDiffDbtRedshiftPasswordOnlyError("Only password authentication is currently supported for Redshift.")
342+
raise DataDiffDbtRedshiftPasswordOnlyError(
343+
"Only password authentication is currently supported for Redshift."
344+
)
319345
conn_info = {
320346
"driver": conn_type,
321347
"host": credentials.get("host"),
@@ -386,7 +412,7 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str
386412
return []
387413

388414
def get_unique_columns(self) -> Dict[str, Set[str]]:
389-
manifest = self.manifest_obj
415+
manifest = self.dev_manifest_obj
390416
cols_by_uid = defaultdict(set)
391417
for node in manifest.nodes.values():
392418
try:

data_diff/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,7 @@ class DataDiffDbtCoreNoRunnerError(Exception):
4444

4545
class DataDiffDbtSelectVersionTooLowError(Exception):
4646
"Raised when attempting to use `--select` with a dbt-core version < 1.5."
47+
48+
49+
class DataDiffCustomSchemaNoConfigError(Exception):
50+
"Raised when a model has a custom schema, but there is no prod_custom_schema config. (And not using --state)."

tests/cloud/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)