@@ -631,13 +631,78 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
631
631
632
632
633
633
class SchemaFromExistingGraphExtractor (BaseSchemaBuilder ):
634
- """A class to build a GraphSchema object from an existing graph."""
634
+ """A class to build a GraphSchema object from an existing graph.
635
635
636
- def __init__ (self , driver : neo4j .Driver ) -> None :
636
+ Uses the get_structured_schema function to extract existing node labels,
637
+ relationship types, properties and existence constraints.
638
+
639
+ By default, the built schema does not allow any additional item (property,
640
+ node label, relationship type or pattern).
641
+
642
+ Args:
643
+ driver (neo4j.Driver): connection to the neo4j database.
644
+ additional_properties (bool, default False): see GraphSchema
645
+ additional_node_types (bool, default False): see GraphSchema
646
+ additional_relationship_types (bool, default False): see GraphSchema:
647
+ additional_patterns (bool, default False): see GraphSchema:
648
+ neo4j_database (Optional | str): name of the neo4j database to use
649
+ """
650
+
651
+ def __init__ (
652
+ self ,
653
+ driver : neo4j .Driver ,
654
+ additional_properties : bool = False ,
655
+ additional_node_types : bool = False ,
656
+ additional_relationship_types : bool = False ,
657
+ additional_patterns : bool = False ,
658
+ neo4j_database : Optional [str ] = None ,
659
+ ) -> None :
637
660
self .driver = driver
661
+ self .database = neo4j_database
662
+
663
+ self .additional_properties = additional_properties
664
+ self .additional_node_types = additional_node_types
665
+ self .additional_relationship_types = additional_relationship_types
666
+ self .additional_patterns = additional_patterns
667
+
668
+ @staticmethod
669
+ def _extract_required_properties (
670
+ structured_schema : dict [str , Any ],
671
+ ) -> list [tuple [str , str ]]:
672
+ """Extract a list of (node label (or rel type), property name) for which
673
+ an "EXISTENCE" or "KEY" constraint is defined in the DB.
674
+
675
+ Args:
676
+
677
+ structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
678
+
679
+ Returns:
680
+
681
+ list of tuples of (node label (or rel type), property name)
682
+
683
+ """
684
+ schema_metadata = structured_schema .get ("metadata" , {})
685
+ existence_constraint = [] # list of (node label, property name)
686
+ for constraint in schema_metadata .get ("constraints" , []):
687
+ if constraint ["type" ] in (
688
+ "NODE_PROPERTY_EXISTENCE" ,
689
+ "NODE_KEY" ,
690
+ "RELATIONSHIP_PROPERTY_EXISTENCE" ,
691
+ "RELATIONSHIP_KEY" ,
692
+ ):
693
+ properties = constraint ["properties" ]
694
+ labels = constraint ["labelsOrTypes" ]
695
+ # note: existence constraint only apply to a single property
696
+ # and a single label
697
+ prop = properties [0 ]
698
+ lab = labels [0 ]
699
+ existence_constraint .append ((lab , prop ))
700
+ return existence_constraint
701
+
702
+ async def run (self ) -> GraphSchema :
703
+ structured_schema = get_structured_schema (self .driver , database = self .database )
704
+ existence_constraint = self ._extract_required_properties (structured_schema )
638
705
639
- async def run (self , ** kwargs : Any ) -> GraphSchema :
640
- structured_schema = get_structured_schema (self .driver )
641
706
node_labels = set (structured_schema ["node_props" ].keys ())
642
707
node_types = [
643
708
{
@@ -646,9 +711,11 @@ async def run(self, **kwargs: Any) -> GraphSchema:
646
711
{
647
712
"name" : p ["property" ],
648
713
"type" : p ["type" ],
714
+ "required" : (key , p ["property" ]) in existence_constraint ,
649
715
}
650
716
for p in properties
651
717
],
718
+ "additional_properties" : self .additional_properties ,
652
719
}
653
720
for key , properties in structured_schema ["node_props" ].items ()
654
721
]
@@ -660,6 +727,7 @@ async def run(self, **kwargs: Any) -> GraphSchema:
660
727
{
661
728
"name" : p ["property" ],
662
729
"type" : p ["type" ],
730
+ "required" : (key , p ["property" ]) in existence_constraint ,
663
731
}
664
732
for p in properties
665
733
],
@@ -698,5 +766,8 @@ async def run(self, **kwargs: Any) -> GraphSchema:
698
766
"node_types" : node_types ,
699
767
"relationship_types" : relationship_types ,
700
768
"patterns" : patterns ,
769
+ "additional_node_types" : self .additional_node_types ,
770
+ "additional_relationship_types" : self .additional_relationship_types ,
771
+ "additional_patterns" : self .additional_patterns ,
701
772
}
702
773
)
0 commit comments