Skip to content

Commit 4f049d8

Browse files
committed
Add SchemaFromExistingGraphExtractor component
Parses the result from get_structured_schema and returns a GraphSchema object
1 parent fa7cc94 commit 4f049d8

File tree

2 files changed

+117
-2
lines changed

2 files changed

+117
-2
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""This example demonstrates how to use the SchemaFromExistingGraphExtractor component
2+
to automatically extract a schema from an existing Neo4j database.
3+
"""
4+
5+
import asyncio
6+
7+
import neo4j
8+
9+
from neo4j_graphrag.experimental.components.schema import (
10+
SchemaFromExistingGraphExtractor,
11+
GraphSchema,
12+
)
13+
14+
15+
URI = "neo4j+s://demo.neo4jlabs.com"
16+
AUTH = ("recommendations", "recommendations")
17+
DATABASE = "recommendations"
18+
INDEX = "moviePlotsEmbedding"
19+
20+
21+
async def main() -> None:
22+
"""Run the example."""
23+
24+
with neo4j.GraphDatabase.driver(
25+
URI,
26+
auth=AUTH,
27+
) as driver:
28+
extractor = SchemaFromExistingGraphExtractor(driver)
29+
schema: GraphSchema = await extractor.run()
30+
# schema.store_as_json("my_schema.json")
31+
print(schema)
32+
33+
34+
if __name__ == "__main__":
35+
asyncio.run(main())

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from __future__ import annotations
1616

1717
import json
18+
19+
import neo4j
1820
import logging
1921
import warnings
2022
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence, Callable
@@ -44,6 +46,7 @@
4446
from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
4547
from neo4j_graphrag.llm import LLMInterface
4648
from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
49+
from neo4j_graphrag.schema import get_structured_schema
4750

4851

4952
class PropertyType(BaseModel):
@@ -294,7 +297,12 @@ def from_file(
294297
raise SchemaValidationError(str(e)) from e
295298

296299

297-
class SchemaBuilder(Component):
300+
class BaseSchemaBuilder(Component):
301+
async def run(self, *args: Any, **kwargs: Any) -> GraphSchema:
302+
raise NotImplementedError()
303+
304+
305+
class SchemaBuilder(BaseSchemaBuilder):
298306
"""
299307
A builder class for constructing GraphSchema objects from given entities,
300308
relations, and their interrelationships defined in a potential schema.
@@ -412,7 +420,7 @@ async def run(
412420
)
413421

414422

415-
class SchemaFromTextExtractor(Component):
423+
class SchemaFromTextExtractor(BaseSchemaBuilder):
416424
"""
417425
A component for constructing GraphSchema objects from the output of an LLM after
418426
automatic schema extraction from text.
@@ -620,3 +628,75 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
620628
"patterns": extracted_patterns,
621629
}
622630
)
631+
632+
633+
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
634+
"""A class to build a GraphSchema object from an existing graph."""
635+
636+
def __init__(self, driver: neo4j.Driver) -> None:
637+
self.driver = driver
638+
639+
async def run(self, **kwargs: Any) -> GraphSchema:
640+
structured_schema = get_structured_schema(self.driver)
641+
node_labels = set(structured_schema["node_props"].keys())
642+
node_types = [
643+
{
644+
"label": key,
645+
"properties": [
646+
{
647+
"name": p["property"],
648+
"type": p["type"],
649+
}
650+
for p in properties
651+
],
652+
}
653+
for key, properties in structured_schema["node_props"].items()
654+
]
655+
rel_labels = set(structured_schema["rel_props"].keys())
656+
relationship_types = [
657+
{
658+
"label": key,
659+
"properties": [
660+
{
661+
"name": p["property"],
662+
"type": p["type"],
663+
}
664+
for p in properties
665+
],
666+
}
667+
for key, properties in structured_schema["rel_props"].items()
668+
]
669+
patterns = [
670+
(s["start"], s["type"], s["end"])
671+
for s in structured_schema["relationships"]
672+
]
673+
# deal with nodes and relationships without properties
674+
for source, rel, target in patterns:
675+
if source not in node_labels:
676+
node_labels.add(source)
677+
node_types.append(
678+
{
679+
"label": source,
680+
}
681+
)
682+
if target not in node_labels:
683+
node_labels.add(target)
684+
node_types.append(
685+
{
686+
"label": target,
687+
}
688+
)
689+
if rel not in rel_labels:
690+
rel_labels.add(rel)
691+
relationship_types.append(
692+
{
693+
"label": rel,
694+
}
695+
)
696+
return GraphSchema.model_validate(
697+
{
698+
"node_types": node_types,
699+
"relationship_types": relationship_types,
700+
"patterns": patterns,
701+
}
702+
)

0 commit comments

Comments
 (0)