From 2838a623d69144e1b797ea9fea5d67e86e843280 Mon Sep 17 00:00:00 2001 From: chenhao0205 Date: Fri, 26 Sep 2025 10:08:45 +0800 Subject: [PATCH] enable milvus default params injection --- .../tools/rag/store/vector/milvus_store.py | 135 +++++++++++++++--- 1 file changed, 113 insertions(+), 22 deletions(-) diff --git a/lazyllm/tools/rag/store/vector/milvus_store.py b/lazyllm/tools/rag/store/vector/milvus_store.py index d401c9151..92dec9fe3 100644 --- a/lazyllm/tools/rag/store/vector/milvus_store.py +++ b/lazyllm/tools/rag/store/vector/milvus_store.py @@ -20,7 +20,21 @@ MILVUS_UPSERT_BATCH_SIZE = 500 MILVUS_PAGINATION_OFFSET = 1000 MILVUS_INDEX_MAX_RETRY = 3 - +MILVUS_INDEX_TYPE_DEFAULTS = { + 'HNSW': {'metric_type': 'COSINE', 'params': {'M': 16, 'efConstruction': 200}}, + 'IVF_FLAT': {'metric_type': 'L2', 'params': {'nlist': 1024}}, + 'IVF_SQ8': {'metric_type': 'L2', 'params': {'nlist': 1024}}, + 'IVF_PQ': {'metric_type': 'L2', 'params': {'nlist': 1024, 'm': 8, 'nbits': 8}}, + 'FLAT': {'metric_type': 'L2', 'params': {}}, + 'GPU_IVF_FLAT': {'metric_type': 'L2', 'params': {'nlist': 1024}}, + 'GPU_IVF_SQ8': {'metric_type': 'L2', 'params': {'nlist': 1024}}, + 'GPU_IVF_PQ': {'metric_type': 'L2', 'params': {'nlist': 1024, 'm': 8, 'nbits': 8}}, + 'DISKANN': {'metric_type': 'L2', 'params': {'nlist': 1024}}, + 'BIN_FLAT': {'metric_type': 'HAMMING', 'params': {}}, + 'BIN_IVF_FLAT': {'metric_type': 'HAMMING', 'params': {'nlist': 1024}}, + 'SPARSE_INVERTED_INDEX': {'metric_type': 'IP', 'params': {'inverted_index_algo': 'DAAT_MAXSCORE'}}, + 'AUTOINDEX': {'metric_type': 'COSINE', 'params': {'nlist': 128}}, +} class _ClientPool: def __init__(self, maker, max_size: int = 8): @@ -246,19 +260,25 @@ def _create_collection(self, client, collection_name: str, embed_kwargs: Dict[st field_list = copy.deepcopy(self._constant_fields) index_params = client.prepare_index_params() original_index_kwargs = copy.deepcopy(self._index_kwargs) + + # Pre-process index_kwargs to create a lookup dictionary for O(1) access + index_kwargs_lookup = {} + if isinstance(original_index_kwargs, list): + for item in original_index_kwargs: + self._ensure_params_defaults(item) + embed_key = item.get('embed_key', None) + if not embed_key: + raise ValueError(f'cannot find `embed_key` in `index_kwargs` of `{item}`') + index_kwargs_lookup[embed_key] = item.copy() + index_kwargs_lookup[embed_key].pop('embed_key', None) + for k, kws in embed_kwargs.items(): embed_field_name = self._gen_embed_key(k) field_list.append(pymilvus.FieldSchema(name=embed_field_name, **kws)) + if isinstance(original_index_kwargs, list): - for item in original_index_kwargs: - embed_key = item.get('embed_key', None) - if not embed_key: - raise ValueError(f'cannot find `embed_key` in `index_kwargs` of `{item}`') - if embed_key == k: - index_kwarg = item.copy() - index_kwarg.pop('embed_key', None) - index_params.add_index(field_name=embed_field_name, **index_kwarg) - break + if k in index_kwargs_lookup: + index_params.add_index(field_name=embed_field_name, **index_kwargs_lookup[k]) elif isinstance(original_index_kwargs, dict): index_params.add_index(field_name=embed_field_name, **original_index_kwargs) schema = pymilvus.CollectionSchema(fields=field_list, auto_id=False, enable_dynamic_field=False) @@ -278,24 +298,95 @@ def _create_collection(self, client, collection_name: str, embed_kwargs: Dict[st except Exception: LOG.error(f'[Milvus Store] failed to parse invalid index type from error: {msg}') raise - self._replace_index_type_to_autoindex(wrong_index_type) + self._ensure_valid_index(self._index_kwargs) LOG.warning(f'[Milvus Store] Unsupported index type: {wrong_index_type}. ' f'Fallback to AUTOINDEX and retry (try #{retry + 1}).') self._create_collection(client, collection_name, embed_kwargs, retry=retry + 1) else: raise e - def _replace_index_type_to_autoindex(self, index_type: str): - if index_type == 'AUTOINDEX': - raise ValueError(f'[Milvus Store - replace_index_type_to_autoindex] Invalid index type: {index_type}') - if isinstance(self._index_kwargs, list): - for item in self._index_kwargs: - if item.get('index_type') == index_type: - item['index_type'] = 'AUTOINDEX' - elif isinstance(self._index_kwargs, dict): - if self._index_kwargs.get('index_type') == index_type: - self._index_kwargs['index_type'] = 'AUTOINDEX' - return + def _ensure_valid_index(self, index_params: Union[list, dict]): + embed_index_map = { + DataType.FLOAT_VECTOR: { + 'AUTOINDEX': ['L2', 'IP', 'COSINE'], 'HNSW': ['L2', 'COSINE', 'IP'], + 'IVF_FLAT': ['L2', 'IP', 'COSINE'], 'IVF_SQ8': ['L2', 'IP'], 'IVF_PQ': ['L2', 'IP'], + 'FLAT': ['IP', 'COSINE'], 'DISKANN': ['L2'] + }, + DataType.SPARSE_FLOAT_VECTOR: {'SPARSE_INVERTED_INDEX': ['IP'], 'SPARSE_WAND': ['IP']}, + DataType.VARCHAR: {'INVERTED_INDEX': [None]}, + DataType.STRING: {'INVERTED_INDEX': [None]}, + DataType.ARRAY: {'INVERTED_INDEX': [None]}, + DataType.INT32: {'INVERTED_INDEX': [None]}, + DataType.INT64: {'INVERTED_INDEX': [None]}, + DataType.FLOAT: {'INVERTED_INDEX': [None]}, + DataType.BOOLEAN: {'INVERTED_INDEX': [None]}, + } + + def _replace_index_type(index_item: dict): + embed_key = index_item.get('embed_key') + dtype = self._embed_datatypes.get(embed_key) + index_type = index_item.get('index_type').upper() + metric_type = index_item.get('metric_type').upper() + if dtype not in embed_index_map: + LOG.warning(f'[Milvus Store] {embed_key}: Unsupported data type: {dtype}.' + f'Fallback to {list(embed_index_map.keys())[0]}.') + dtype = list(embed_index_map.keys())[0] + self._embed_datatypes[embed_key] = dtype + if index_type not in embed_index_map.get(dtype): + LOG.warning(f'[Milvus Store] {DataType(dtype).name}: Unsupported index type: {index_type}.' + f'Fallback to {list(embed_index_map.get(dtype).keys())[0]}.') + index_type = list(embed_index_map.get(dtype).keys())[0] + index_item['index_type'] = index_type + if metric_type not in embed_index_map[dtype][index_type]: + LOG.warning(f'[Milvus Store] {index_type}: Unsupported metric type: {metric_type}.' + f'Fallback to {embed_index_map[dtype][index_type][0]}.') + metric_type = embed_index_map[dtype][index_type][0] + index_item['metric_type'] = metric_type + self._ensure_params_defaults(index_item) + + if isinstance(index_params, list): + for index_item in index_params: + _replace_index_type(index_item) + elif isinstance(index_params, dict): + _replace_index_type(index_params) + else: + LOG.error(f'Expected index params type: list or dict, but got {type(index_params)}') + + def _ensure_params_defaults(self, index_item: dict): + ''' + Fill in the missing fields (index_type, metric_type, params) of a single index item. + Do not override the fields explicitly provided by the user (only setdefault) + params will be filled in with common defaults based on index_type + (if params already exist, only fill in missing keys) + ''' + if not isinstance(index_item, dict): + return + + # Normalize index_type + itype = index_item.get('index_type') + if itype: + itype_up = str(itype).upper() + index_item['index_type'] = itype_up + else: + LOG.error(f'cannot find `index_type` in `index_kwargs` of `{index_item}`') + + defaults = MILVUS_INDEX_TYPE_DEFAULTS.get(index_item['index_type'], {}) + + # metric_type default fill (do not override user) + if 'metric_type' not in index_item and 'metric_type' in defaults: + index_item['metric_type'] = defaults['metric_type'] + + default_params = defaults.get('params', {}) + if 'params' not in index_item or index_item.get('params') is None: + index_item['params'] = dict(default_params) + else: + # fill in the missing keys of params + if isinstance(index_item['params'], dict): + for k, v in default_params.items(): + index_item['params'].setdefault(k, v) + else: + # if user passed a non-dict (exception), replace it with the default dict + index_item['params'] = dict(default_params) def _serialize_data(self, d: dict) -> dict: # only keep primary_key, embedding and global_meta