@@ -425,20 +425,27 @@ def test_local_diff(self, mock_diff_tables):
425
425
mock_diff = MagicMock ()
426
426
mock_diff_tables .return_value = mock_diff
427
427
mock_diff .__iter__ .return_value = [1 , 2 , 3 ]
428
+ threads = None
429
+ where = "a_string"
428
430
dev_qualified_list = ["dev_db" , "dev_schema" , "dev_table" ]
429
431
prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
430
432
expected_keys = ["key" ]
431
- diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_keys , mock_connection , None )
433
+ diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_keys , mock_connection , threads , where )
432
434
with patch ("data_diff.dbt.connect_to_table" , side_effect = [mock_table1 , mock_table2 ]) as mock_connect :
433
435
_local_diff (diff_vars )
434
436
435
437
mock_diff_tables .assert_called_once_with (
436
- mock_table1 , mock_table2 , threaded = True , algorithm = Algorithm .JOINDIFF , extra_columns = ANY
438
+ mock_table1 ,
439
+ mock_table2 ,
440
+ threaded = True ,
441
+ algorithm = Algorithm .JOINDIFF ,
442
+ extra_columns = ANY ,
443
+ where = where ,
437
444
)
438
445
self .assertEqual (len (mock_diff_tables .call_args [1 ]["extra_columns" ]), 2 )
439
446
self .assertEqual (mock_connect .call_count , 2 )
440
- mock_connect .assert_any_call (mock_connection , "." .join (dev_qualified_list ), tuple (expected_keys ), None )
441
- mock_connect .assert_any_call (mock_connection , "." .join (prod_qualified_list ), tuple (expected_keys ), None )
447
+ mock_connect .assert_any_call (mock_connection , "." .join (dev_qualified_list ), tuple (expected_keys ), threads )
448
+ mock_connect .assert_any_call (mock_connection , "." .join (prod_qualified_list ), tuple (expected_keys ), threads )
442
449
mock_diff .get_stats_string .assert_called_once ()
443
450
444
451
@patch ("data_diff.dbt.diff_tables" )
@@ -455,12 +462,14 @@ def test_local_diff_no_diffs(self, mock_diff_tables):
455
462
dev_qualified_list = ["dev_db" , "dev_schema" , "dev_table" ]
456
463
prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
457
464
expected_keys = ["primary_key_column" ]
458
- diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_keys , mock_connection , None )
465
+ threads = None
466
+ where = "a_string"
467
+ diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_keys , mock_connection , threads , where )
459
468
with patch ("data_diff.dbt.connect_to_table" , side_effect = [mock_table1 , mock_table2 ]) as mock_connect :
460
469
_local_diff (diff_vars )
461
470
462
471
mock_diff_tables .assert_called_once_with (
463
- mock_table1 , mock_table2 , threaded = True , algorithm = Algorithm .JOINDIFF , extra_columns = ANY
472
+ mock_table1 , mock_table2 , threaded = True , algorithm = Algorithm .JOINDIFF , extra_columns = ANY , where = where
464
473
)
465
474
self .assertEqual (len (mock_diff_tables .call_args [1 ]["extra_columns" ]), 2 )
466
475
self .assertEqual (mock_connect .call_count , 2 )
@@ -479,7 +488,10 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
479
488
prod_qualified_list = ["prod_db" , "prod_schema" , "prod_table" ]
480
489
expected_datasource_id = 1
481
490
expected_primary_keys = ["primary_key_column" ]
482
- diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_primary_keys , None , None )
491
+ connection = None
492
+ threads = None
493
+ where = "a_string"
494
+ diff_vars = DiffVars (dev_qualified_list , prod_qualified_list , expected_primary_keys , connection , threads , where )
483
495
_cloud_diff (diff_vars , expected_datasource_id , api = mock_api )
484
496
485
497
mock_api .create_data_diff .assert_called_once ()
@@ -491,6 +503,8 @@ def test_cloud_diff(self, mock_api, mock_os_environ, mock_print):
491
503
self .assertEqual (payload .table1 , prod_qualified_list )
492
504
self .assertEqual (payload .table2 , dev_qualified_list )
493
505
self .assertEqual (payload .pk_columns , expected_primary_keys )
506
+ self .assertEqual (payload .filter1 , where )
507
+ self .assertEqual (payload .filter2 , where )
494
508
495
509
@patch ("data_diff.dbt._initialize_api" )
496
510
@patch ("data_diff.dbt._get_diff_vars" )
@@ -512,11 +526,14 @@ def test_diff_is_cloud(
512
526
api_key = "a_api_key"
513
527
api = DatafoldAPI (api_key = api_key , host = host )
514
528
mock_initialize_api .return_value = api
529
+ connection = None
530
+ threads = None
531
+ where = "a_string"
515
532
516
533
mock_dbt_parser .return_value = mock_dbt_parser_inst
517
534
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
518
535
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
519
- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
536
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
520
537
mock_get_diff_vars .return_value = expected_diff_vars
521
538
dbt_diff (is_cloud = True )
522
539
mock_dbt_parser_inst .get_models .assert_called_once ()
@@ -547,11 +564,14 @@ def test_diff_is_cloud_no_ds_id(
547
564
api_key = "a_api_key"
548
565
api = DatafoldAPI (api_key = api_key , host = host )
549
566
mock_initialize_api .return_value = api
567
+ connection = None
568
+ threads = None
569
+ where = "a_string"
550
570
551
571
mock_dbt_parser .return_value = mock_dbt_parser_inst
552
572
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
553
573
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
554
- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
574
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
555
575
mock_get_diff_vars .return_value = expected_diff_vars
556
576
557
577
with self .assertRaises (ValueError ):
@@ -579,7 +599,10 @@ def test_diff_is_not_cloud(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
579
599
}
580
600
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
581
601
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
582
- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
602
+ connection = None
603
+ threads = None
604
+ where = "a_string"
605
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
583
606
mock_get_diff_vars .return_value = expected_diff_vars
584
607
dbt_diff (is_cloud = False )
585
608
@@ -606,7 +629,10 @@ def test_diff_no_prod_configs(
606
629
607
630
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
608
631
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
609
- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
632
+ connection = None
633
+ threads = None
634
+ where = "a_string"
635
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
610
636
mock_get_diff_vars .return_value = expected_diff_vars
611
637
with self .assertRaises (ValueError ):
612
638
dbt_diff (is_cloud = False )
@@ -633,7 +659,10 @@ def test_diff_only_prod_db(self, mock_print, mock_dbt_parser, mock_cloud_diff, m
633
659
}
634
660
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
635
661
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
636
- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
662
+ connection = None
663
+ threads = None
664
+ where = "a_string"
665
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
637
666
mock_get_diff_vars .return_value = expected_diff_vars
638
667
dbt_diff (is_cloud = False )
639
668
@@ -661,7 +690,10 @@ def test_diff_only_prod_schema(
661
690
662
691
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
663
692
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
664
- expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], None , None )
693
+ connection = None
694
+ threads = None
695
+ where = "a_string"
696
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], ["pks" ], connection , threads , where )
665
697
mock_get_diff_vars .return_value = expected_diff_vars
666
698
with self .assertRaises (ValueError ):
667
699
dbt_diff (is_cloud = False )
@@ -697,7 +729,10 @@ def test_diff_is_cloud_no_pks(
697
729
698
730
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
699
731
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
700
- expected_diff_vars = DiffVars (["dev" ], ["prod" ], [], None , None )
732
+ connection = None
733
+ threads = None
734
+ where = "a_string"
735
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], [], connection , threads , where )
701
736
mock_get_diff_vars .return_value = expected_diff_vars
702
737
dbt_diff (is_cloud = True )
703
738
@@ -727,8 +762,10 @@ def test_diff_not_is_cloud_no_pks(
727
762
728
763
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
729
764
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
730
-
731
- expected_diff_vars = DiffVars (["dev" ], ["prod" ], [], None , None )
765
+ connection = None
766
+ threads = None
767
+ where = "a_string"
768
+ expected_diff_vars = DiffVars (["dev" ], ["prod" ], [], connection , threads , where )
732
769
mock_get_diff_vars .return_value = expected_diff_vars
733
770
dbt_diff (is_cloud = False )
734
771
mock_dbt_parser_inst .get_models .assert_called_once ()
@@ -749,6 +786,7 @@ def test_get_diff_vars_replace_custom_schema(self):
749
786
mock_dbt_parser = Mock ()
750
787
mock_dbt_parser .get_pk_from_model .return_value = primary_keys
751
788
mock_dbt_parser .requires_upper = False
789
+ mock_model .meta = None
752
790
753
791
diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod_<custom_schema>" , mock_model )
754
792
@@ -773,6 +811,7 @@ def test_get_diff_vars_static_custom_schema(self):
773
811
mock_dbt_parser = Mock ()
774
812
mock_dbt_parser .get_pk_from_model .return_value = primary_keys
775
813
mock_dbt_parser .requires_upper = False
814
+ mock_model .meta = None
776
815
777
816
diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod" , mock_model )
778
817
@@ -796,6 +835,7 @@ def test_get_diff_vars_no_custom_schema_on_model(self):
796
835
mock_dbt_parser = Mock ()
797
836
mock_dbt_parser .get_pk_from_model .return_value = primary_keys
798
837
mock_dbt_parser .requires_upper = False
838
+ mock_model .meta = None
799
839
800
840
diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod" , mock_model )
801
841
@@ -817,6 +857,7 @@ def test_get_diff_vars_match_dev_schema(self):
817
857
mock_dbt_parser = Mock ()
818
858
mock_dbt_parser .get_pk_from_model .return_value = primary_keys
819
859
mock_dbt_parser .requires_upper = False
860
+ mock_model .meta = None
820
861
821
862
diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
822
863
@@ -844,3 +885,75 @@ def test_get_diff_custom_schema_no_config_exception(self):
844
885
_get_diff_vars (mock_dbt_parser , prod_database , prod_schema , None , mock_model )
845
886
846
887
mock_dbt_parser .get_pk_from_model .assert_called_once ()
888
+
889
+ def test_get_diff_vars_meta_where (self ):
890
+ mock_model = Mock ()
891
+ prod_database = "a_prod_db"
892
+ primary_keys = ["a_primary_key" ]
893
+ mock_model .database = "a_dev_db"
894
+ mock_model .schema_ = "a_schema"
895
+ mock_model .config .schema_ = None
896
+ mock_model .alias = "a_model_name"
897
+ mock_dbt_parser = Mock ()
898
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
899
+ mock_dbt_parser .requires_upper = False
900
+ where = "a filter"
901
+ mock_model .meta = {"datafold" : {"datadiff" : {"filter" : where }}}
902
+
903
+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
904
+
905
+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
906
+ assert diff_vars .prod_path == [prod_database , mock_model .schema_ , mock_model .alias ]
907
+ assert diff_vars .primary_keys == primary_keys
908
+ assert diff_vars .connection == mock_dbt_parser .connection
909
+ assert diff_vars .threads == mock_dbt_parser .threads
910
+ self .assertEqual (diff_vars .where_filter , where )
911
+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
912
+
913
+ def test_get_diff_vars_meta_unrelated (self ):
914
+ mock_model = Mock ()
915
+ prod_database = "a_prod_db"
916
+ primary_keys = ["a_primary_key" ]
917
+ mock_model .database = "a_dev_db"
918
+ mock_model .schema_ = "a_schema"
919
+ mock_model .config .schema_ = None
920
+ mock_model .alias = "a_model_name"
921
+ mock_dbt_parser = Mock ()
922
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
923
+ mock_dbt_parser .requires_upper = False
924
+ where = None
925
+ mock_model .meta = {"key" : "value" }
926
+
927
+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
928
+
929
+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
930
+ assert diff_vars .prod_path == [prod_database , mock_model .schema_ , mock_model .alias ]
931
+ assert diff_vars .primary_keys == primary_keys
932
+ assert diff_vars .connection == mock_dbt_parser .connection
933
+ assert diff_vars .threads == mock_dbt_parser .threads
934
+ self .assertEqual (diff_vars .where_filter , where )
935
+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
936
+
937
+ def test_get_diff_vars_meta_none (self ):
938
+ mock_model = Mock ()
939
+ prod_database = "a_prod_db"
940
+ primary_keys = ["a_primary_key" ]
941
+ mock_model .database = "a_dev_db"
942
+ mock_model .schema_ = "a_schema"
943
+ mock_model .config .schema_ = None
944
+ mock_model .alias = "a_model_name"
945
+ mock_dbt_parser = Mock ()
946
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
947
+ mock_dbt_parser .requires_upper = False
948
+ where = None
949
+ mock_model .meta = None
950
+
951
+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
952
+
953
+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
954
+ assert diff_vars .prod_path == [prod_database , mock_model .schema_ , mock_model .alias ]
955
+ assert diff_vars .primary_keys == primary_keys
956
+ assert diff_vars .connection == mock_dbt_parser .connection
957
+ assert diff_vars .threads == mock_dbt_parser .threads
958
+ self .assertEqual (diff_vars .where_filter , where )
959
+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
0 commit comments