Skip to content

Commit 214edd4

Browse files
committed
feat(milvus): add backfill utility for graph DB to vector store sync
Add backfill_vector_store() that reads all entity nodes, entity edges, episodic nodes, and community nodes from the graph DB and batch-upserts them into the vector store. Supports group_id filtering and configurable batch size. Skips records without embeddings. 7 new unit tests covering empty graph, batching, filtering, and all four collection types.
1 parent a0fb9f4 commit 214edd4

File tree

2 files changed

+622
-0
lines changed

2 files changed

+622
-0
lines changed
Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
"""Backfill utility to populate a vector store from an existing graph database."""
2+
3+
import logging
4+
from typing import Any
5+
6+
from graphiti_core.driver.driver import GraphDriver
7+
from graphiti_core.vector_store.client import VectorStoreClient
8+
from graphiti_core.vector_store.milvus_utils import (
9+
COLLECTION_COMMUNITY_NODES,
10+
COLLECTION_ENTITY_EDGES,
11+
COLLECTION_ENTITY_NODES,
12+
COLLECTION_EPISODIC_NODES,
13+
community_node_to_milvus_dict,
14+
entity_edge_to_milvus_dict,
15+
entity_node_to_milvus_dict,
16+
episodic_node_to_milvus_dict,
17+
)
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
async def backfill_vector_store(
23+
driver: GraphDriver,
24+
vector_store: VectorStoreClient,
25+
group_ids: list[str] | None = None,
26+
batch_size: int = 100,
27+
) -> dict[str, int]:
28+
"""Backfill a vector store from an existing graph database.
29+
30+
Reads all entity nodes, entity edges, episodic nodes, and community nodes
31+
from the graph DB and upserts them into the vector store.
32+
33+
Parameters
34+
----------
35+
driver : GraphDriver
36+
The graph database driver to read from.
37+
vector_store : VectorStoreClient
38+
The vector store client to write to.
39+
group_ids : list[str] | None
40+
Optional list of group IDs to filter by. If None, syncs all data.
41+
batch_size : int
42+
Number of records to process per batch.
43+
44+
Returns
45+
-------
46+
dict[str, int]
47+
Counts of synced records per collection type.
48+
"""
49+
await vector_store.ensure_ready()
50+
counts: dict[str, int] = {
51+
'entity_nodes': 0,
52+
'entity_edges': 0,
53+
'episodic_nodes': 0,
54+
'community_nodes': 0,
55+
}
56+
57+
group_filter = ''
58+
params: dict[str, Any] = {}
59+
if group_ids is not None:
60+
group_filter = 'WHERE n.group_id IN $group_ids'
61+
params['group_ids'] = group_ids
62+
63+
# Sync entity nodes
64+
counts['entity_nodes'] = await _sync_entity_nodes(
65+
driver, vector_store, group_filter, params, batch_size
66+
)
67+
68+
# Sync entity edges
69+
edge_group_filter = group_filter.replace('n.group_id', 'r.group_id')
70+
counts['entity_edges'] = await _sync_entity_edges(
71+
driver, vector_store, edge_group_filter, params, batch_size
72+
)
73+
74+
# Sync episodic nodes
75+
counts['episodic_nodes'] = await _sync_episodic_nodes(
76+
driver, vector_store, group_filter, params, batch_size
77+
)
78+
79+
# Sync community nodes
80+
counts['community_nodes'] = await _sync_community_nodes(
81+
driver, vector_store, group_filter, params, batch_size
82+
)
83+
84+
logger.info(f'Backfill complete: {counts}')
85+
return counts
86+
87+
88+
async def _sync_entity_nodes(
89+
driver: GraphDriver,
90+
vector_store: VectorStoreClient,
91+
group_filter: str,
92+
params: dict[str, Any],
93+
batch_size: int,
94+
) -> int:
95+
"""Sync entity nodes from graph DB to vector store."""
96+
from graphiti_core.nodes import get_entity_node_from_record
97+
98+
records, _, _ = await driver.execute_query(
99+
f"""
100+
MATCH (n:Entity)
101+
{group_filter}
102+
RETURN
103+
n.uuid AS uuid,
104+
n.name AS name,
105+
n.group_id AS group_id,
106+
n.created_at AS created_at,
107+
n.summary AS summary,
108+
n.name_embedding AS name_embedding,
109+
labels(n) AS labels,
110+
properties(n) AS attributes
111+
""",
112+
**params,
113+
routing_='r',
114+
)
115+
116+
count = 0
117+
col = vector_store.collection_name(COLLECTION_ENTITY_NODES)
118+
batch: list[dict[str, Any]] = []
119+
120+
for record in records:
121+
node = get_entity_node_from_record(record, driver.provider)
122+
embedding = record.get('name_embedding')
123+
if embedding is not None:
124+
node.name_embedding = embedding
125+
else:
126+
logger.debug(f'Skipping entity node {node.uuid}: no embedding')
127+
continue
128+
129+
batch.append(entity_node_to_milvus_dict(node))
130+
if len(batch) >= batch_size:
131+
await vector_store.upsert(collection_name=col, data=batch)
132+
count += len(batch)
133+
batch = []
134+
135+
if batch:
136+
await vector_store.upsert(collection_name=col, data=batch)
137+
count += len(batch)
138+
139+
logger.info(f'Synced {count} entity nodes')
140+
return count
141+
142+
143+
async def _sync_entity_edges(
144+
driver: GraphDriver,
145+
vector_store: VectorStoreClient,
146+
group_filter: str,
147+
params: dict[str, Any],
148+
batch_size: int,
149+
) -> int:
150+
"""Sync entity edges from graph DB to vector store."""
151+
from graphiti_core.edges import EntityEdge
152+
153+
records, _, _ = await driver.execute_query(
154+
f"""
155+
MATCH (src)-[r:RELATES_TO]->(tgt)
156+
{group_filter}
157+
RETURN
158+
r.uuid AS uuid,
159+
r.group_id AS group_id,
160+
src.uuid AS source_node_uuid,
161+
tgt.uuid AS target_node_uuid,
162+
r.name AS name,
163+
r.fact AS fact,
164+
r.fact_embedding AS fact_embedding,
165+
r.episodes AS episodes,
166+
r.created_at AS created_at,
167+
r.expired_at AS expired_at,
168+
r.valid_at AS valid_at,
169+
r.invalid_at AS invalid_at
170+
""",
171+
**params,
172+
routing_='r',
173+
)
174+
175+
count = 0
176+
col = vector_store.collection_name(COLLECTION_ENTITY_EDGES)
177+
batch: list[dict[str, Any]] = []
178+
179+
for record in records:
180+
embedding = record.get('fact_embedding')
181+
if embedding is None:
182+
logger.debug(f'Skipping edge {record.get("uuid")}: no embedding')
183+
continue
184+
185+
edge = EntityEdge(
186+
uuid=record['uuid'],
187+
group_id=record['group_id'],
188+
source_node_uuid=record['source_node_uuid'],
189+
target_node_uuid=record['target_node_uuid'],
190+
name=record.get('name', ''),
191+
fact=record.get('fact', ''),
192+
fact_embedding=embedding,
193+
episodes=record.get('episodes') or [],
194+
created_at=record['created_at'],
195+
expired_at=record.get('expired_at'),
196+
valid_at=record.get('valid_at'),
197+
invalid_at=record.get('invalid_at'),
198+
)
199+
200+
batch.append(entity_edge_to_milvus_dict(edge))
201+
if len(batch) >= batch_size:
202+
await vector_store.upsert(collection_name=col, data=batch)
203+
count += len(batch)
204+
batch = []
205+
206+
if batch:
207+
await vector_store.upsert(collection_name=col, data=batch)
208+
count += len(batch)
209+
210+
logger.info(f'Synced {count} entity edges')
211+
return count
212+
213+
214+
async def _sync_episodic_nodes(
215+
driver: GraphDriver,
216+
vector_store: VectorStoreClient,
217+
group_filter: str,
218+
params: dict[str, Any],
219+
batch_size: int,
220+
) -> int:
221+
"""Sync episodic nodes from graph DB to vector store."""
222+
from graphiti_core.nodes import EpisodicNode
223+
224+
records, _, _ = await driver.execute_query(
225+
f"""
226+
MATCH (n:Episodic)
227+
{group_filter}
228+
RETURN
229+
n.uuid AS uuid,
230+
n.group_id AS group_id,
231+
n.name AS name,
232+
n.content AS content,
233+
n.source AS source,
234+
n.source_description AS source_description,
235+
n.created_at AS created_at,
236+
n.valid_at AS valid_at,
237+
n.entity_edges AS entity_edges
238+
""",
239+
**params,
240+
routing_='r',
241+
)
242+
243+
count = 0
244+
col = vector_store.collection_name(COLLECTION_EPISODIC_NODES)
245+
batch: list[dict[str, Any]] = []
246+
247+
for record in records:
248+
node = EpisodicNode(
249+
uuid=record['uuid'],
250+
group_id=record['group_id'],
251+
name=record.get('name', ''),
252+
content=record.get('content', ''),
253+
source=record.get('source', 'text'),
254+
source_description=record.get('source_description', ''),
255+
created_at=record['created_at'],
256+
valid_at=record.get('valid_at') or record['created_at'],
257+
entity_edges=record.get('entity_edges') or [],
258+
)
259+
260+
batch.append(episodic_node_to_milvus_dict(node))
261+
if len(batch) >= batch_size:
262+
await vector_store.upsert(collection_name=col, data=batch)
263+
count += len(batch)
264+
batch = []
265+
266+
if batch:
267+
await vector_store.upsert(collection_name=col, data=batch)
268+
count += len(batch)
269+
270+
logger.info(f'Synced {count} episodic nodes')
271+
return count
272+
273+
274+
async def _sync_community_nodes(
275+
driver: GraphDriver,
276+
vector_store: VectorStoreClient,
277+
group_filter: str,
278+
params: dict[str, Any],
279+
batch_size: int,
280+
) -> int:
281+
"""Sync community nodes from graph DB to vector store."""
282+
from graphiti_core.nodes import CommunityNode
283+
284+
records, _, _ = await driver.execute_query(
285+
f"""
286+
MATCH (n:Community)
287+
{group_filter}
288+
RETURN
289+
n.uuid AS uuid,
290+
n.group_id AS group_id,
291+
n.name AS name,
292+
n.summary AS summary,
293+
n.created_at AS created_at,
294+
n.name_embedding AS name_embedding
295+
""",
296+
**params,
297+
routing_='r',
298+
)
299+
300+
count = 0
301+
col = vector_store.collection_name(COLLECTION_COMMUNITY_NODES)
302+
batch: list[dict[str, Any]] = []
303+
304+
for record in records:
305+
embedding = record.get('name_embedding')
306+
if embedding is None:
307+
logger.debug(f'Skipping community node {record.get("uuid")}: no embedding')
308+
continue
309+
310+
node = CommunityNode(
311+
uuid=record['uuid'],
312+
group_id=record['group_id'],
313+
name=record.get('name', ''),
314+
summary=record.get('summary', ''),
315+
created_at=record['created_at'],
316+
name_embedding=embedding,
317+
)
318+
319+
batch.append(community_node_to_milvus_dict(node))
320+
if len(batch) >= batch_size:
321+
await vector_store.upsert(collection_name=col, data=batch)
322+
count += len(batch)
323+
batch = []
324+
325+
if batch:
326+
await vector_store.upsert(collection_name=col, data=batch)
327+
count += len(batch)
328+
329+
logger.info(f'Synced {count} community nodes')
330+
return count

0 commit comments

Comments
 (0)