Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit fdf0071

Browse files
authored
Merge pull request #522 from dlawin/issue_518_1
support datadiff meta filter
2 parents a5cd84d + e7f75a9 commit fdf0071

File tree

3 files changed

+152
-18
lines changed

3 files changed

+152
-18
lines changed

data_diff/cloud/datafold_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class TCloudApiDataDiff(pydantic.BaseModel):
101101
table1: List[str]
102102
table2: List[str]
103103
pk_columns: List[str]
104+
filter1: Optional[str] = None
105+
filter2: Optional[str] = None
104106

105107

106108
class TSummaryResultPrimaryKeyStats(pydantic.BaseModel):

data_diff/dbt.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class DiffVars:
7373
primary_keys: List[str]
7474
connection: Dict[str, str]
7575
threads: Optional[int]
76+
where_filter: Optional[str] = None
7677

7778

7879
def dbt_diff(
@@ -191,7 +192,16 @@ def _get_diff_vars(
191192
dev_qualified_list = [dev_database, dev_schema, model.alias]
192193
prod_qualified_list = [prod_database, prod_schema, model.alias]
193194

194-
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, dbt_parser.connection, dbt_parser.threads)
195+
where_filter = None
196+
if model.meta:
197+
try:
198+
where_filter = model.meta["datafold"]["datadiff"]["filter"]
199+
except KeyError:
200+
pass
201+
202+
return DiffVars(
203+
dev_qualified_list, prod_qualified_list, primary_keys, dbt_parser.connection, dbt_parser.threads, where_filter
204+
)
195205

196206

197207
def _local_diff(diff_vars: DiffVars) -> None:
@@ -228,7 +238,14 @@ def _local_diff(diff_vars: DiffVars) -> None:
228238
mutual_set = mutual_set - set(diff_vars.primary_keys)
229239
extra_columns = tuple(mutual_set)
230240

231-
diff = diff_tables(table1, table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=extra_columns)
241+
diff = diff_tables(
242+
table1,
243+
table2,
244+
threaded=True,
245+
algorithm=Algorithm.JOINDIFF,
246+
extra_columns=extra_columns,
247+
where=diff_vars.where_filter,
248+
)
232249

233250
if list(diff):
234251
diff_output_str += f"{column_diffs_str}{diff.get_stats_string(is_dbt=True)} \n"
@@ -277,6 +294,8 @@ def _cloud_diff(diff_vars: DiffVars, datasource_id: int, api: DatafoldAPI) -> No
277294
table1=diff_vars.prod_path,
278295
table2=diff_vars.dev_path,
279296
pk_columns=diff_vars.primary_keys,
297+
filter1=diff_vars.where_filter,
298+
filter2=diff_vars.where_filter,
280299
)
281300

282301
if is_tracking_enabled():

tests/test_dbt.py

Lines changed: 129 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -425,20 +425,27 @@ def test_local_diff(self, mock_diff_tables):
425425
mock_diff = MagicMock()
426426
mock_diff_tables.return_value = mock_diff
427427
mock_diff.__iter__.return_value = [1, 2, 3]
428+
threads = None
429+
where = "a_string"
428430
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
429431
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
430432
expected_keys = ["key"]
431-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, mock_connection, None)
433+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, mock_connection, threads, where)
432434
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
433435
_local_diff(diff_vars)
434436

435437
mock_diff_tables.assert_called_once_with(
436-
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
438+
mock_table1,
439+
mock_table2,
440+
threaded=True,
441+
algorithm=Algorithm.JOINDIFF,
442+
extra_columns=ANY,
443+
where=where,
437444
)
438445
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
439446
self.assertEqual(mock_connect.call_count, 2)
440-
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), None)
441-
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), None)
447+
mock_connect.assert_any_call(mock_connection, ".".join(dev_qualified_list), tuple(expected_keys), threads)
448+
mock_connect.assert_any_call(mock_connection, ".".join(prod_qualified_list), tuple(expected_keys), threads)
442449
mock_diff.get_stats_string.assert_called_once()
443450

444451
@patch("data_diff.dbt.diff_tables")
@@ -455,12 +462,14 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
455462
dev_qualified_list = ["dev_db", "dev_schema", "dev_table"]
456463
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
457464
expected_keys = ["primary_key_column"]
458-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, mock_connection, None)
465+
threads = None
466+
where = "a_string"
467+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_keys, mock_connection, threads, where)
459468
with patch("data_diff.dbt.connect_to_table", side_effect=[mock_table1, mock_table2]) as mock_connect:
460469
_local_diff(diff_vars)
461470

462471
mock_diff_tables.assert_called_once_with(
463-
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY
472+
mock_table1, mock_table2, threaded=True, algorithm=Algorithm.JOINDIFF, extra_columns=ANY, where=where
464473
)
465474
self.assertEqual(len(mock_diff_tables.call_args[1]["extra_columns"]), 2)
466475
self.assertEqual(mock_connect.call_count, 2)
@@ -479,7 +488,10 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
479488
prod_qualified_list = ["prod_db", "prod_schema", "prod_table"]
480489
expected_datasource_id = 1
481490
expected_primary_keys = ["primary_key_column"]
482-
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_primary_keys, None, None)
491+
connection = None
492+
threads = None
493+
where = "a_string"
494+
diff_vars = DiffVars(dev_qualified_list, prod_qualified_list, expected_primary_keys, connection, threads, where)
483495
_cloud_diff(diff_vars, expected_datasource_id, api=mock_api)
484496

485497
mock_api.create_data_diff.assert_called_once()
@@ -491,6 +503,8 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
491503
self.assertEqual(payload.table1, prod_qualified_list)
492504
self.assertEqual(payload.table2, dev_qualified_list)
493505
self.assertEqual(payload.pk_columns, expected_primary_keys)
506+
self.assertEqual(payload.filter1, where)
507+
self.assertEqual(payload.filter2, where)
494508

495509
@patch("data_diff.dbt._initialize_api")
496510
@patch("data_diff.dbt._get_diff_vars")
@@ -512,11 +526,14 @@ def test_diff_is_cloud(
512526
api_key = "a_api_key"
513527
api = DatafoldAPI(api_key=api_key, host=host)
514528
mock_initialize_api.return_value = api
529+
connection = None
530+
threads = None
531+
where = "a_string"
515532

516533
mock_dbt_parser.return_value = mock_dbt_parser_inst
517534
mock_dbt_parser_inst.get_models.return_value = [mock_model]
518535
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
519-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
536+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
520537
mock_get_diff_vars.return_value = expected_diff_vars
521538
dbt_diff(is_cloud=True)
522539
mock_dbt_parser_inst.get_models.assert_called_once()
@@ -547,11 +564,14 @@ def test_diff_is_cloud_no_ds_id(
547564
api_key = "a_api_key"
548565
api = DatafoldAPI(api_key=api_key, host=host)
549566
mock_initialize_api.return_value = api
567+
connection = None
568+
threads = None
569+
where = "a_string"
550570

551571
mock_dbt_parser.return_value = mock_dbt_parser_inst
552572
mock_dbt_parser_inst.get_models.return_value = [mock_model]
553573
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
554-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
574+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
555575
mock_get_diff_vars.return_value = expected_diff_vars
556576

557577
with self.assertRaises(ValueError):
@@ -579,7 +599,10 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
579599
}
580600
mock_dbt_parser_inst.get_models.return_value = [mock_model]
581601
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
582-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
602+
connection = None
603+
threads = None
604+
where = "a_string"
605+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
583606
mock_get_diff_vars.return_value = expected_diff_vars
584607
dbt_diff(is_cloud=False)
585608

@@ -606,7 +629,10 @@ def test_diff_no_prod_configs(
606629

607630
mock_dbt_parser_inst.get_models.return_value = [mock_model]
608631
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
609-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
632+
connection = None
633+
threads = None
634+
where = "a_string"
635+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
610636
mock_get_diff_vars.return_value = expected_diff_vars
611637
with self.assertRaises(ValueError):
612638
dbt_diff(is_cloud=False)
@@ -633,7 +659,10 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
633659
}
634660
mock_dbt_parser_inst.get_models.return_value = [mock_model]
635661
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
636-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
662+
connection = None
663+
threads = None
664+
where = "a_string"
665+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
637666
mock_get_diff_vars.return_value = expected_diff_vars
638667
dbt_diff(is_cloud=False)
639668

@@ -661,7 +690,10 @@ def test_diff_only_prod_schema(
661690

662691
mock_dbt_parser_inst.get_models.return_value = [mock_model]
663692
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
664-
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], None, None)
693+
connection = None
694+
threads = None
695+
where = "a_string"
696+
expected_diff_vars = DiffVars(["dev"], ["prod"], ["pks"], connection, threads, where)
665697
mock_get_diff_vars.return_value = expected_diff_vars
666698
with self.assertRaises(ValueError):
667699
dbt_diff(is_cloud=False)
@@ -697,7 +729,10 @@ def test_diff_is_cloud_no_pks(
697729

698730
mock_dbt_parser_inst.get_models.return_value = [mock_model]
699731
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
700-
expected_diff_vars = DiffVars(["dev"], ["prod"], [], None, None)
732+
connection = None
733+
threads = None
734+
where = "a_string"
735+
expected_diff_vars = DiffVars(["dev"], ["prod"], [], connection, threads, where)
701736
mock_get_diff_vars.return_value = expected_diff_vars
702737
dbt_diff(is_cloud=True)
703738

@@ -727,8 +762,10 @@ def test_diff_not_is_cloud_no_pks(
727762

728763
mock_dbt_parser_inst.get_models.return_value = [mock_model]
729764
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
730-
731-
expected_diff_vars = DiffVars(["dev"], ["prod"], [], None, None)
765+
connection = None
766+
threads = None
767+
where = "a_string"
768+
expected_diff_vars = DiffVars(["dev"], ["prod"], [], connection, threads, where)
732769
mock_get_diff_vars.return_value = expected_diff_vars
733770
dbt_diff(is_cloud=False)
734771
mock_dbt_parser_inst.get_models.assert_called_once()
@@ -749,6 +786,7 @@ def test_get_diff_vars_replace_custom_schema(self):
749786
mock_dbt_parser = Mock()
750787
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
751788
mock_dbt_parser.requires_upper = False
789+
mock_model.meta = None
752790

753791
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod_<custom_schema>", mock_model)
754792

@@ -773,6 +811,7 @@ def test_get_diff_vars_static_custom_schema(self):
773811
mock_dbt_parser = Mock()
774812
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
775813
mock_dbt_parser.requires_upper = False
814+
mock_model.meta = None
776815

777816
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model)
778817

@@ -796,6 +835,7 @@ def test_get_diff_vars_no_custom_schema_on_model(self):
796835
mock_dbt_parser = Mock()
797836
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
798837
mock_dbt_parser.requires_upper = False
838+
mock_model.meta = None
799839

800840
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model)
801841

@@ -817,6 +857,7 @@ def test_get_diff_vars_match_dev_schema(self):
817857
mock_dbt_parser = Mock()
818858
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
819859
mock_dbt_parser.requires_upper = False
860+
mock_model.meta = None
820861

821862
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)
822863

@@ -844,3 +885,75 @@ def test_get_diff_custom_schema_no_config_exception(self):
844885
_get_diff_vars(mock_dbt_parser, prod_database, prod_schema, None, mock_model)
845886

846887
mock_dbt_parser.get_pk_from_model.assert_called_once()
888+
889+
def test_get_diff_vars_meta_where(self):
890+
mock_model = Mock()
891+
prod_database = "a_prod_db"
892+
primary_keys = ["a_primary_key"]
893+
mock_model.database = "a_dev_db"
894+
mock_model.schema_ = "a_schema"
895+
mock_model.config.schema_ = None
896+
mock_model.alias = "a_model_name"
897+
mock_dbt_parser = Mock()
898+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
899+
mock_dbt_parser.requires_upper = False
900+
where = "a filter"
901+
mock_model.meta = {"datafold": {"datadiff": {"filter": where}}}
902+
903+
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)
904+
905+
assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
906+
assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias]
907+
assert diff_vars.primary_keys == primary_keys
908+
assert diff_vars.connection == mock_dbt_parser.connection
909+
assert diff_vars.threads == mock_dbt_parser.threads
910+
self.assertEqual(diff_vars.where_filter, where)
911+
mock_dbt_parser.get_pk_from_model.assert_called_once()
912+
913+
def test_get_diff_vars_meta_unrelated(self):
914+
mock_model = Mock()
915+
prod_database = "a_prod_db"
916+
primary_keys = ["a_primary_key"]
917+
mock_model.database = "a_dev_db"
918+
mock_model.schema_ = "a_schema"
919+
mock_model.config.schema_ = None
920+
mock_model.alias = "a_model_name"
921+
mock_dbt_parser = Mock()
922+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
923+
mock_dbt_parser.requires_upper = False
924+
where = None
925+
mock_model.meta = {"key": "value"}
926+
927+
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)
928+
929+
assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
930+
assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias]
931+
assert diff_vars.primary_keys == primary_keys
932+
assert diff_vars.connection == mock_dbt_parser.connection
933+
assert diff_vars.threads == mock_dbt_parser.threads
934+
self.assertEqual(diff_vars.where_filter, where)
935+
mock_dbt_parser.get_pk_from_model.assert_called_once()
936+
937+
def test_get_diff_vars_meta_none(self):
938+
mock_model = Mock()
939+
prod_database = "a_prod_db"
940+
primary_keys = ["a_primary_key"]
941+
mock_model.database = "a_dev_db"
942+
mock_model.schema_ = "a_schema"
943+
mock_model.config.schema_ = None
944+
mock_model.alias = "a_model_name"
945+
mock_dbt_parser = Mock()
946+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
947+
mock_dbt_parser.requires_upper = False
948+
where = None
949+
mock_model.meta = None
950+
951+
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)
952+
953+
assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
954+
assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias]
955+
assert diff_vars.primary_keys == primary_keys
956+
assert diff_vars.connection == mock_dbt_parser.connection
957+
assert diff_vars.threads == mock_dbt_parser.threads
958+
self.assertEqual(diff_vars.where_filter, where)
959+
mock_dbt_parser.get_pk_from_model.assert_called_once()

0 commit comments

Comments
 (0)