From befd58c0e28380fe517d7bb6ee1963dcd4aa8b1f Mon Sep 17 00:00:00 2001 From: Chenjiahao Date: Wed, 3 Sep 2025 19:09:48 +0800 Subject: [PATCH] add instruct for online embed module --- .../base/onlineEmbeddingModuleBase.py | 22 ++++++++++++++++++- lazyllm/tools/rag/store/document_store.py | 2 +- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/lazyllm/module/llms/onlinemodule/base/onlineEmbeddingModuleBase.py b/lazyllm/module/llms/onlinemodule/base/onlineEmbeddingModuleBase.py index 9d56e86ec..851ea7209 100644 --- a/lazyllm/module/llms/onlinemodule/base/onlineEmbeddingModuleBase.py +++ b/lazyllm/module/llms/onlinemodule/base/onlineEmbeddingModuleBase.py @@ -11,12 +11,17 @@ def __init__(self, embed_url: str, api_key: str, embed_model_name: str, - return_trace: bool = False): + return_trace: bool = False, + instruct: str = None, + instruct_format: str = 'Instruct: {instruct}\nQuery: {text}' + ): super().__init__(return_trace=return_trace) self._model_series = model_series self._embed_url = embed_url self._api_key = api_key self._embed_model_name = embed_model_name + self._instruct = instruct + self._instruct_format = instruct_format self._set_headers() @property @@ -33,7 +38,22 @@ def _set_headers(self) -> Dict[str, str]: "Authorization": f"Bearer {self._api_key}" } + def _set_instruct_for_input(self, input: Union[List, str]) -> Union[List, str]: + if isinstance(input, str): + return self._get_detailed_instruction(input) + elif isinstance(input, list): + return [self._get_detailed_instruction(i) for i in input] + else: + raise ValueError(f'Invalid input type: {type(input)}') + + def _get_detailed_instruction(self, text: str) -> str: + return self._instruct_format.format(instruct=self._instruct, text=text) + def forward(self, input: Union[List, str], **kwargs) -> List[float]: + is_query = kwargs.pop('is_query', False) + if is_query and self._instruct: + input = self._set_instruct_for_input(input) + data = self._encapsulated_data(input, **kwargs) proxies = {'http': None, 'https': None} if self.NO_PROXY else None with requests.post(self._embed_url, json=data, headers=self._headers, proxies=proxies) as r: diff --git a/lazyllm/tools/rag/store/document_store.py b/lazyllm/tools/rag/store/document_store.py index b8b55225c..86a45401d 100644 --- a/lazyllm/tools/rag/store/document_store.py +++ b/lazyllm/tools/rag/store/document_store.py @@ -275,7 +275,7 @@ def query(self, query: str, group_name: str, similarity_name: Optional[str] = No ' are not supported when no vector store is provided') # vector search for embed_key in embed_keys: - query_embedding = self._embed.get(embed_key)(query) + query_embedding = self._embed.get(embed_key)(query, is_query=True) search_res = self.impl.search(collection_name=self._gen_collection_name(group_name), query=query, query_embedding=query_embedding, topk=topk, filters=filters, embed_key=embed_key, **kwargs)