Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,17 @@ def __init__(self,
embed_url: str,
api_key: str,
embed_model_name: str,
return_trace: bool = False):
return_trace: bool = False,
instruct: str = None,
instruct_format: str = 'Instruct: {instruct}\nQuery: {text}'
):
super().__init__(return_trace=return_trace)
self._model_series = model_series
self._embed_url = embed_url
self._api_key = api_key
self._embed_model_name = embed_model_name
self._instruct = instruct
self._instruct_format = instruct_format
self._set_headers()

@property
Expand All @@ -33,7 +38,22 @@ def _set_headers(self) -> Dict[str, str]:
"Authorization": f"Bearer {self._api_key}"
}

def _set_instruct_for_input(self, input: Union[List, str]) -> Union[List, str]:
if isinstance(input, str):
return self._get_detailed_instruction(input)
elif isinstance(input, list):
return [self._get_detailed_instruction(i) for i in input]
else:
raise ValueError(f'Invalid input type: {type(input)}')

def _get_detailed_instruction(self, text: str) -> str:
return self._instruct_format.format(instruct=self._instruct, text=text)

def forward(self, input: Union[List, str], **kwargs) -> List[float]:
is_query = kwargs.pop('is_query', False)
if is_query and self._instruct:
input = self._set_instruct_for_input(input)

data = self._encapsulated_data(input, **kwargs)
proxies = {'http': None, 'https': None} if self.NO_PROXY else None
with requests.post(self._embed_url, json=data, headers=self._headers, proxies=proxies) as r:
Expand Down
2 changes: 1 addition & 1 deletion lazyllm/tools/rag/store/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def query(self, query: str, group_name: str, similarity_name: Optional[str] = No
' are not supported when no vector store is provided')
# vector search
for embed_key in embed_keys:
query_embedding = self._embed.get(embed_key)(query)
query_embedding = self._embed.get(embed_key)(query, is_query=True)
search_res = self.impl.search(collection_name=self._gen_collection_name(group_name),
query=query, query_embedding=query_embedding,
topk=topk, filters=filters, embed_key=embed_key, **kwargs)
Expand Down
Loading