diff --git a/.changes/unreleased/Enhancement or New Feature-20260313-125806.yaml b/.changes/unreleased/Enhancement or New Feature-20260313-125806.yaml new file mode 100644 index 00000000..ebff38b9 --- /dev/null +++ b/.changes/unreleased/Enhancement or New Feature-20260313-125806.yaml @@ -0,0 +1,3 @@ +kind: Enhancement or New Feature +body: Add structured content output to get_lineage tool with LineageGraph model +time: 2026-03-13T12:58:06.544448-04:00 diff --git a/src/dbt_mcp/discovery/tools.py b/src/dbt_mcp/discovery/tools.py index 6d8400d5..0525234c 100644 --- a/src/dbt_mcp/discovery/tools.py +++ b/src/dbt_mcp/discovery/tools.py @@ -1,8 +1,11 @@ +import json import logging from dataclasses import dataclass +from typing import Annotated from mcp.server.fastmcp import FastMCP -from pydantic import Field +from mcp.types import CallToolResult, TextContent +from pydantic import BaseModel, Field from dbt_mcp.config.config_providers import ConfigProvider, DiscoveryConfig from dbt_mcp.discovery.client import ( @@ -233,22 +236,63 @@ async def get_model_performance( return results +class LineageNode(BaseModel): + unique_id: str + name: str + resource_type: str + + +class LineageEdge(BaseModel): + source: str + target: str + + +class LineageGraph(BaseModel): + type: str = "lineage_graph" + root_id: str + nodes: list[LineageNode] + edges: list[LineageEdge] + + @dbt_mcp_tool( description=get_prompt("discovery/get_lineage"), title="Get Lineage", read_only_hint=True, destructive_hint=False, idempotent_hint=True, + structured_output=True, ) async def get_lineage( context: DiscoveryToolContext, unique_id: str = UNIQUE_ID_REQUIRED_FIELD, types: list[LineageResourceType] | None = TYPES_FIELD, depth: int = DEPTH_FIELD, -) -> list[dict]: - return await context.lineage_fetcher.fetch_lineage( +) -> Annotated[CallToolResult, LineageGraph]: + nodes = await context.lineage_fetcher.fetch_lineage( unique_id=unique_id, types=types, depth=depth ) + node_ids = {n["uniqueId"] for n in nodes} + graph = LineageGraph( + root_id=unique_id, + nodes=[ + LineageNode( + unique_id=n["uniqueId"], + name=n["name"], + resource_type=n["resourceType"], + ) + for n in nodes + ], + edges=[ + LineageEdge(source=parent_id, target=n["uniqueId"]) + for n in nodes + for parent_id in n.get("parentIds", []) + if parent_id in node_ids + ], + ) + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(nodes))], + structuredContent=graph.model_dump(mode="json"), + ) @dbt_mcp_tool( diff --git a/tests/unit/discovery/test_get_lineage.py b/tests/unit/discovery/test_get_lineage.py new file mode 100644 index 00000000..05d615e4 --- /dev/null +++ b/tests/unit/discovery/test_get_lineage.py @@ -0,0 +1,206 @@ +import json +from unittest.mock import AsyncMock, Mock + +import pytest +from mcp.types import CallToolResult, TextContent + +from dbt_mcp.discovery.tools import ( + DiscoveryToolContext, + get_lineage as get_lineage_tool, +) + +# Access the underlying function from the ToolDefinition +get_lineage = get_lineage_tool.fn + + +@pytest.fixture +def mock_discovery_tool_context(): + """Mock DiscoveryToolContext for testing.""" + context = Mock(spec=DiscoveryToolContext) + context.lineage_fetcher = AsyncMock() + return context + + +SAMPLE_NODES = [ + { + "uniqueId": "source.test.raw_customers", + "name": "raw_customers", + "resourceType": "Source", + "parentIds": [], + }, + { + "uniqueId": "model.test.customers", + "name": "customers", + "resourceType": "Model", + "parentIds": ["source.test.raw_customers"], + }, + { + "uniqueId": "model.test.customer_metrics", + "name": "customer_metrics", + "resourceType": "Model", + "parentIds": ["model.test.customers"], + }, +] + + +async def test_get_lineage_returns_call_tool_result(mock_discovery_tool_context): + """Test that get_lineage returns a CallToolResult.""" + mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = ( + SAMPLE_NODES + ) + + result = await get_lineage( + context=mock_discovery_tool_context, + unique_id="model.test.customers", + types=None, + depth=5, + ) + + assert isinstance(result, CallToolResult) + + +async def test_get_lineage_text_content_contains_raw_nodes(mock_discovery_tool_context): + """Test that the text content contains the raw node data as JSON.""" + mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = ( + SAMPLE_NODES + ) + + result = await get_lineage( + context=mock_discovery_tool_context, + unique_id="model.test.customers", + types=None, + depth=5, + ) + + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + parsed = json.loads(result.content[0].text) + assert parsed == SAMPLE_NODES + + +async def test_get_lineage_structured_content_has_correct_graph( + mock_discovery_tool_context, +): + """Test that structuredContent contains a well-formed LineageGraph.""" + mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = ( + SAMPLE_NODES + ) + + result = await get_lineage( + context=mock_discovery_tool_context, + unique_id="model.test.customers", + types=None, + depth=5, + ) + + sc = result.structuredContent + assert sc["type"] == "lineage_graph" + assert sc["root_id"] == "model.test.customers" + assert len(sc["nodes"]) == 3 + assert len(sc["edges"]) == 2 + + node_ids = {n["unique_id"] for n in sc["nodes"]} + assert node_ids == { + "source.test.raw_customers", + "model.test.customers", + "model.test.customer_metrics", + } + + edges = {(e["source"], e["target"]) for e in sc["edges"]} + assert edges == { + ("source.test.raw_customers", "model.test.customers"), + ("model.test.customers", "model.test.customer_metrics"), + } + + +async def test_get_lineage_filters_edges_to_known_nodes(mock_discovery_tool_context): + """Test that edges referencing nodes outside the graph are excluded.""" + nodes_with_external_parent = [ + { + "uniqueId": "model.test.customers", + "name": "customers", + "resourceType": "Model", + "parentIds": ["source.test.raw_customers", "model.other.unknown"], + }, + { + "uniqueId": "source.test.raw_customers", + "name": "raw_customers", + "resourceType": "Source", + "parentIds": [], + }, + ] + mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = ( + nodes_with_external_parent + ) + + result = await get_lineage( + context=mock_discovery_tool_context, + unique_id="model.test.customers", + types=None, + depth=5, + ) + + sc = result.structuredContent + # Only the edge from raw_customers should exist; model.other.unknown is not in the graph + assert len(sc["edges"]) == 1 + assert sc["edges"][0]["source"] == "source.test.raw_customers" + assert sc["edges"][0]["target"] == "model.test.customers" + + +async def test_get_lineage_empty_result(mock_discovery_tool_context): + """Test that an empty fetcher result produces an empty graph.""" + mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = [] + + result = await get_lineage( + context=mock_discovery_tool_context, + unique_id="model.test.nonexistent", + types=None, + depth=5, + ) + + sc = result.structuredContent + assert sc["nodes"] == [] + assert sc["edges"] == [] + assert sc["root_id"] == "model.test.nonexistent" + + +async def test_get_lineage_passes_parameters_to_fetcher(mock_discovery_tool_context): + """Test that parameters are correctly forwarded to the fetcher.""" + mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = [] + + await get_lineage( + context=mock_discovery_tool_context, + unique_id="model.test.customers", + types=["Model", "Source"], + depth=3, + ) + + mock_discovery_tool_context.lineage_fetcher.fetch_lineage.assert_called_once_with( + unique_id="model.test.customers", + types=["Model", "Source"], + depth=3, + ) + + +async def test_get_lineage_node_without_parent_ids(mock_discovery_tool_context): + """Test handling of nodes that lack a parentIds field.""" + nodes = [ + { + "uniqueId": "model.test.orphan", + "name": "orphan", + "resourceType": "Model", + # no parentIds key + }, + ] + mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = nodes + + result = await get_lineage( + context=mock_discovery_tool_context, + unique_id="model.test.orphan", + types=None, + depth=5, + ) + + sc = result.structuredContent + assert len(sc["nodes"]) == 1 + assert sc["edges"] == []