Skip to content

Commit f77538e

Browse files
prasmussen15claude
andcommitted
refactor: add get_between_nodes_bidirectional to avoid doubling queries
Adds a new EntityEdge.get_between_nodes_bidirectional method that uses undirected Cypher matching (-) instead of directed (->), finding edges in both directions with a single query. This preserves the existing directional semantics of get_between_nodes for other callers. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2a1a941 commit f77538e

File tree

2 files changed

+37
-21
lines changed

2 files changed

+37
-21
lines changed

graphiti_core/edges.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,35 @@ async def get_between_nodes(
435435

436436
return edges
437437

438+
@classmethod
439+
async def get_between_nodes_bidirectional(
440+
cls, driver: GraphDriver, node_uuid_a: str, node_uuid_b: str
441+
):
442+
match_query = """
443+
MATCH (n:Entity {uuid: $node_uuid_a})-[e:RELATES_TO]-(m:Entity {uuid: $node_uuid_b})
444+
"""
445+
if driver.provider == GraphProvider.KUZU:
446+
match_query = """
447+
MATCH (n:Entity {uuid: $node_uuid_a})
448+
-[:RELATES_TO]-(e:RelatesToNode_)
449+
-[:RELATES_TO]-(m:Entity {uuid: $node_uuid_b})
450+
"""
451+
452+
records, _, _ = await driver.execute_query(
453+
match_query
454+
+ """
455+
RETURN
456+
"""
457+
+ get_entity_edge_return_query(driver.provider),
458+
node_uuid_a=node_uuid_a,
459+
node_uuid_b=node_uuid_b,
460+
routing_='r',
461+
)
462+
463+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
464+
465+
return edges
466+
438467
@classmethod
439468
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
440469
if driver.graph_operations_interface:

graphiti_core/utils/maintenance/edge_operations.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -273,27 +273,14 @@ async def resolve_extracted_edges(
273273
embedder = clients.embedder
274274
await create_entity_edge_embeddings(embedder, extracted_edges)
275275

276-
all_edge_queries = [
277-
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
278-
for edge in extracted_edges
279-
] + [
280-
EntityEdge.get_between_nodes(driver, edge.target_node_uuid, edge.source_node_uuid)
281-
for edge in extracted_edges
282-
]
283-
all_results: list[list[EntityEdge]] = await semaphore_gather(*all_edge_queries)
284-
n = len(extracted_edges)
285-
forward_edges_list = all_results[:n]
286-
inverse_edges_list = all_results[n:]
287-
288-
valid_edges_list: list[list[EntityEdge]] = []
289-
for forward_edges, inverse_edges in zip(forward_edges_list, inverse_edges_list, strict=True):
290-
seen_uuids: set[str] = set()
291-
combined: list[EntityEdge] = []
292-
for edge in [*forward_edges, *inverse_edges]:
293-
if edge.uuid not in seen_uuids:
294-
seen_uuids.add(edge.uuid)
295-
combined.append(edge)
296-
valid_edges_list.append(combined)
276+
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
277+
*[
278+
EntityEdge.get_between_nodes_bidirectional(
279+
driver, edge.source_node_uuid, edge.target_node_uuid
280+
)
281+
for edge in extracted_edges
282+
]
283+
)
297284

298285
related_edges_results: list[SearchResults] = await semaphore_gather(
299286
*[

0 commit comments

Comments
 (0)