|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
17 | 17 | import json
|
| 18 | + |
| 19 | +import neo4j |
18 | 20 | import logging
|
19 | 21 | import warnings
|
20 | 22 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence, Callable
|
|
44 | 46 | from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate
|
45 | 47 | from neo4j_graphrag.llm import LLMInterface
|
46 | 48 | from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat
|
| 49 | +from neo4j_graphrag.schema import get_structured_schema |
47 | 50 |
|
48 | 51 |
|
49 | 52 | class PropertyType(BaseModel):
|
@@ -294,7 +297,12 @@ def from_file(
|
294 | 297 | raise SchemaValidationError(str(e)) from e
|
295 | 298 |
|
296 | 299 |
|
297 |
| -class SchemaBuilder(Component): |
| 300 | +class BaseSchemaBuilder(Component): |
| 301 | + async def run(self, *args: Any, **kwargs: Any) -> GraphSchema: |
| 302 | + raise NotImplementedError() |
| 303 | + |
| 304 | + |
| 305 | +class SchemaBuilder(BaseSchemaBuilder): |
298 | 306 | """
|
299 | 307 | A builder class for constructing GraphSchema objects from given entities,
|
300 | 308 | relations, and their interrelationships defined in a potential schema.
|
@@ -412,7 +420,7 @@ async def run(
|
412 | 420 | )
|
413 | 421 |
|
414 | 422 |
|
415 |
| -class SchemaFromTextExtractor(Component): |
| 423 | +class SchemaFromTextExtractor(BaseSchemaBuilder): |
416 | 424 | """
|
417 | 425 | A component for constructing GraphSchema objects from the output of an LLM after
|
418 | 426 | automatic schema extraction from text.
|
@@ -620,3 +628,75 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
|
620 | 628 | "patterns": extracted_patterns,
|
621 | 629 | }
|
622 | 630 | )
|
| 631 | + |
| 632 | + |
| 633 | +class SchemaFromExistingGraphExtractor(BaseSchemaBuilder): |
| 634 | + """A class to build a GraphSchema object from an existing graph.""" |
| 635 | + |
| 636 | + def __init__(self, driver: neo4j.Driver) -> None: |
| 637 | + self.driver = driver |
| 638 | + |
| 639 | + async def run(self, **kwargs: Any) -> GraphSchema: |
| 640 | + structured_schema = get_structured_schema(self.driver) |
| 641 | + node_labels = set(structured_schema["node_props"].keys()) |
| 642 | + node_types = [ |
| 643 | + { |
| 644 | + "label": key, |
| 645 | + "properties": [ |
| 646 | + { |
| 647 | + "name": p["property"], |
| 648 | + "type": p["type"], |
| 649 | + } |
| 650 | + for p in properties |
| 651 | + ], |
| 652 | + } |
| 653 | + for key, properties in structured_schema["node_props"].items() |
| 654 | + ] |
| 655 | + rel_labels = set(structured_schema["rel_props"].keys()) |
| 656 | + relationship_types = [ |
| 657 | + { |
| 658 | + "label": key, |
| 659 | + "properties": [ |
| 660 | + { |
| 661 | + "name": p["property"], |
| 662 | + "type": p["type"], |
| 663 | + } |
| 664 | + for p in properties |
| 665 | + ], |
| 666 | + } |
| 667 | + for key, properties in structured_schema["rel_props"].items() |
| 668 | + ] |
| 669 | + patterns = [ |
| 670 | + (s["start"], s["type"], s["end"]) |
| 671 | + for s in structured_schema["relationships"] |
| 672 | + ] |
| 673 | + # deal with nodes and relationships without properties |
| 674 | + for source, rel, target in patterns: |
| 675 | + if source not in node_labels: |
| 676 | + node_labels.add(source) |
| 677 | + node_types.append( |
| 678 | + { |
| 679 | + "label": source, |
| 680 | + } |
| 681 | + ) |
| 682 | + if target not in node_labels: |
| 683 | + node_labels.add(target) |
| 684 | + node_types.append( |
| 685 | + { |
| 686 | + "label": target, |
| 687 | + } |
| 688 | + ) |
| 689 | + if rel not in rel_labels: |
| 690 | + rel_labels.add(rel) |
| 691 | + relationship_types.append( |
| 692 | + { |
| 693 | + "label": rel, |
| 694 | + } |
| 695 | + ) |
| 696 | + return GraphSchema.model_validate( |
| 697 | + { |
| 698 | + "node_types": node_types, |
| 699 | + "relationship_types": relationship_types, |
| 700 | + "patterns": patterns, |
| 701 | + } |
| 702 | + ) |
0 commit comments