Skip to content

Commit 0b61fa2

Browse files
committed
feat: add dataset schema support
- Add a tool to retrieve the schema of a dataset - Modify get_entity so that when querying a dataset, it also returns the schema version
1 parent 0beb77a commit 0b61fa2

File tree

4 files changed

+178
-2
lines changed

4 files changed

+178
-2
lines changed

scripts/test_main_tools.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ async def main(urn_or_query: Optional[str]) -> None:
9494
indent=2,
9595
)
9696
)
97+
_divider()
98+
print(f"Getting versioned_dataset: {urn}")
99+
print(
100+
json.dumps(
101+
await _call_tool(
102+
mcp_client,
103+
"get_versioned_dataset",
104+
dataset_urn=urn_or_query,
105+
semantic_version="0.0.0",
106+
),
107+
indent=2,
108+
)
109+
)
97110

98111

99112
if __name__ == "__main__":

src/mcp_server_datahub/gql/entity_details.gql

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,3 +1302,59 @@ query GetEntityLineage($input: SearchAcrossLineageInput!) {
13021302
}
13031303
}
13041304
}
1305+
1306+
1307+
query getSchemaVersionList($input: GetSchemaVersionListInput!) {
1308+
getSchemaVersionList(input: $input) {
1309+
latestVersion {
1310+
semanticVersion
1311+
versionStamp
1312+
__typename
1313+
}
1314+
semanticVersionList {
1315+
semanticVersion
1316+
versionStamp
1317+
__typename
1318+
}
1319+
__typename
1320+
}
1321+
}
1322+
1323+
1324+
query getVersionedDataset($urn: String!, $versionStamp: String) {
1325+
versionedDataset(urn: $urn, versionStamp: $versionStamp) {
1326+
schema {
1327+
fields {
1328+
fieldPath
1329+
jsonPath
1330+
nullable
1331+
description
1332+
type
1333+
nativeDataType
1334+
recursive
1335+
isPartOfKey
1336+
isPartitioningKey
1337+
__typename
1338+
}
1339+
lastObserved
1340+
__typename
1341+
}
1342+
editableSchemaMetadata {
1343+
editableSchemaFieldInfo {
1344+
fieldPath
1345+
description
1346+
globalTags {
1347+
...globalTagsFields
1348+
__typename
1349+
}
1350+
glossaryTerms {
1351+
...glossaryTerms
1352+
__typename
1353+
}
1354+
__typename
1355+
}
1356+
__typename
1357+
}
1358+
__typename
1359+
}
1360+
}

src/mcp_server_datahub/mcp_server.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from datahub.sdk.search_filters import Filter, FilterDsl, load_filters
2626
from datahub.utilities.ordered_set import OrderedSet
2727
from fastmcp import FastMCP
28-
from pydantic import BaseModel
28+
from pydantic import BaseModel, Field
29+
from functools import lru_cache
2930

3031
_P = ParamSpec("_P")
3132
_R = TypeVar("_R")
@@ -173,7 +174,80 @@ def _clean_get_entity_response(raw_response: dict) -> dict:
173174
return response
174175

175176

176-
@mcp.tool(description="Get an entity by its DataHub URN.")
177+
class SemanticVersionStruct(BaseModel):
178+
semantic_version: str = Field(alias="semanticVersion")
179+
version_stamp: str = Field(alias="versionStamp")
180+
181+
182+
class SchemaVersionList(BaseModel):
183+
latest_version: SemanticVersionStruct
184+
versions: list[SemanticVersionStruct]
185+
186+
187+
class DatasetSchemaAPI:
188+
def __init__(self, graph: DataHubGraph) -> None:
189+
self._graph = graph
190+
191+
def get_schema_version_list(self, dataset_urn: str) -> SchemaVersionList | None:
192+
variables = {
193+
"input": {
194+
"datasetUrn": dataset_urn,
195+
}
196+
}
197+
resp = _execute_graphql(
198+
self._graph,
199+
query=entity_details_fragment_gql,
200+
variables=variables,
201+
operation_name="getSchemaVersionList",
202+
)
203+
204+
if not (raw_schema_versions := resp.get("getSchemaVersionList")):
205+
return None
206+
207+
return SchemaVersionList(
208+
latest_version=SemanticVersionStruct.model_validate(
209+
raw_schema_versions.get("latestVersion", {})
210+
),
211+
versions=[
212+
SemanticVersionStruct.model_validate(structs)
213+
for structs in raw_schema_versions.get("semanticVersionList", [])
214+
],
215+
)
216+
217+
def get_versioned_dataset(
218+
self, dataset_urn: str, semantic_version: str
219+
) -> dict[str, Any]:
220+
variables = {
221+
"urn": dataset_urn,
222+
"versionStamp": self._get_version_timestamp(dataset_urn, semantic_version),
223+
}
224+
resp = _execute_graphql(
225+
self._graph,
226+
query=entity_details_fragment_gql,
227+
variables=variables,
228+
operation_name="getVersionedDataset",
229+
)
230+
return _clean_gql_response(resp.get("versionedDataset", {}))
231+
232+
def _get_version_timestamp(self, dataset_urn: str, semantic_version: str):
233+
if not (schema_version_list := self.get_schema_version_list(dataset_urn)):
234+
raise ValueError(f"No schema_version_list found for dataset {dataset_urn}")
235+
236+
version_stamp_mapping = {
237+
struct.semantic_version: struct.version_stamp
238+
for struct in schema_version_list.versions
239+
}
240+
241+
if not (version_stamp := version_stamp_mapping.get(semantic_version)):
242+
raise ValueError(
243+
f"Version '{semantic_version}' not found for dataset '{dataset_urn}'"
244+
)
245+
return version_stamp
246+
247+
248+
@mcp.tool(
249+
description="Get an entity by its DataHub URN. This also provide schema_version_list(latest version, all versions) if available."
250+
)
177251
@async_background
178252
def get_entity(urn: str) -> dict:
179253
client = get_datahub_client()
@@ -193,6 +267,22 @@ def get_entity(urn: str) -> dict:
193267

194268
_inject_urls_for_urns(client._graph, result, [""])
195269

270+
if result.get("urn", "").startswith("urn:li:dataset:"):
271+
schema_api = DatasetSchemaAPI(client._graph)
272+
273+
if schema_version_list := schema_api.get_schema_version_list(urn):
274+
sorted_versions = sorted(
275+
[v.semantic_version for v in schema_version_list.versions]
276+
)
277+
latest_semantic_version = (
278+
schema_version_list.latest_version.semantic_version
279+
)
280+
281+
result["schemaVersionList"] = {
282+
"latestVersion": latest_semantic_version,
283+
"versions": sorted_versions,
284+
}
285+
196286
return _clean_get_entity_response(result)
197287

198288

@@ -441,3 +531,13 @@ def get_lineage(
441531
lineage = lineage_api.get_lineage(asset_lineage_directive)
442532
_inject_urls_for_urns(client._graph, lineage, ["*.searchResults[].entity"])
443533
return lineage
534+
535+
536+
@mcp.tool(description="Get schema from a dataset by its URN and version.")
537+
@async_background
538+
@lru_cache(maxsize=20)
539+
def get_versioned_dataset(dataset_urn: str, semantic_version: str) -> dict:
540+
client = get_datahub_client()
541+
schema_api = DatasetSchemaAPI(client._graph)
542+
543+
return schema_api.get_versioned_dataset(dataset_urn, semantic_version)

tests/test_mcp_server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
get_dataset_queries,
1111
get_entity,
1212
get_lineage,
13+
get_versioned_dataset,
1314
mcp,
1415
search,
1516
with_datahub_client,
@@ -131,3 +132,9 @@ async def test_get_dataset_queries() -> None:
131132
assert res is not None
132133
assert res.get("queries") is not None
133134
assert len(res.get("queries")) > 0
135+
136+
137+
@pytest.mark.anyio
138+
async def test_get_versioned_datset() -> None:
139+
res = await get_versioned_dataset.fn(_test_urn, "0.0.0")
140+
assert res is not None

0 commit comments

Comments
 (0)