Skip to content

Commit 70c2b73

Browse files
committed
Extract required properties from existing constraints
1 parent 4f049d8 commit 70c2b73

File tree

1 file changed

+75
-4
lines changed
  • src/neo4j_graphrag/experimental/components

1 file changed

+75
-4
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -631,13 +631,78 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
631631

632632

633633
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
634-
"""A class to build a GraphSchema object from an existing graph."""
634+
"""A class to build a GraphSchema object from an existing graph.
635635
636-
def __init__(self, driver: neo4j.Driver) -> None:
636+
Uses the get_structured_schema function to extract existing node labels,
637+
relationship types, properties and existence constraints.
638+
639+
By default, the built schema does not allow any additional item (property,
640+
node label, relationship type or pattern).
641+
642+
Args:
643+
driver (neo4j.Driver): connection to the neo4j database.
644+
additional_properties (bool, default False): see GraphSchema
645+
additional_node_types (bool, default False): see GraphSchema
646+
additional_relationship_types (bool, default False): see GraphSchema:
647+
additional_patterns (bool, default False): see GraphSchema:
648+
neo4j_database (Optional | str): name of the neo4j database to use
649+
"""
650+
651+
def __init__(
652+
self,
653+
driver: neo4j.Driver,
654+
additional_properties: bool = False,
655+
additional_node_types: bool = False,
656+
additional_relationship_types: bool = False,
657+
additional_patterns: bool = False,
658+
neo4j_database: Optional[str] = None,
659+
) -> None:
637660
self.driver = driver
661+
self.database = neo4j_database
662+
663+
self.additional_properties = additional_properties
664+
self.additional_node_types = additional_node_types
665+
self.additional_relationship_types = additional_relationship_types
666+
self.additional_patterns = additional_patterns
667+
668+
@staticmethod
669+
def _extract_required_properties(
670+
structured_schema: dict[str, Any],
671+
) -> list[tuple[str, str]]:
672+
"""Extract a list of (node label (or rel type), property name) for which
673+
an "EXISTENCE" or "KEY" constraint is defined in the DB.
674+
675+
Args:
676+
677+
structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
678+
679+
Returns:
680+
681+
list of tuples of (node label (or rel type), property name)
682+
683+
"""
684+
schema_metadata = structured_schema.get("metadata", {})
685+
existence_constraint = [] # list of (node label, property name)
686+
for constraint in schema_metadata.get("constraints", []):
687+
if constraint["type"] in (
688+
"NODE_PROPERTY_EXISTENCE",
689+
"NODE_KEY",
690+
"RELATIONSHIP_PROPERTY_EXISTENCE",
691+
"RELATIONSHIP_KEY",
692+
):
693+
properties = constraint["properties"]
694+
labels = constraint["labelsOrTypes"]
695+
# note: existence constraint only apply to a single property
696+
# and a single label
697+
prop = properties[0]
698+
lab = labels[0]
699+
existence_constraint.append((lab, prop))
700+
return existence_constraint
701+
702+
async def run(self) -> GraphSchema:
703+
structured_schema = get_structured_schema(self.driver, database=self.database)
704+
existence_constraint = self._extract_required_properties(structured_schema)
638705

639-
async def run(self, **kwargs: Any) -> GraphSchema:
640-
structured_schema = get_structured_schema(self.driver)
641706
node_labels = set(structured_schema["node_props"].keys())
642707
node_types = [
643708
{
@@ -646,9 +711,11 @@ async def run(self, **kwargs: Any) -> GraphSchema:
646711
{
647712
"name": p["property"],
648713
"type": p["type"],
714+
"required": (key, p["property"]) in existence_constraint,
649715
}
650716
for p in properties
651717
],
718+
"additional_properties": self.additional_properties,
652719
}
653720
for key, properties in structured_schema["node_props"].items()
654721
]
@@ -660,6 +727,7 @@ async def run(self, **kwargs: Any) -> GraphSchema:
660727
{
661728
"name": p["property"],
662729
"type": p["type"],
730+
"required": (key, p["property"]) in existence_constraint,
663731
}
664732
for p in properties
665733
],
@@ -698,5 +766,8 @@ async def run(self, **kwargs: Any) -> GraphSchema:
698766
"node_types": node_types,
699767
"relationship_types": relationship_types,
700768
"patterns": patterns,
769+
"additional_node_types": self.additional_node_types,
770+
"additional_relationship_types": self.additional_relationship_types,
771+
"additional_patterns": self.additional_patterns,
701772
}
702773
)

0 commit comments

Comments
 (0)