Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/store/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def is_group_active(self, group: str) -> bool:
return group in self._activated_groups

def is_group_empty(self, group: str) -> bool:
return not self.impl.get(self._gen_collection_name(group), {})
return not self.impl.get(self._gen_collection_name(group), {}, limit=10)

def update_nodes(self, nodes: List[DocNode]): # noqa: C901
if not nodes:
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/store/hybrid/hybrid_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -
res_segments = self.segment_store.get(collection_name=collection_name, criteria=criteria, **kwargs)
if not res_segments: return []
uids = [item.get('uid') for item in res_segments]
res_vectors = self.vector_store.get(collection_name=collection_name, criteria={'uid': uids})
res_vectors = self.vector_store.get(collection_name=collection_name, criteria={'uid': uids}, **kwargs)

data = {}
for item in res_segments:
Expand Down
65 changes: 49 additions & 16 deletions lazyllm/tools/rag/store/vector/milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def delete(self, collection_name: str, criteria: Optional[dict] = None, **kwargs
return False

@override
def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]:
def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: # noqa: C901
try:
with self._client_context() as client:
if not client.has_collection(collection_name):
Expand All @@ -183,29 +183,62 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -
col_desc = client.describe_collection(collection_name=collection_name)
field_names = [field.get('name') for field in col_desc.get('fields', [])
if field.get('name').startswith(EMBED_PREFIX)]
if criteria and self._primary_key in criteria:
res = client.get(collection_name=collection_name, ids=criteria[self._primary_key])
query_kwargs = self._construct_criteria(criteria) if criteria else {}
if version.parse(pymilvus.__version__) < version.parse('2.4.11'):
# For older versions, batch query manually
res = self._batch_query_legacy(client, collection_name, field_names, query_kwargs)
else:
filters = self._construct_criteria(criteria) if criteria else {}
if version.parse(pymilvus.__version__) >= version.parse('2.4.11'):
iterator = client.query_iterator(collection_name=collection_name,
batch_size=MILVUS_PAGINATION_OFFSET,
output_fields=field_names, **filters)
res = []
while True:
result = iterator.next()
if not result:
iterator.close()
break
res += result
if criteria and self._primary_key in criteria:
ids = criteria[self._primary_key]
if isinstance(ids, str):
ids = [ids]
query_kwargs = {'filter': f'{self._primary_key} in {ids}'}
# return all fields
field_names = None
else:
res = client.query(collection_name=collection_name, output_fields=field_names, **filters)
query_kwargs.update(**kwargs)

iterator = client.query_iterator(collection_name=collection_name,
batch_size=MILVUS_PAGINATION_OFFSET,
output_fields=field_names, **query_kwargs)
res = []
while True:
result = iterator.next()
if not result:
iterator.close()
break
res += result
return [self._deserialize_data(r) for r in res]
except Exception as e:
LOG.error(f'[Milvus Store - get] error: {e}')
LOG.error(traceback.format_exc())
return []

def _batch_query_legacy(self, client, collection_name: str, field_names: List[str], kwargs: dict) -> List[dict]:
res = []
offset = 0
batch_size = MILVUS_PAGINATION_OFFSET

while True:
try:
# Add offset and limit to filters for pagination
batch_kwargs = dict(kwargs)
batch_kwargs['offset'] = offset
batch_kwargs['limit'] = batch_size

batch_res = client.query(collection_name=collection_name, output_fields=field_names, **batch_kwargs)
if not batch_res:
break

res.extend(batch_res)
if len(batch_res) < batch_size:
break
offset += batch_size
except Exception as e:
LOG.error(f'[Milvus Store - _batch_query_legacy] error: {e}')
raise
return res

def _set_constants(self):
self._type2milvus = {
DataType.VARCHAR: pymilvus.DataType.VARCHAR,
Expand Down
42 changes: 42 additions & 0 deletions tests/basic_tests/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import tempfile
import unittest
import copy
import lazyllm
from lazyllm.tools.rag.store import (MapStore, ChromadbStore, MilvusStore,
SenseCoreStore, BUILDIN_GLOBAL_META_DESC, HybridStore)
Expand Down Expand Up @@ -508,6 +509,47 @@ def test_search_with_filters(self):
embed_key='vec_dense', topk=1, filters={RAG_KB_ID: ['kb1']})
self.assertEqual(len(res), 0)

def test_get_massive_data(self):
new_data_list = []
criteria_list = []
MASSIVE_DATA_SIZE = 20000
for i in range(MASSIVE_DATA_SIZE):
one_data = copy.deepcopy(data[0])
one_data['uid'] = f'uid_{i}'
one_data['doc_id'] = 'doc_common'
criteria_list.append(f'uid_{i}')
new_data_list.append(one_data)

self.store.upsert(self.collections[0], new_data_list)

# test client.query_iterator in get api
res = self.store.get(collection_name=self.collections[0])
self.assertEqual(len(res), MASSIVE_DATA_SIZE)

SEARCH_DATA_SIZE = 9999
res = self.store.get(collection_name=self.collections[0], criteria={'uid': criteria_list[0:SEARCH_DATA_SIZE]})
self.assertEqual(len(res), SEARCH_DATA_SIZE)

def test_batch_query_legacy(self):

with self.store._client_context() as client:
new_data_list = []
criteria_list = []
for i in range(10000):
one_data = copy.deepcopy(data[0])
one_data['uid'] = f'uid_{i}'
one_data['doc_id'] = 'doc_common'
criteria_list.append(f'uid_{i}')
new_data_list.append(one_data)

self.store.upsert(self.collections[0], new_data_list)
res = self.store._batch_query_legacy(client, self.collections[0], field_names=['uid'], kwargs={})
self.assertEqual(len(res), len(new_data_list))

filters = self.store._construct_criteria({"doc_id": "doc_common"})
res = self.store._batch_query_legacy(client, self.collections[0], field_names=['uid'], kwargs=filters)
self.assertEqual(len(res), len(new_data_list))

@pytest.mark.skip(reason=('local test for milvus standalone, please set up a milvus standalone server'
' and set the uri to the server'))
def test_milvus_standalone(self):
Expand Down
Loading