1515from  __future__ import  annotations 
1616
1717import  asyncio 
18+ import  inspect 
1819import  logging 
1920from  abc  import  abstractmethod 
2021from  typing  import  Any , Dict , Literal , Optional , Tuple 
2829    Neo4jRelationship ,
2930)
3031from  neo4j_graphrag .experimental .pipeline .component  import  Component , DataModel 
31- from  neo4j_graphrag .indexes  import  (
32-     async_upsert_vector ,
33-     async_upsert_vector_on_relationship ,
34-     upsert_vector ,
35-     upsert_vector_on_relationship ,
36- )
3732from  neo4j_graphrag .neo4j_queries  import  UPSERT_NODE_QUERY , UPSERT_RELATIONSHIP_QUERY 
3833
3934logger  =  logging .getLogger (__name__ )
@@ -102,15 +97,26 @@ def __init__(
10297        self .neo4j_database  =  neo4j_database 
10398        self .max_concurrency  =  max_concurrency 
10499
100+     def  _db_setup (self ) ->  None :
101+         # create index on __Entity__.id 
102+         self .driver .execute_query (
103+             "CREATE INDEX __entity__id IF NOT EXISTS  FOR (n:__Entity__) ON (n.id)" 
104+         )
105+ 
106+     async  def  _async_db_setup (self ) ->  None :
107+         # create index on __Entity__.id 
108+         await  self .driver .execute_query (
109+             "CREATE INDEX __entity__id IF NOT EXISTS  FOR (n:__Entity__) ON (n.id)" 
110+         )
111+ 
105112    def  _get_node_query (self , node : Neo4jNode ) ->  Tuple [str , Dict [str , Any ]]:
106113        # Create the initial node 
107-         parameters  =  {"id" : node .id }
108-         if  node .properties :
109-             parameters .update (node .properties )
110-         properties  =  (
111-             "{"  +  ", " .join (f"{ key } { key }   for  key  in  parameters .keys ()) +  "}" 
112-         )
113-         query  =  UPSERT_NODE_QUERY .format (label = node .label , properties = properties )
114+         parameters  =  {
115+             "id" : node .id ,
116+             "properties" : node .properties  or  {},
117+             "embeddings" : node .embedding_properties ,
118+         }
119+         query  =  UPSERT_NODE_QUERY .format (label = node .label )
114120        return  query , parameters 
115121
116122    def  _upsert_node (self , node : Neo4jNode ) ->  None :
@@ -120,18 +126,7 @@ def _upsert_node(self, node: Neo4jNode) -> None:
120126            node (Neo4jNode): The node to upsert into the database. 
121127        """ 
122128        query , parameters  =  self ._get_node_query (node )
123-         result  =  self .driver .execute_query (query , parameters_ = parameters )
124-         node_id  =  result .records [0 ]["elementID(n)" ]
125-         # Add the embedding properties to the node 
126-         if  node .embedding_properties :
127-             for  prop , vector  in  node .embedding_properties .items ():
128-                 upsert_vector (
129-                     driver = self .driver ,
130-                     node_id = node_id ,
131-                     embedding_property = prop ,
132-                     vector = vector ,
133-                     neo4j_database = self .neo4j_database ,
134-                 )
129+         self .driver .execute_query (query , parameters_ = parameters )
135130
136131    async  def  _async_upsert_node (
137132        self ,
@@ -145,35 +140,18 @@ async def _async_upsert_node(
145140        """ 
146141        async  with  sem :
147142            query , parameters  =  self ._get_node_query (node )
148-             result  =  await  self .driver .execute_query (query , parameters_ = parameters )
149-             node_id  =  result .records [0 ]["elementID(n)" ]
150-             # Add the embedding properties to the node 
151-             if  node .embedding_properties :
152-                 for  prop , vector  in  node .embedding_properties .items ():
153-                     await  async_upsert_vector (
154-                         driver = self .driver ,
155-                         node_id = node_id ,
156-                         embedding_property = prop ,
157-                         vector = vector ,
158-                         neo4j_database = self .neo4j_database ,
159-                     )
143+             await  self .driver .execute_query (query , parameters_ = parameters )
160144
161145    def  _get_rel_query (self , rel : Neo4jRelationship ) ->  Tuple [str , Dict [str , Any ]]:
162146        # Create the initial relationship 
163147        parameters  =  {
164148            "start_node_id" : rel .start_node_id ,
165149            "end_node_id" : rel .end_node_id ,
150+             "properties" : rel .properties  or  {},
151+             "embeddings" : rel .embedding_properties ,
166152        }
167-         if  rel .properties :
168-             properties  =  (
169-                 "{"  +  ", " .join (f"{ key } { key }   for  key  in  rel .properties .keys ()) +  "}" 
170-             )
171-             parameters .update (rel .properties )
172-         else :
173-             properties  =  "{}" 
174153        query  =  UPSERT_RELATIONSHIP_QUERY .format (
175154            type = rel .type ,
176-             properties = properties ,
177155        )
178156        return  query , parameters 
179157
@@ -184,18 +162,7 @@ def _upsert_relationship(self, rel: Neo4jRelationship) -> None:
184162            rel (Neo4jRelationship): The relationship to upsert into the database. 
185163        """ 
186164        query , parameters  =  self ._get_rel_query (rel )
187-         result  =  self .driver .execute_query (query , parameters_ = parameters )
188-         rel_id  =  result .records [0 ]["elementID(r)" ]
189-         # Add the embedding properties to the relationship 
190-         if  rel .embedding_properties :
191-             for  prop , vector  in  rel .embedding_properties .items ():
192-                 upsert_vector_on_relationship (
193-                     driver = self .driver ,
194-                     rel_id = rel_id ,
195-                     embedding_property = prop ,
196-                     vector = vector ,
197-                     neo4j_database = self .neo4j_database ,
198-                 )
165+         self .driver .execute_query (query , parameters_ = parameters )
199166
200167    async  def  _async_upsert_relationship (
201168        self , rel : Neo4jRelationship , sem : asyncio .Semaphore 
@@ -207,18 +174,7 @@ async def _async_upsert_relationship(
207174        """ 
208175        async  with  sem :
209176            query , parameters  =  self ._get_rel_query (rel )
210-             result  =  await  self .driver .execute_query (query , parameters_ = parameters )
211-             rel_id  =  result .records [0 ]["elementID(r)" ]
212-             # Add the embedding properties to the relationship 
213-             if  rel .embedding_properties :
214-                 for  prop , vector  in  rel .embedding_properties .items ():
215-                     await  async_upsert_vector_on_relationship (
216-                         driver = self .driver ,
217-                         rel_id = rel_id ,
218-                         embedding_property = prop ,
219-                         vector = vector ,
220-                         neo4j_database = self .neo4j_database ,
221-                     )
177+             await  self .driver .execute_query (query , parameters_ = parameters )
222178
223179    @validate_call  
224180    async  def  run (self , graph : Neo4jGraph ) ->  KGWriterModel :
@@ -228,7 +184,8 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
228184            graph (Neo4jGraph): The knowledge graph to upsert into the database. 
229185        """ 
230186        try :
231-             if  isinstance (self .driver , neo4j .AsyncDriver ):
187+             if  inspect .iscoroutinefunction (self .driver .execute_query ):
188+                 await  self ._async_db_setup ()
232189                sem  =  asyncio .Semaphore (self .max_concurrency )
233190                node_tasks  =  [
234191                    self ._async_upsert_node (node , sem ) for  node  in  graph .nodes 
@@ -241,6 +198,8 @@ async def run(self, graph: Neo4jGraph) -> KGWriterModel:
241198                ]
242199                await  asyncio .gather (* rel_tasks )
243200            else :
201+                 self ._db_setup ()
202+ 
244203                for  node  in  graph .nodes :
245204                    self ._upsert_node (node )
246205
0 commit comments