diff --git a/core/dbt/artifacts/resources/v1/components.py b/core/dbt/artifacts/resources/v1/components.py index e25fd49797f..9890ff90a27 100644 --- a/core/dbt/artifacts/resources/v1/components.py +++ b/core/dbt/artifacts/resources/v1/components.py @@ -258,6 +258,7 @@ class ParsedResource(ParsedResourceMandatory): unrendered_config_call_dict: Dict[str, Any] = field(default_factory=dict) relation_name: Optional[str] = None raw_code: str = "" + vars: Dict[str, Any] = field(default_factory=dict) doc_blocks: List[str] = field(default_factory=list) def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None): diff --git a/core/dbt/artifacts/resources/v1/exposure.py b/core/dbt/artifacts/resources/v1/exposure.py index b3002f0ec9d..18bf44c039c 100644 --- a/core/dbt/artifacts/resources/v1/exposure.py +++ b/core/dbt/artifacts/resources/v1/exposure.py @@ -43,6 +43,7 @@ class Exposure(GraphResource): tags: List[str] = field(default_factory=list) config: ExposureConfig = field(default_factory=ExposureConfig) unrendered_config: Dict[str, Any] = field(default_factory=dict) + vars: Dict[str, Any] = field(default_factory=dict) url: Optional[str] = None depends_on: DependsOn = field(default_factory=DependsOn) refs: List[RefArgs] = field(default_factory=list) diff --git a/core/dbt/artifacts/resources/v1/source_definition.py b/core/dbt/artifacts/resources/v1/source_definition.py index 31784a017c0..4f56541334f 100644 --- a/core/dbt/artifacts/resources/v1/source_definition.py +++ b/core/dbt/artifacts/resources/v1/source_definition.py @@ -75,6 +75,7 @@ class SourceDefinition(ParsedSourceMandatory): config: SourceConfig = field(default_factory=SourceConfig) patch_path: Optional[str] = None unrendered_config: Dict[str, Any] = field(default_factory=dict) + vars: Dict[str, Any] = field(default_factory=dict) relation_name: Optional[str] = None created_at: float = field(default_factory=lambda: time.time()) unrendered_database: Optional[str] = None diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py index 730316674d8..bf069d64e5d 100644 --- a/core/dbt/context/configured.py +++ b/core/dbt/context/configured.py @@ -31,23 +31,35 @@ def __init__(self, package_name: str): self.resource_type = NodeType.Model +class SchemaYamlVars: + def __init__(self): + self.env_vars = {} + self.vars = {} + + class ConfiguredVar(Var): def __init__( self, context: Dict[str, Any], config: AdapterRequiredConfig, project_name: str, + schema_yaml_vars: Optional[SchemaYamlVars] = None, ): super().__init__(context, config.cli_vars) self._config = config self._project_name = project_name + self.schema_yaml_vars = schema_yaml_vars def __call__(self, var_name, default=Var._VAR_NOTSET): my_config = self._config.load_dependencies()[self._project_name] + var_found = False + var_value = None + # cli vars > active project > local project if var_name in self._config.cli_vars: - return self._config.cli_vars[var_name] + var_found = True + var_value = self._config.cli_vars[var_name] adapter_type = self._config.credentials.type lookup = FQNLookup(self._project_name) @@ -58,19 +70,21 @@ def __call__(self, var_name, default=Var._VAR_NOTSET): all_vars.add(my_config.vars.vars_for(lookup, adapter_type)) all_vars.add(active_vars) - if var_name in all_vars: - return all_vars[var_name] + if not var_found and var_name in all_vars: + var_found = True + var_value = all_vars[var_name] - if default is not Var._VAR_NOTSET: - return default - - return self.get_missing_var(var_name) + if not var_found and default is not Var._VAR_NOTSET: + var_found = True + var_value = default + if not var_found: + return self.get_missing_var(var_name) + else: + if self.schema_yaml_vars: + self.schema_yaml_vars.vars[var_name] = var_value -class SchemaYamlVars: - def __init__(self): - self.env_vars = {} - self.vars = {} + return var_value class SchemaYamlContext(ConfiguredContext): @@ -82,7 +96,7 @@ def __init__(self, config, project_name: str, schema_yaml_vars: Optional[SchemaY @contextproperty() def var(self) -> ConfiguredVar: - return ConfiguredVar(self._ctx, self.config, self._project_name) + return ConfiguredVar(self._ctx, self.config, self._project_name, self.schema_yaml_vars) @contextmember() def env_var(self, var: str, default: Optional[str] = None) -> str: diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index d0f3c55851f..2ed08646a6f 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -967,6 +967,14 @@ def get_missing_var(self, var_name): # in the parser, just always return None. return None + def __call__(self, var_name: str, default: Any = ModelConfiguredVar._VAR_NOTSET) -> Any: + var_value = super().__call__(var_name, default) + + if self._node and hasattr(self._node, "vars"): + self._node.vars[var_name] = var_value + + return var_value + class RuntimeVar(ModelConfiguredVar): pass diff --git a/core/dbt/contracts/files.py b/core/dbt/contracts/files.py index 17c5d18d519..976de077e5f 100644 --- a/core/dbt/contracts/files.py +++ b/core/dbt/contracts/files.py @@ -220,6 +220,7 @@ class SchemaSourceFile(BaseSourceFile): unrendered_configs: Dict[str, Any] = field(default_factory=dict) unrendered_databases: Dict[str, Any] = field(default_factory=dict) unrendered_schemas: Dict[str, Any] = field(default_factory=dict) + vars: Dict[str, Any] = field(default_factory=dict) pp_dict: Optional[Dict[str, Any]] = None pp_test_index: Optional[Dict[str, Any]] = None @@ -360,6 +361,22 @@ def delete_from_unrendered_configs(self, yaml_key, name): if not self.unrendered_configs[yaml_key]: del self.unrendered_configs[yaml_key] + def add_vars(self, vars: Dict[str, Any], yaml_key: str, name: str) -> None: + if yaml_key not in self.vars: + self.vars[yaml_key] = {} + + if name not in self.vars[yaml_key]: + self.vars[yaml_key][name] = vars + + def get_vars(self, yaml_key: str, name: str) -> Dict[str, Any]: + if yaml_key not in self.vars: + return {} + + if name not in self.vars[yaml_key]: + return {} + + return self.vars[yaml_key][name] + def add_env_var(self, var, yaml_key, name): if yaml_key not in self.env_vars: self.env_vars[yaml_key] = {} diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py index 560952287c0..239ce3f9751 100644 --- a/core/dbt/contracts/graph/nodes.py +++ b/core/dbt/contracts/graph/nodes.py @@ -378,6 +378,9 @@ def same_contract(self, old, adapter_type=None) -> bool: # This would only apply to seeds return True + def same_vars(self, old) -> bool: + return self.vars == old.vars + def same_contents(self, old, adapter_type) -> bool: if old is None: return False @@ -385,12 +388,20 @@ def same_contents(self, old, adapter_type) -> bool: # Need to ensure that same_contract is called because it # could throw an error same_contract = self.same_contract(old, adapter_type) + + # Legacy behaviour + if not get_flags().state_modified_compare_vars: + same_vars = True + else: + same_vars = self.same_vars(old) + return ( self.same_body(old) and self.same_config(old) and self.same_persisted_description(old) and self.same_fqn(old) and self.same_database_representation(old) + and same_vars and same_contract and True ) @@ -1353,6 +1364,9 @@ def same_config(self, old: "SourceDefinition") -> bool: old.unrendered_config, ) + def same_vars(self, other: "SourceDefinition") -> bool: + return self.vars == other.vars + def same_contents(self, old: Optional["SourceDefinition"]) -> bool: # existing when it didn't before is a change! if old is None: @@ -1366,6 +1380,12 @@ def same_contents(self, old: Optional["SourceDefinition"]) -> bool: # freshness changes are changes, I guess # metadata/tags changes are not "changes" # patching/description changes are not "changes" + # Legacy behaviour + if not get_flags().state_modified_compare_vars: + same_vars = True + else: + same_vars = self.same_vars(old) + return ( self.same_database_representation(old) and self.same_fqn(old) @@ -1373,6 +1393,7 @@ def same_contents(self, old: Optional["SourceDefinition"]) -> bool: and self.same_quoting(old) and self.same_freshness(old) and self.same_external(old) + and same_vars and True ) @@ -1469,12 +1490,21 @@ def same_config(self, old: "Exposure") -> bool: old.unrendered_config, ) + def same_vars(self, old: "Exposure") -> bool: + return self.vars == old.vars + def same_contents(self, old: Optional["Exposure"]) -> bool: # existing when it didn't before is a change! # metadata/tags changes are not "changes" if old is None: return True + # Legacy behaviour + if not get_flags().state_modified_compare_vars: + same_vars = True + else: + same_vars = self.same_vars(old) + return ( self.same_fqn(old) and self.same_exposure_type(old) @@ -1485,6 +1515,7 @@ def same_contents(self, old: Optional["Exposure"]) -> bool: and self.same_label(old) and self.same_depends_on(old) and self.same_config(old) + and same_vars and True ) @@ -1765,6 +1796,7 @@ class ParsedNodePatch(ParsedPatch): latest_version: Optional[NodeVersion] constraints: List[Dict[str, Any]] deprecation_date: Optional[datetime] + vars: Dict[str, Any] time_spine: Optional[TimeSpine] = None semantic_model: Union[UnparsedSemanticModelConfig, bool, None] = None metrics: Optional[List[UnparsedMetricV2]] = None diff --git a/core/dbt/graph/selector_methods.py b/core/dbt/graph/selector_methods.py index d97a468bfef..9a5ddbd5fcc 100644 --- a/core/dbt/graph/selector_methods.py +++ b/core/dbt/graph/selector_methods.py @@ -781,6 +781,7 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu "modified.relation": self.check_modified_factory("same_database_representation"), "modified.macros": self.check_modified_macros, "modified.contract": self.check_modified_contract("same_contract", adapter_type), + "modified.vars": self.check_modified_factory("same_vars"), } if selector in state_checks: checker = state_checks[selector] diff --git a/core/dbt/parser/schema_yaml_readers.py b/core/dbt/parser/schema_yaml_readers.py index 2b8b15f4426..fd6d6cb2a90 100644 --- a/core/dbt/parser/schema_yaml_readers.py +++ b/core/dbt/parser/schema_yaml_readers.py @@ -116,6 +116,9 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> None: unique_id = f"{NodeType.Exposure}.{package_name}.{unparsed.name}" path = self.yaml.path.relative_path + assert isinstance(self.yaml.file, SchemaSourceFile) + exposure_vars = self.yaml.file.get_vars(self.key, unparsed.name) + fqn = self.schema_parser.get_fqn_prefix(path) fqn.append(unparsed.name) @@ -164,6 +167,7 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> None: maturity=unparsed.maturity, config=config, unrendered_config=unrendered_config, + vars=exposure_vars, ) ctx = generate_parse_exposure( parsed, diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index e2b03d8b612..2723ddb483a 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -493,10 +493,14 @@ def get_key_dicts(self) -> Iterable[Dict[str, Any]]: if self.schema_yaml_vars.env_vars: self.schema_parser.manifest.env_vars.update(self.schema_yaml_vars.env_vars) - for var in self.schema_yaml_vars.env_vars.keys(): - schema_file.add_env_var(var, self.key, entry["name"]) + for env_var in self.schema_yaml_vars.env_vars.keys(): + schema_file.add_env_var(env_var, self.key, entry["name"]) self.schema_yaml_vars.env_vars = {} + if self.schema_yaml_vars.vars: + schema_file.add_vars(self.schema_yaml_vars.vars, self.key, entry["name"]) + self.schema_yaml_vars.vars = {} + yield entry def render_entry(self, dct): @@ -803,6 +807,9 @@ def _get_node_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> Pa derived_semantics = block.target.derived_semantics agg_time_dimension = block.target.agg_time_dimension primary_entity = block.target.primary_entity + assert isinstance(self.yaml.file, SchemaSourceFile) + source_file: SchemaSourceFile = self.yaml.file + return ParsedNodePatch( name=block.target.name, original_file_path=block.target.original_file_path, @@ -818,6 +825,7 @@ def _get_node_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> Pa latest_version=None, constraints=block.target.constraints, deprecation_date=deprecation_date, + vars=source_file.get_vars(block.target.yaml_key, block.target.name), time_spine=time_spine, semantic_model=semantic_model, metrics=metrics, @@ -931,6 +939,8 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None: node.description = patch.description node.columns = patch.columns node.name = patch.name + # Prefer node-level vars to vars from patch + node.vars = {**patch.vars, **node.vars} if not isinstance(node, ModelNode): for attr in ["latest_version", "access", "version", "constraints"]: @@ -1080,6 +1090,7 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef) latest_version=latest_version, constraints=unparsed_version.constraints or target.constraints, deprecation_date=unparsed_version.deprecation_date, + vars=source_file.get_vars(block.target.yaml_key, block.target.name), ) # Node patched before config because config patching depends on model name, # which may have been updated in the version patch @@ -1330,6 +1341,9 @@ def _get_node_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> Pa target = block.target assert isinstance(target, UnparsedFunctionUpdate) + assert isinstance(self.yaml.file, SchemaSourceFile) + source_file: SchemaSourceFile = self.yaml.file + return ParsedFunctionPatch( name=target.name, original_file_path=target.original_file_path, @@ -1345,6 +1359,7 @@ def _get_node_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> Pa latest_version=None, constraints=target.constraints, deprecation_date=None, + vars=source_file.get_vars(target.yaml_key, target.name), time_spine=None, arguments=target.arguments, returns=target.returns, diff --git a/core/dbt/parser/sources.py b/core/dbt/parser/sources.py index a058e17e9b3..6878a3cba2b 100644 --- a/core/dbt/parser/sources.py +++ b/core/dbt/parser/sources.py @@ -12,6 +12,7 @@ ContextConfigGenerator, UnrenderedConfigGenerator, ) +from dbt.contracts.files import SchemaSourceFile from dbt.contracts.graph.manifest import Manifest, SourceKey from dbt.contracts.graph.nodes import ( GenericTestNode, @@ -146,6 +147,13 @@ def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition: rendered=False, ) + schema_file = self.manifest.files.get(target.file_id) + source_vars = ( + schema_file.get_vars("sources", source.name) + if isinstance(schema_file, SchemaSourceFile) + else {} + ) + if not isinstance(config, SourceConfig): raise DbtInternalError( f"Calculated a {type(config)} for a source, but expected a SourceConfig" @@ -181,6 +189,7 @@ def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition: tags=config.tags, config=config, unrendered_config=unrendered_config, + vars=source_vars, ) if ( diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index d5d8ada5130..2264135079d 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -100,6 +100,7 @@ "defer_relation", "time_spine", "batch", + "vars", } ) diff --git a/tests/unit/contracts/graph/test_nodes.py b/tests/unit/contracts/graph/test_nodes.py index 2a19c878c7d..9e87cdefd6e 100644 --- a/tests/unit/contracts/graph/test_nodes.py +++ b/tests/unit/contracts/graph/test_nodes.py @@ -207,6 +207,7 @@ def basic_compiled_dict(): }, "unrendered_config": {}, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "access": "protected", "constraints": [], @@ -531,6 +532,7 @@ def basic_compiled_schema_test_dict(): "severity": "warn", }, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "doc_blocks": [], } diff --git a/tests/unit/contracts/graph/test_nodes_parsed.py b/tests/unit/contracts/graph/test_nodes_parsed.py index 83bca46036b..d8a53f37156 100644 --- a/tests/unit/contracts/graph/test_nodes_parsed.py +++ b/tests/unit/contracts/graph/test_nodes_parsed.py @@ -206,6 +206,7 @@ def base_parsed_model_dict(): }, "unrendered_config": {}, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "access": AccessType.Protected.value, "constraints": [], @@ -263,6 +264,7 @@ def minimal_parsed_model_dict(): "checksum": "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", }, "unrendered_config": {}, + "vars": {}, } @@ -333,6 +335,7 @@ def complex_parsed_model_dict(): "post_hook": ['insert into blah(a, b) select "1", 1'], }, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "access": AccessType.Protected.value, "constraints": [], @@ -544,6 +547,7 @@ def basic_parsed_seed_dict(): "checksum": {"name": "path", "checksum": "seeds/seed.csv"}, "unrendered_config": {}, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "doc_blocks": [], } @@ -653,6 +657,7 @@ def complex_parsed_seed_dict(): "persist_docs": {"relation": True, "columns": True}, }, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "doc_blocks": [], } @@ -855,6 +860,7 @@ def base_parsed_hook_dict(): }, "unrendered_config": {}, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "doc_blocks": [], } @@ -956,6 +962,7 @@ def complex_parsed_hook_dict(): "materialized": "table", }, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "doc_blocks": [], } @@ -1103,6 +1110,7 @@ def basic_parsed_schema_test_dict(): }, "unrendered_config": {}, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "doc_blocks": [], } @@ -1196,6 +1204,7 @@ def complex_parsed_schema_test_dict(): }, "unrendered_config": {"materialized": "table", "severity": "WARN"}, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "doc_blocks": [], } @@ -1593,6 +1602,7 @@ def basic_timestamp_snapshot_dict(): "target_schema": "some_snapshot_schema", }, "unrendered_config_call_dict": {}, + "vars": {}, "config_call_dict": {}, "doc_blocks": [], } @@ -1701,6 +1711,7 @@ def basic_check_snapshot_dict(): }, "unrendered_config_call_dict": {}, "config_call_dict": {}, + "vars": {}, "doc_blocks": [], } @@ -1922,6 +1933,7 @@ def basic_parsed_source_definition_dict(): "meta": {}, }, "unrendered_config": {}, + "vars": {}, "doc_blocks": [], } @@ -1961,6 +1973,7 @@ def complex_parsed_source_definition_dict(): "freshness": {"warn_after": {"period": "hour", "count": 1}, "error_after": {}}, "loaded_at_field": "loaded_at", "unrendered_config": {}, + "vars": {}, "doc_blocks": [], } @@ -2143,6 +2156,7 @@ def basic_parsed_exposure_dict(): "meta": {}, }, "unrendered_config": {}, + "vars": {}, } @@ -2200,6 +2214,7 @@ def complex_parsed_exposure_dict(): "meta": {}, }, "unrendered_config": {}, + "vars": {}, } diff --git a/tests/unit/parser/test_parser.py b/tests/unit/parser/test_parser.py index ec26d85aa06..623c6355f5d 100644 --- a/tests/unit/parser/test_parser.py +++ b/tests/unit/parser/test_parser.py @@ -579,6 +579,7 @@ def test__read_basic_source(self): @mock.patch("dbt.parser.sources.get_adapter") def test_parse_source_custom_freshness_at_source(self, _): block = self.file_block_for(SOURCE_CUSTOM_FRESHNESS_AT_SOURCE, "test_one.yml") + self.parser.manifest.files[block.file.file_id] = block.file dct = yaml_from_file(block.file, validate=True) self.parser.parse_file(block, dct) unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] @@ -590,6 +591,7 @@ def test_parse_source_custom_freshness_at_source_field_at_table(self, _): block = self.file_block_for( SOURCE_CUSTOM_FRESHNESS_AT_SOURCE_FIELD_AT_TABLE, "test_one.yml" ) + self.parser.manifest.files[block.file.file_id] = block.file dct = yaml_from_file(block.file, validate=True) self.parser.parse_file(block, dct) unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] @@ -602,6 +604,7 @@ def test_parse_source_field_at_source_custom_freshness_at_table(self, _): block = self.file_block_for( SOURCE_FIELD_AT_SOURCE_CUSTOM_FRESHNESS_AT_TABLE, "test_one.yml" ) + self.parser.manifest.files[block.file.file_id] = block.file dct = yaml_from_file(block.file, validate=True) self.parser.parse_file(block, dct) unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] @@ -611,6 +614,7 @@ def test_parse_source_field_at_source_custom_freshness_at_table(self, _): @mock.patch("dbt.parser.sources.get_adapter") def test_parse_source_field_at_custom_freshness_both_at_table_fails(self, _): block = self.file_block_for(SOURCE_FIELD_AT_CUSTOM_FRESHNESS_BOTH_AT_TABLE, "test_one.yml") + self.parser.manifest.files[block.file.file_id] = block.file dct = yaml_from_file(block.file, validate=True) self.parser.parse_file(block, dct) unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] @@ -620,6 +624,7 @@ def test_parse_source_field_at_custom_freshness_both_at_table_fails(self, _): @mock.patch("dbt.parser.sources.get_adapter") def test_parse_source_resulting_node_freshness_matches_config_freshness(self, _): block = self.file_block_for(SOURCE_FRESHNESS_AT_TABLE_AND_CONFIG, "test_one.yml") + self.parser.manifest.files[block.file.file_id] = block.file dct = yaml_from_file(block.file, validate=True) self.parser.parse_file(block, dct) unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] @@ -635,6 +640,7 @@ def test_parse_source_field_at_custom_freshness_both_at_source_fails(self, _): block = self.file_block_for( SOURCE_FIELD_AT_CUSTOM_FRESHNESS_BOTH_AT_SOURCE, "test_one.yml" ) + self.parser.manifest.files[block.file.file_id] = block.file dct = yaml_from_file(block.file, validate=True) self.parser.parse_file(block, dct) unpatched_src_default = self.parser.manifest.sources["source.snowplow.my_source.my_table"] @@ -657,6 +663,7 @@ def test__parse_basic_source(self): @mock.patch("dbt.parser.sources.get_adapter") def test__parse_basic_source_meta(self, mock_get_adapter): block = self.file_block_for(MULTIPLE_TABLE_SOURCE_META, "test_one.yml") + self.parser.manifest.files[block.file.file_id] = block.file dct = yaml_from_file(block.file, validate=True) self.parser.parse_file(block, dct) self.assert_has_manifest_lengths(self.parser.manifest, sources=2)