Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@

logger = logging.getLogger(__name__)

VALID_DIALECTS = (
"athena",
"bigquery",
"databricks",
"postgres",
"redshift",
"snowflake",
"spark",
"trino",
)


@dataclass
class Config:
Expand Down Expand Up @@ -57,7 +68,11 @@ def is_valid_field(cls, field_name: str) -> bool:

dialect = os.getenv("INPUT_DIALECT", None)
if dialect is not None:
env_vars["dialect"] = dialect
if dialect.lower() not in VALID_DIALECTS:
raise ValueError(
f"Invalid dialect: {dialect}. Valid dialects are: {VALID_DIALECTS}"
)
env_vars["dialect"] = dialect.lower()

dry_run = os.getenv("INPUT_DRY_RUN", "false").lower() == "true"
env_vars["dry_run"] = dry_run
Expand Down
3 changes: 3 additions & 0 deletions src/interfaces/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

if TYPE_CHECKING: # pragma: no cover
from src.models.node import Node
from src.config import Config


class LineageServiceProtocol(Protocol):
config: "Config"

def get_node_lineage(self, nodes: List["Node"]) -> Set[str]: ...

def get_column_lineage(self, node_id: str, column_name: str) -> Set[str]: ...
Expand Down
18 changes: 17 additions & 1 deletion src/models/column_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,22 @@ class ColumnTracker:
_tracked_columns: Set[str] = field(default_factory=set)
_impacted_ids: Set[str] = field(default_factory=set)

def _column_name_for_dialect(self, column_name: str) -> str:
"""
Get the column name for the current dialect.

Args:
column_name: The original column name

Returns:
str: The column name for the current dialect
"""
if self._lineage_service.config.dialect == "snowflake":
return column_name.upper()

# TODO: Any other modifications?
return column_name

def track_node_columns(self, node: "Node") -> Set[str]:
"""
Track columns for a node and identify impacted downstream nodes.
Expand All @@ -56,7 +72,7 @@ def track_node_columns(self, node: "Node") -> Set[str]:
)
impacted_ids.update(
self._lineage_service.get_column_lineage(
node.unique_id, column_name
node.unique_id, self._column_name_for_dialect(column_name)
)
)
self._tracked_columns.add(node_column)
Expand Down
3 changes: 1 addition & 2 deletions src/services/discovery_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def get_column_lineage(
variables = {
"environmentId": environment_id,
"nodeUniqueId": node_id,
# TODO: This is a hack because Snowflake uppercases everything
"filters": {"columnName": column_name.upper()},
"filters": {"columnName": column_name},
}

lineage = self.config.dbtc_client.metadata.query(
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ def mock_dbt_runner() -> DbtRunnerProtocol:


@pytest.fixture
def mock_lineage_service() -> LineageServiceProtocol:
def mock_lineage_service(mock_config: Config) -> LineageServiceProtocol:
"""Create a mock lineage service."""
service = MagicMock(spec=LineageServiceProtocol)
service.config = mock_config

# Setup default return values
service.get_column_lineage.return_value = set()
Expand Down
29 changes: 22 additions & 7 deletions tests/models/test_column_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def test_track_node_columns_new_columns(mock_lineage_service, mock_node):
impacted_ids = tracker.track_node_columns(mock_node)

# Verify the results
expected_tracked_columns = {
"model.my_project.test_model.column1",
"model.my_project.test_model.column2",
}
expected_impacted_ids = {
"model.my_project.downstream_model1",
"model.my_project.downstream_model2",
}
expected_tracked_columns = {
"model.my_project.test_model.column1",
"model.my_project.test_model.column2",
}

assert tracker._tracked_columns == expected_tracked_columns
assert tracker._impacted_ids == expected_impacted_ids
Expand All @@ -51,10 +51,10 @@ def test_track_node_columns_new_columns(mock_lineage_service, mock_node):
# Verify lineage service was called correctly
assert mock_lineage_service.get_column_lineage.call_count == 2
mock_lineage_service.get_column_lineage.assert_any_call(
"model.my_project.test_model", "column1"
"model.my_project.test_model", "COLUMN1"
)
mock_lineage_service.get_column_lineage.assert_any_call(
"model.my_project.test_model", "column2"
"model.my_project.test_model", "COLUMN2"
)


Expand Down Expand Up @@ -85,7 +85,7 @@ def test_track_node_columns_already_tracked(mock_lineage_service, mock_node):

# Verify lineage service was called only once (for column2)
mock_lineage_service.get_column_lineage.assert_called_once_with(
"model.my_project.test_model", "column2"
"model.my_project.test_model", "COLUMN2"
)


Expand All @@ -100,3 +100,18 @@ def test_impacted_ids_property(mock_lineage_service):
assert tracker.impacted_ids == expected_ids
# Ensure we get a copy of the set, not the original
assert tracker.impacted_ids is not tracker._impacted_ids


def test_column_name_for_dialect(mock_lineage_service):
"""Test column name handling for different dialects."""
tracker = ColumnTracker(mock_lineage_service)

# Test Snowflake dialect (should uppercase)
mock_lineage_service.config.dialect = "snowflake"
assert tracker._column_name_for_dialect("test_column") == "TEST_COLUMN"
assert tracker._column_name_for_dialect("MixedCase") == "MIXEDCASE"

# Test other dialect (should return unchanged)
mock_lineage_service.config.dialect = "bigquery"
assert tracker._column_name_for_dialect("test_column") == "test_column"
assert tracker._column_name_for_dialect("MixedCase") == "MixedCase"
20 changes: 20 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,23 @@ def test_set_fields_from_dbtc_client_invalid_response(mock_config):
mock_config._set_fields_from_dbtc_client()

assert "An error occurred retrieving your job's data" in str(exc_info.value)


def test_config_invalid_dialect():
"""Test Config creation with an invalid dialect."""
env_vars = {
"INPUT_DBT_CLOUD_HOST": "cloud.getdbt.com",
"INPUT_DBT_CLOUD_SERVICE_TOKEN": "test_token",
"INPUT_DBT_CLOUD_TOKEN_NAME": "cloud-cli-6d65",
"INPUT_DBT_CLOUD_TOKEN_VALUE": "test_token_value",
"INPUT_DBT_CLOUD_ACCOUNT_ID": "43786",
"INPUT_DBT_CLOUD_JOB_ID": "567183",
"INPUT_DIALECT": "invalid_dialect",
}

with patch.dict("os.environ", env_vars, clear=True):
with pytest.raises(ValueError) as exc_info:
Config.from_env()

assert "Invalid dialect: invalid_dialect" in str(exc_info.value)
assert "Valid dialects are:" in str(exc_info.value)
Loading