Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 1d1873f

Browse files
committed
add --state flag feature to dbt
1 parent 975efa4 commit 1d1873f

File tree

8 files changed

+408
-256
lines changed

8 files changed

+408
-256
lines changed

data_diff/__main__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import deepcopy
22
from datetime import datetime
3+
import os
34
import sys
45
import time
56
import json
@@ -233,7 +234,14 @@ def write_usage(self, prog: str, args: str = "", prefix: Optional[str] = None) -
233234
"-s",
234235
default=None,
235236
metavar="PATH",
236-
help="select dbt resources to compare using dbt selection syntax",
237+
help="select dbt resources to compare using dbt selection syntax.",
238+
)
239+
@click.option(
240+
"--state",
241+
"-s",
242+
default=None,
243+
metavar="PATH",
244+
help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.",
237245
)
238246
def main(conf, run, **kw):
239247
if kw["table2"] is None and kw["database2"]:
@@ -266,12 +274,16 @@ def main(conf, run, **kw):
266274
logging.basicConfig(level=logging.WARNING, format=LOG_FORMAT, datefmt=DATE_FORMAT)
267275

268276
try:
277+
state = kw.pop("state", None)
278+
if state:
279+
state = os.path.expanduser(state)
269280
if kw["dbt"]:
270281
dbt_diff(
271282
profiles_dir_override=kw["dbt_profiles_dir"],
272283
project_dir_override=kw["dbt_project_dir"],
273284
is_cloud=kw["cloud"],
274285
dbt_selection=kw["select"],
286+
state=state,
275287
)
276288
else:
277289
return _data_diff(**kw)

data_diff/cloud/data_source.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def _validate_temp_schema(temp_schema: str):
5252

5353

5454
def _get_temp_schema(dbt_parser: DbtParser, db_type: str) -> Optional[str]:
55-
diff_vars = dbt_parser.get_datadiff_variables()
56-
config_prod_database = diff_vars.get("prod_database")
57-
config_prod_schema = diff_vars.get("prod_schema")
55+
config = dbt_parser.get_datadiff_config()
56+
config_prod_database = config.prod_database
57+
config_prod_schema = config.prod_schema
5858
if config_prod_database is not None and config_prod_schema is not None:
5959
temp_schema = f"{config_prod_database}.{config_prod_schema}"
6060
if db_type == "snowflake":

data_diff/dbt.py

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import os
22
import time
33
import webbrowser
4-
from typing import List, Optional, Dict
4+
from typing import List, Optional, Dict, Tuple, Union
55
import keyring
6-
76
import pydantic
87
import rich
98
from rich.prompt import Confirm
109

10+
from data_diff.errors import DataDiffCustomSchemaNoConfigError, DataDiffDbtProjectVarsNotFoundError
11+
1112
from . import connect_to_table, diff_tables, Algorithm
1213
from .cloud import DatafoldAPI, TCloudApiDataDiff, TCloudApiOrgMeta, get_or_create_data_source
13-
from .dbt_parser import DbtParser, PROJECT_FILE
14+
from .dbt_parser import DbtParser, PROJECT_FILE, TDatadiffConfig
1415
from .tracking import (
1516
set_entrypoint_name,
1617
set_dbt_user_id,
@@ -52,24 +53,23 @@ def dbt_diff(
5253
project_dir_override: Optional[str] = None,
5354
is_cloud: bool = False,
5455
dbt_selection: Optional[str] = None,
56+
state: Optional[str] = None,
5557
) -> None:
5658
print_version_info()
5759
diff_threads = []
5860
set_entrypoint_name("CLI-dbt")
59-
dbt_parser = DbtParser(profiles_dir_override, project_dir_override)
61+
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, state)
6062
models = dbt_parser.get_models(dbt_selection)
61-
datadiff_variables = dbt_parser.get_datadiff_variables()
62-
config_prod_database = datadiff_variables.get("prod_database")
63-
config_prod_schema = datadiff_variables.get("prod_schema")
64-
config_prod_custom_schema = datadiff_variables.get("prod_custom_schema")
65-
datasource_id = datadiff_variables.get("datasource_id")
63+
config = dbt_parser.get_datadiff_config()
64+
6665
set_dbt_user_id(dbt_parser.dbt_user_id)
6766
set_dbt_version(dbt_parser.dbt_version)
6867
set_dbt_project_id(dbt_parser.dbt_project_id)
6968

70-
if datadiff_variables.get("custom_schemas") is not None:
71-
logger.warning(
72-
"vars: data_diff: custom_schemas: is no longer used and can be removed.\nTo utilize custom schemas, see the documentation here: https://docs.datafold.com/development_testing/open_source"
69+
if not state and not (config.prod_database or config.prod_schema):
70+
doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project"
71+
raise DataDiffDbtProjectVarsNotFoundError(
72+
f"""vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n\nOr specify a production manifest using the `--state` flag."""
7373
)
7474

7575
if is_cloud:
@@ -79,13 +79,13 @@ def dbt_diff(
7979
return
8080
org_meta = api.get_org_meta()
8181

82-
if datasource_id is None:
82+
if config.datasource_id is None:
8383
rich.print("[red]Data source ID not found in dbt_project.yml")
8484
is_create_data_source = Confirm.ask("Would you like to create a new data source?")
8585
if is_create_data_source:
86-
datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser)
86+
config.datasource_id = get_or_create_data_source(api=api, dbt_parser=dbt_parser)
8787
rich.print(f'To use the data source in next runs, please, update your "{PROJECT_FILE}" with a block:')
88-
rich.print(f"[green]vars:\n data_diff:\n datasource_id: {datasource_id}\n")
88+
rich.print(f"[green]vars:\n data_diff:\n datasource_id: {config.datasource_id}\n")
8989
rich.print(
9090
"Read more about Datafold vars in docs: "
9191
"https://docs.datafold.com/os_diff/dbt_integration/#configure-a-data-source\n"
@@ -96,21 +96,29 @@ def dbt_diff(
9696
"\nvars:\n data_diff:\n datasource_id: 1234"
9797
)
9898

99-
data_source = api.get_data_source(datasource_id)
99+
data_source = api.get_data_source(config.datasource_id)
100100
dbt_parser.set_casing_policy_for(connection_type=data_source.type)
101101
rich.print("[green][bold]\nDiffs in progress...[/][/]\n")
102102

103103
else:
104104
dbt_parser.set_connection()
105105

106106
for model in models:
107-
diff_vars = _get_diff_vars(
108-
dbt_parser, config_prod_database, config_prod_schema, config_prod_custom_schema, model
109-
)
107+
diff_vars = _get_diff_vars(dbt_parser, config, model)
108+
109+
# we won't always have a prod path when using state
110+
# when the model DNE in prod manifest, skip the model diff
111+
if (
112+
state and len(diff_vars.prod_path) < 2
113+
): # < 2 because some providers like databricks can legitimately have *only* 2
114+
diff_output_str = _diff_output_base(".".join(diff_vars.dev_path), ".".join(diff_vars.prod_path))
115+
diff_output_str += "[green]New model: nothing to diff![/] \n"
116+
rich.print(diff_output_str)
117+
continue
110118

111119
if diff_vars.primary_keys:
112120
if is_cloud:
113-
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, api, org_meta)
121+
diff_thread = run_as_daemon(_cloud_diff, diff_vars, config.datasource_id, api, org_meta)
114122
diff_threads.append(diff_thread)
115123
else:
116124
_local_diff(diff_vars)
@@ -128,41 +136,19 @@ def dbt_diff(
128136

129137
def _get_diff_vars(
130138
dbt_parser: "DbtParser",
131-
config_prod_database: Optional[str],
132-
config_prod_schema: Optional[str],
133-
config_prod_custom_schema: Optional[str],
139+
config: TDatadiffConfig,
134140
model,
135141
) -> TDiffVars:
136142
dev_database = model.database
137143
dev_schema = model.schema_
138144

139145
primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")
140146

141-
# "custom" dbt config database
142-
if model.config.database:
143-
prod_database = model.config.database
144-
elif config_prod_database:
145-
prod_database = config_prod_database
147+
# prod path is constructed via configuration or the prod manifest via --state
148+
if dbt_parser.prod_manifest_obj:
149+
prod_database, prod_schema = _get_prod_path_from_manifest(model, dbt_parser.prod_manifest_obj)
146150
else:
147-
prod_database = dev_database
148-
149-
# prod schema name differs from dev schema name
150-
if config_prod_schema:
151-
custom_schema = model.config.schema_
152-
153-
# the model has a custom schema config(schema='some_schema')
154-
if custom_schema:
155-
if not config_prod_custom_schema:
156-
raise ValueError(
157-
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value!\n"
158-
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
159-
)
160-
prod_schema = config_prod_custom_schema.replace("<custom_schema>", custom_schema)
161-
# no custom schema, use the default
162-
else:
163-
prod_schema = config_prod_schema
164-
else:
165-
prod_schema = dev_schema
151+
prod_database, prod_schema = _get_prod_path_from_config(config, model, dev_database, dev_schema)
166152

167153
if dbt_parser.requires_upper:
168154
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias] if x]
@@ -186,6 +172,45 @@ def _get_diff_vars(
186172
)
187173

188174

175+
def _get_prod_path_from_config(config, model, dev_database, dev_schema) -> Tuple[str, str]:
176+
# "custom" dbt config database
177+
if model.config.database:
178+
prod_database = model.config.database
179+
elif config.prod_database:
180+
prod_database = config.prod_database
181+
else:
182+
prod_database = dev_database
183+
184+
# prod schema name differs from dev schema name
185+
if config.prod_schema:
186+
custom_schema = model.config.schema_
187+
188+
# the model has a custom schema config(schema='some_schema')
189+
if custom_schema:
190+
if not config.prod_custom_schema:
191+
raise DataDiffCustomSchemaNoConfigError(
192+
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value or utilize the `--state` flag!\n\n"
193+
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
194+
)
195+
prod_schema = config.prod_custom_schema.replace("<custom_schema>", custom_schema)
196+
# no custom schema, use the default
197+
else:
198+
prod_schema = config.prod_schema
199+
else:
200+
prod_schema = dev_schema
201+
return prod_database, prod_schema
202+
203+
204+
def _get_prod_path_from_manifest(model, prod_manifest) -> Union[Tuple[str, str], Tuple[None, None]]:
205+
prod_database = None
206+
prod_schema = None
207+
prod_model = prod_manifest.nodes.get(model.unique_id, None)
208+
if prod_model:
209+
prod_database = prod_model.database
210+
prod_schema = prod_model.schema_
211+
return prod_database, prod_schema
212+
213+
189214
def _local_diff(diff_vars: TDiffVars) -> None:
190215
dev_qualified_str = ".".join(diff_vars.dev_path)
191216
prod_qualified_str = ".".join(diff_vars.prod_path)

data_diff/dbt_parser.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
DataDiffDbtCoreNoRunnerError,
1717
DataDiffDbtNoSuccessfulModelsInRunError,
1818
DataDiffDbtProfileNotFoundError,
19-
DataDiffDbtProjectVarsNotFoundError,
2019
DataDiffDbtRedshiftPasswordOnlyError,
2120
DataDiffDbtRunResultsVersionError,
2221
DataDiffDbtSelectNoMatchingModelsError,
@@ -88,29 +87,52 @@ class TDatadiffModelConfig(pydantic.BaseModel):
8887
exclude_columns: List[str] = []
8988

9089

90+
class TDatadiffConfig(pydantic.BaseModel):
91+
prod_database: Optional[str] = None
92+
prod_schema: Optional[str] = None
93+
prod_custom_schema: Optional[str] = None
94+
datasource_id: Optional[int] = None
95+
96+
9197
class DbtParser:
92-
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
98+
def __init__(
99+
self,
100+
profiles_dir_override: Optional[str] = None,
101+
project_dir_override: Optional[str] = None,
102+
state: Optional[str] = None,
103+
) -> None:
93104
try_set_dbt_flags()
94105
self.dbt_runner = try_get_dbt_runner()
95106
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
96107
self.project_dir = Path(project_dir_override or default_project_dir())
97108
self.connection = {}
98109
self.project_dict = self.get_project_dict()
99-
self.manifest_obj = self.get_manifest_obj()
100-
self.dbt_user_id = self.manifest_obj.metadata.user_id
101-
self.dbt_version = self.manifest_obj.metadata.dbt_version
102-
self.dbt_project_id = self.manifest_obj.metadata.project_id
110+
self.dev_manifest_obj = self.get_manifest_obj(self.project_dir / MANIFEST_PATH)
111+
self.prod_manifest_obj = None
112+
if state:
113+
self.prod_manifest_obj = self.get_manifest_obj(Path(state))
114+
115+
self.dbt_user_id = self.dev_manifest_obj.metadata.user_id
116+
self.dbt_version = self.dev_manifest_obj.metadata.dbt_version
117+
self.dbt_project_id = self.dev_manifest_obj.metadata.project_id
103118
self.requires_upper = False
104119
self.threads = None
105120
self.unique_columns = self.get_unique_columns()
106121

107-
def get_datadiff_variables(self) -> dict:
108-
doc_url = "https://docs.datafold.com/development_testing/open_source#configure-your-dbt-project"
109-
exception = DataDiffDbtProjectVarsNotFoundError(
110-
f"vars: data_diff: section not found in dbt_project.yml.\n\nTo solve this, please configure your dbt project: \n{doc_url}\n"
122+
def get_datadiff_config(self) -> TDatadiffConfig:
123+
data_diff_vars = self.project_dict.get("vars", {}).get("data_diff", {})
124+
prod_database = data_diff_vars.get("prod_database")
125+
prod_schema = data_diff_vars.get("prod_schema")
126+
prod_custom_schema = data_diff_vars.get("prod_custom_schema")
127+
datasource_id = data_diff_vars.get("datasource_id")
128+
config = TDatadiffConfig(
129+
prod_database=prod_database,
130+
prod_schema=prod_schema,
131+
prod_custom_schema=prod_custom_schema,
132+
datasource_id=datasource_id,
111133
)
112-
vars_dict = get_from_dict_with_raise(self.project_dict, "vars", exception)
113-
return get_from_dict_with_raise(vars_dict, "data_diff", exception)
134+
logger.info(f"config: {config}")
135+
return config
114136

115137
def get_datadiff_model_config(self, model_meta: dict) -> TDatadiffModelConfig:
116138
where_filter = None
@@ -172,7 +194,7 @@ def get_dbt_selection_models(self, dbt_selection: str) -> List[str]:
172194

173195
if results.success and results.result:
174196
model_list = [json.loads(model)["unique_id"] for model in results.result]
175-
models = [self.manifest_obj.nodes.get(x) for x in model_list]
197+
models = [self.dev_manifest_obj.nodes.get(x) for x in model_list]
176198
return models
177199

178200
if not results.result:
@@ -202,15 +224,17 @@ def get_run_results_models(self):
202224
)
203225

204226
success_models = [x.unique_id for x in run_results_obj.results if x.status.name == "success"]
205-
models = [self.manifest_obj.nodes.get(x) for x in success_models]
227+
models = [self.dev_manifest_obj.nodes.get(x) for x in success_models]
206228
if not models:
207-
raise DataDiffDbtNoSuccessfulModelsInRunError("Expected > 0 successful models runs from the last dbt command.")
229+
raise DataDiffDbtNoSuccessfulModelsInRunError(
230+
"Expected > 0 successful models runs from the last dbt command."
231+
)
208232

209233
return models
210234

211-
def get_manifest_obj(self):
212-
with open(self.project_dir / MANIFEST_PATH) as manifest:
213-
logger.info(f"Parsing file {MANIFEST_PATH}")
235+
def get_manifest_obj(self, path: Path):
236+
with open(path) as manifest:
237+
logger.info(f"Parsing file {path}")
214238
manifest_dict = json.load(manifest)
215239
manifest_obj = parse_manifest(manifest=manifest_dict)
216240
return manifest_obj
@@ -315,7 +339,9 @@ def set_connection(self):
315339
if (credentials.get("pass") is None and credentials.get("password") is None) or credentials.get(
316340
"method"
317341
) == "iam":
318-
raise DataDiffDbtRedshiftPasswordOnlyError("Only password authentication is currently supported for Redshift.")
342+
raise DataDiffDbtRedshiftPasswordOnlyError(
343+
"Only password authentication is currently supported for Redshift."
344+
)
319345
conn_info = {
320346
"driver": conn_type,
321347
"host": credentials.get("host"),
@@ -386,7 +412,7 @@ def get_pk_from_model(self, node, unique_columns: dict, pk_tag: str) -> List[str
386412
return []
387413

388414
def get_unique_columns(self) -> Dict[str, Set[str]]:
389-
manifest = self.manifest_obj
415+
manifest = self.dev_manifest_obj
390416
cols_by_uid = defaultdict(set)
391417
for node in manifest.nodes.values():
392418
try:

data_diff/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,7 @@ class DataDiffDbtCoreNoRunnerError(Exception):
4444

4545
class DataDiffDbtSelectVersionTooLowError(Exception):
4646
"Raised when attempting to use `--select` with a dbt-core version < 1.5."
47+
48+
49+
class DataDiffCustomSchemaNoConfigError(Exception):
50+
"Raised when a model has a custom schema, but there is no prod_custom_schema config. (And not using --state)."

tests/cloud/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)