Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Enhancement or New Feature
body: Add structured content output to get_lineage tool with LineageGraph model
time: 2026-03-13T12:58:06.544448-04:00
50 changes: 47 additions & 3 deletions src/dbt_mcp/discovery/tools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import logging
from dataclasses import dataclass
from typing import Annotated

from mcp.server.fastmcp import FastMCP
from pydantic import Field
from mcp.types import CallToolResult, TextContent
from pydantic import BaseModel, Field

from dbt_mcp.config.config_providers import ConfigProvider, DiscoveryConfig
from dbt_mcp.discovery.client import (
Expand Down Expand Up @@ -233,22 +236,63 @@ async def get_model_performance(
return results


class LineageNode(BaseModel):
unique_id: str
name: str
resource_type: str


class LineageEdge(BaseModel):
source: str
target: str


class LineageGraph(BaseModel):
type: str = "lineage_graph"
root_id: str
nodes: list[LineageNode]
edges: list[LineageEdge]


@dbt_mcp_tool(
description=get_prompt("discovery/get_lineage"),
title="Get Lineage",
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
structured_output=True,
)
async def get_lineage(
context: DiscoveryToolContext,
unique_id: str = UNIQUE_ID_REQUIRED_FIELD,
types: list[LineageResourceType] | None = TYPES_FIELD,
depth: int = DEPTH_FIELD,
) -> list[dict]:
return await context.lineage_fetcher.fetch_lineage(
) -> Annotated[CallToolResult, LineageGraph]:
nodes = await context.lineage_fetcher.fetch_lineage(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps fetch_lineage should return the pydantic model? Less dicts and more typed objects in general are better.

unique_id=unique_id, types=types, depth=depth
)
node_ids = {n["uniqueId"] for n in nodes}
graph = LineageGraph(
root_id=unique_id,
nodes=[
LineageNode(
unique_id=n["uniqueId"],
name=n["name"],
resource_type=n["resourceType"],
)
for n in nodes
],
edges=[
LineageEdge(source=parent_id, target=n["uniqueId"])
for n in nodes
for parent_id in n.get("parentIds", [])
if parent_id in node_ids
],
)
return CallToolResult(
content=[TextContent(type="text", text=json.dumps(nodes))],
structuredContent=graph.model_dump(mode="json"),
)


@dbt_mcp_tool(
Expand Down
206 changes: 206 additions & 0 deletions tests/unit/discovery/test_get_lineage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import json
from unittest.mock import AsyncMock, Mock

import pytest
from mcp.types import CallToolResult, TextContent

from dbt_mcp.discovery.tools import (
DiscoveryToolContext,
get_lineage as get_lineage_tool,
)

# Access the underlying function from the ToolDefinition
get_lineage = get_lineage_tool.fn


@pytest.fixture
def mock_discovery_tool_context():
"""Mock DiscoveryToolContext for testing."""
context = Mock(spec=DiscoveryToolContext)
context.lineage_fetcher = AsyncMock()
return context


SAMPLE_NODES = [
{
"uniqueId": "source.test.raw_customers",
"name": "raw_customers",
"resourceType": "Source",
"parentIds": [],
},
{
"uniqueId": "model.test.customers",
"name": "customers",
"resourceType": "Model",
"parentIds": ["source.test.raw_customers"],
},
{
"uniqueId": "model.test.customer_metrics",
"name": "customer_metrics",
"resourceType": "Model",
"parentIds": ["model.test.customers"],
},
]


async def test_get_lineage_returns_call_tool_result(mock_discovery_tool_context):
"""Test that get_lineage returns a CallToolResult."""
mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = (
SAMPLE_NODES
)

result = await get_lineage(
context=mock_discovery_tool_context,
unique_id="model.test.customers",
types=None,
depth=5,
)

assert isinstance(result, CallToolResult)


async def test_get_lineage_text_content_contains_raw_nodes(mock_discovery_tool_context):
"""Test that the text content contains the raw node data as JSON."""
mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = (
SAMPLE_NODES
)

result = await get_lineage(
context=mock_discovery_tool_context,
unique_id="model.test.customers",
types=None,
depth=5,
)

assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
parsed = json.loads(result.content[0].text)
assert parsed == SAMPLE_NODES


async def test_get_lineage_structured_content_has_correct_graph(
mock_discovery_tool_context,
):
"""Test that structuredContent contains a well-formed LineageGraph."""
mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = (
SAMPLE_NODES
)

result = await get_lineage(
context=mock_discovery_tool_context,
unique_id="model.test.customers",
types=None,
depth=5,
)

sc = result.structuredContent
assert sc["type"] == "lineage_graph"
assert sc["root_id"] == "model.test.customers"
assert len(sc["nodes"]) == 3
assert len(sc["edges"]) == 2

node_ids = {n["unique_id"] for n in sc["nodes"]}
assert node_ids == {
"source.test.raw_customers",
"model.test.customers",
"model.test.customer_metrics",
}

edges = {(e["source"], e["target"]) for e in sc["edges"]}
assert edges == {
("source.test.raw_customers", "model.test.customers"),
("model.test.customers", "model.test.customer_metrics"),
}


async def test_get_lineage_filters_edges_to_known_nodes(mock_discovery_tool_context):
"""Test that edges referencing nodes outside the graph are excluded."""
nodes_with_external_parent = [
{
"uniqueId": "model.test.customers",
"name": "customers",
"resourceType": "Model",
"parentIds": ["source.test.raw_customers", "model.other.unknown"],
},
{
"uniqueId": "source.test.raw_customers",
"name": "raw_customers",
"resourceType": "Source",
"parentIds": [],
},
]
mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = (
nodes_with_external_parent
)

result = await get_lineage(
context=mock_discovery_tool_context,
unique_id="model.test.customers",
types=None,
depth=5,
)

sc = result.structuredContent
# Only the edge from raw_customers should exist; model.other.unknown is not in the graph
assert len(sc["edges"]) == 1
assert sc["edges"][0]["source"] == "source.test.raw_customers"
assert sc["edges"][0]["target"] == "model.test.customers"


async def test_get_lineage_empty_result(mock_discovery_tool_context):
"""Test that an empty fetcher result produces an empty graph."""
mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = []

result = await get_lineage(
context=mock_discovery_tool_context,
unique_id="model.test.nonexistent",
types=None,
depth=5,
)

sc = result.structuredContent
assert sc["nodes"] == []
assert sc["edges"] == []
assert sc["root_id"] == "model.test.nonexistent"


async def test_get_lineage_passes_parameters_to_fetcher(mock_discovery_tool_context):
"""Test that parameters are correctly forwarded to the fetcher."""
mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = []

await get_lineage(
context=mock_discovery_tool_context,
unique_id="model.test.customers",
types=["Model", "Source"],
depth=3,
)

mock_discovery_tool_context.lineage_fetcher.fetch_lineage.assert_called_once_with(
unique_id="model.test.customers",
types=["Model", "Source"],
depth=3,
)


async def test_get_lineage_node_without_parent_ids(mock_discovery_tool_context):
"""Test handling of nodes that lack a parentIds field."""
nodes = [
{
"uniqueId": "model.test.orphan",
"name": "orphan",
"resourceType": "Model",
# no parentIds key
},
]
mock_discovery_tool_context.lineage_fetcher.fetch_lineage.return_value = nodes

result = await get_lineage(
context=mock_discovery_tool_context,
unique_id="model.test.orphan",
types=None,
depth=5,
)

sc = result.structuredContent
assert len(sc["nodes"]) == 1
assert sc["edges"] == []
Loading