Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 113 additions & 22 deletions lazyllm/tools/rag/store/vector/milvus_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,21 @@
MILVUS_UPSERT_BATCH_SIZE = 500
MILVUS_PAGINATION_OFFSET = 1000
MILVUS_INDEX_MAX_RETRY = 3

MILVUS_INDEX_TYPE_DEFAULTS = {
'HNSW': {'metric_type': 'COSINE', 'params': {'M': 16, 'efConstruction': 200}},
'IVF_FLAT': {'metric_type': 'L2', 'params': {'nlist': 1024}},
'IVF_SQ8': {'metric_type': 'L2', 'params': {'nlist': 1024}},
'IVF_PQ': {'metric_type': 'L2', 'params': {'nlist': 1024, 'm': 8, 'nbits': 8}},
'FLAT': {'metric_type': 'L2', 'params': {}},
'GPU_IVF_FLAT': {'metric_type': 'L2', 'params': {'nlist': 1024}},
'GPU_IVF_SQ8': {'metric_type': 'L2', 'params': {'nlist': 1024}},
'GPU_IVF_PQ': {'metric_type': 'L2', 'params': {'nlist': 1024, 'm': 8, 'nbits': 8}},
'DISKANN': {'metric_type': 'L2', 'params': {'nlist': 1024}},
'BIN_FLAT': {'metric_type': 'HAMMING', 'params': {}},
'BIN_IVF_FLAT': {'metric_type': 'HAMMING', 'params': {'nlist': 1024}},
'SPARSE_INVERTED_INDEX': {'metric_type': 'IP', 'params': {'inverted_index_algo': 'DAAT_MAXSCORE'}},
'AUTOINDEX': {'metric_type': 'COSINE', 'params': {'nlist': 128}},
}

class _ClientPool:
def __init__(self, maker, max_size: int = 8):
Expand Down Expand Up @@ -246,19 +260,25 @@ def _create_collection(self, client, collection_name: str, embed_kwargs: Dict[st
field_list = copy.deepcopy(self._constant_fields)
index_params = client.prepare_index_params()
original_index_kwargs = copy.deepcopy(self._index_kwargs)

# Pre-process index_kwargs to create a lookup dictionary for O(1) access
index_kwargs_lookup = {}
if isinstance(original_index_kwargs, list):
for item in original_index_kwargs:
self._ensure_params_defaults(item)
embed_key = item.get('embed_key', None)
if not embed_key:
raise ValueError(f'cannot find `embed_key` in `index_kwargs` of `{item}`')
index_kwargs_lookup[embed_key] = item.copy()
index_kwargs_lookup[embed_key].pop('embed_key', None)

for k, kws in embed_kwargs.items():
embed_field_name = self._gen_embed_key(k)
field_list.append(pymilvus.FieldSchema(name=embed_field_name, **kws))

if isinstance(original_index_kwargs, list):
for item in original_index_kwargs:
embed_key = item.get('embed_key', None)
if not embed_key:
raise ValueError(f'cannot find `embed_key` in `index_kwargs` of `{item}`')
if embed_key == k:
index_kwarg = item.copy()
index_kwarg.pop('embed_key', None)
index_params.add_index(field_name=embed_field_name, **index_kwarg)
break
if k in index_kwargs_lookup:
index_params.add_index(field_name=embed_field_name, **index_kwargs_lookup[k])
elif isinstance(original_index_kwargs, dict):
index_params.add_index(field_name=embed_field_name, **original_index_kwargs)
schema = pymilvus.CollectionSchema(fields=field_list, auto_id=False, enable_dynamic_field=False)
Expand All @@ -278,24 +298,95 @@ def _create_collection(self, client, collection_name: str, embed_kwargs: Dict[st
except Exception:
LOG.error(f'[Milvus Store] failed to parse invalid index type from error: {msg}')
raise
self._replace_index_type_to_autoindex(wrong_index_type)
self._ensure_valid_index(self._index_kwargs)
LOG.warning(f'[Milvus Store] Unsupported index type: {wrong_index_type}. '
f'Fallback to AUTOINDEX and retry (try #{retry + 1}).')
self._create_collection(client, collection_name, embed_kwargs, retry=retry + 1)
else:
raise e

def _replace_index_type_to_autoindex(self, index_type: str):
if index_type == 'AUTOINDEX':
raise ValueError(f'[Milvus Store - replace_index_type_to_autoindex] Invalid index type: {index_type}')
if isinstance(self._index_kwargs, list):
for item in self._index_kwargs:
if item.get('index_type') == index_type:
item['index_type'] = 'AUTOINDEX'
elif isinstance(self._index_kwargs, dict):
if self._index_kwargs.get('index_type') == index_type:
self._index_kwargs['index_type'] = 'AUTOINDEX'
return
def _ensure_valid_index(self, index_params: Union[list, dict]):
embed_index_map = {
DataType.FLOAT_VECTOR: {
'AUTOINDEX': ['L2', 'IP', 'COSINE'], 'HNSW': ['L2', 'COSINE', 'IP'],
'IVF_FLAT': ['L2', 'IP', 'COSINE'], 'IVF_SQ8': ['L2', 'IP'], 'IVF_PQ': ['L2', 'IP'],
'FLAT': ['IP', 'COSINE'], 'DISKANN': ['L2']
},
DataType.SPARSE_FLOAT_VECTOR: {'SPARSE_INVERTED_INDEX': ['IP'], 'SPARSE_WAND': ['IP']},
DataType.VARCHAR: {'INVERTED_INDEX': [None]},
DataType.STRING: {'INVERTED_INDEX': [None]},
DataType.ARRAY: {'INVERTED_INDEX': [None]},
DataType.INT32: {'INVERTED_INDEX': [None]},
DataType.INT64: {'INVERTED_INDEX': [None]},
DataType.FLOAT: {'INVERTED_INDEX': [None]},
DataType.BOOLEAN: {'INVERTED_INDEX': [None]},
}

def _replace_index_type(index_item: dict):
embed_key = index_item.get('embed_key')
dtype = self._embed_datatypes.get(embed_key)
index_type = index_item.get('index_type').upper()
metric_type = index_item.get('metric_type').upper()
if dtype not in embed_index_map:
LOG.warning(f'[Milvus Store] {embed_key}: Unsupported data type: {dtype}.'
f'Fallback to {list(embed_index_map.keys())[0]}.')
dtype = list(embed_index_map.keys())[0]
self._embed_datatypes[embed_key] = dtype
if index_type not in embed_index_map.get(dtype):
LOG.warning(f'[Milvus Store] {DataType(dtype).name}: Unsupported index type: {index_type}.'
f'Fallback to {list(embed_index_map.get(dtype).keys())[0]}.')
index_type = list(embed_index_map.get(dtype).keys())[0]
index_item['index_type'] = index_type
if metric_type not in embed_index_map[dtype][index_type]:
LOG.warning(f'[Milvus Store] {index_type}: Unsupported metric type: {metric_type}.'
f'Fallback to {embed_index_map[dtype][index_type][0]}.')
metric_type = embed_index_map[dtype][index_type][0]
index_item['metric_type'] = metric_type
self._ensure_params_defaults(index_item)

if isinstance(index_params, list):
for index_item in index_params:
_replace_index_type(index_item)
elif isinstance(index_params, dict):
_replace_index_type(index_params)
else:
LOG.error(f'Expected index params type: list or dict, but got {type(index_params)}')

def _ensure_params_defaults(self, index_item: dict):
'''
Fill in the missing fields (index_type, metric_type, params) of a single index item.
Do not override the fields explicitly provided by the user (only setdefault)
params will be filled in with common defaults based on index_type
(if params already exist, only fill in missing keys)
'''
if not isinstance(index_item, dict):
return

# Normalize index_type
itype = index_item.get('index_type')
if itype:
itype_up = str(itype).upper()
index_item['index_type'] = itype_up
else:
LOG.error(f'cannot find `index_type` in `index_kwargs` of `{index_item}`')

defaults = MILVUS_INDEX_TYPE_DEFAULTS.get(index_item['index_type'], {})

# metric_type default fill (do not override user)
if 'metric_type' not in index_item and 'metric_type' in defaults:
index_item['metric_type'] = defaults['metric_type']

default_params = defaults.get('params', {})
if 'params' not in index_item or index_item.get('params') is None:
index_item['params'] = dict(default_params)
else:
# fill in the missing keys of params
if isinstance(index_item['params'], dict):
for k, v in default_params.items():
index_item['params'].setdefault(k, v)
else:
# if user passed a non-dict (exception), replace it with the default dict
index_item['params'] = dict(default_params)

def _serialize_data(self, d: dict) -> dict:
# only keep primary_key, embedding and global_meta
Expand Down
Loading