diff --git a/xinference/model/rerank/core.py b/xinference/model/rerank/core.py index 1743757cbd..178f33025c 100644 --- a/xinference/model/rerank/core.py +++ b/xinference/model/rerank/core.py @@ -177,6 +177,7 @@ def _auto_detect_type(model_path): "LlamaTokenizerFast": "LLM-based layerwise", "GemmaTokenizerFast": "LLM-based", "XLMRobertaTokenizerFast": "normal", + "CLIPTokenizerFast": "LLM-based multimodal", } tokenizer = RerankModel._get_tokenizer(model_path) @@ -229,12 +230,40 @@ def load(self): ) if self._use_fp16: self._model.model.half() + elif self._model_spec.type == "LLM-based multimodal": + # 添加多模态模型加载逻辑 + try: + from transformers import AutoModel + + attn_implementation = ( + "flash_attention_2" if flash_attn_installed else None + ) + self._model = AutoModel.from_pretrained( + self._model_path, + torch_dtype="auto" if self._use_fp16 else None, + trust_remote_code=True, + attn_implementation=attn_implementation, + ) + + if self._device: + self._model.to(self._device) + self._model.eval() + return + + except ImportError: + error_message = "Failed to import module 'transformers'" + installation_guide = [ + "Please make sure 'transformers>=4.47.3' is installed. ", + "You can install it by `pip install 'transformers>=4.47.3'`\n", + ] + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") else: try: if self._model_spec.type == "LLM-based": from FlagEmbedding import FlagLLMReranker as FlagReranker elif self._model_spec.type == "LLM-based layerwise": from FlagEmbedding import LayerWiseFlagLLMReranker as FlagReranker + else: raise RuntimeError( f"Unsupported Rank model type: {self._model_spec.type}" @@ -265,6 +294,7 @@ def rerank( if max_chunks_per_doc is not None: raise ValueError("rerank hasn't support `max_chunks_per_doc` parameter.") logger.info("Rerank with kwargs: %s, model: %s", kwargs, self._model) + sentence_combinations = [[query, doc] for doc in documents] # reset n tokens self._model.model.n_tokens = 0 @@ -277,6 +307,37 @@ def rerank( ).cpu() if similarity_scores.dtype == torch.bfloat16: similarity_scores = similarity_scores.float() + elif self._model_spec.type == "LLM-based multimodal": + # 获取文档类型,默认为text + doc_type = kwargs.pop("doc_type", "text") + + # 检查模型是否支持该文档类型 + if ( + hasattr(self._model_spec, "supported_doc_types") + and doc_type not in self._model_spec.supported_doc_types + ): + raise ValueError( + f"Model {self._model_spec.model_name} does not support document type: {doc_type}" + ) + + + # 多模态模型处理逻辑 + max_length = kwargs.pop("max_length", 1024) + similarity_scores = self._model.compute_score( + sentence_combinations, + max_length=max_length, + doc_type=doc_type, + **kwargs, + ) + + if not isinstance(similarity_scores, Sequence): + similarity_scores = [similarity_scores] + elif ( + isinstance(similarity_scores, list) + and len(similarity_scores) > 0 + and isinstance(similarity_scores[0], Sequence) + ): + similarity_scores = similarity_scores[0] else: # Related issue: https://github.com/xorbitsai/inference/issues/1775 similarity_scores = self._model.compute_score( @@ -340,6 +401,16 @@ def rerank( gc.collect() empty_cache() + if self._counter > 10: + items = [] + for i in range(10): + items.append(docs[i]) + items = list(items) + else: + items.append(docs[0]) + + items.append(docs[0]) + return Rerank(id=str(uuid.uuid1()), results=docs, meta=metadata) diff --git a/xinference/model/rerank/model_spec.json b/xinference/model/rerank/model_spec.json index 7356198916..37c37d37da 100644 --- a/xinference/model/rerank/model_spec.json +++ b/xinference/model/rerank/model_spec.json @@ -62,5 +62,13 @@ "max_tokens": 1024, "model_id": "openbmb/MiniCPM-Reranker", "model_revision": "5d2fd7345b6444c89d4c0fa59c92272888f3f2d0" + }, + { + "model_name": "jina-reranker-m0", + "type": "LLM-based multimodal", + "language": ["en", "zh"], + "max_tokens": 10240, + "model_id": "jinaai/jina-reranker-m0", + "model_revision": "main" } ] diff --git a/xinference/model/rerank/model_spec_modelscope.json b/xinference/model/rerank/model_spec_modelscope.json index 2da7f099db..7a31369eb9 100644 --- a/xinference/model/rerank/model_spec_modelscope.json +++ b/xinference/model/rerank/model_spec_modelscope.json @@ -57,5 +57,14 @@ "max_tokens": 1024, "model_id": "OpenBMB/MiniCPM-Reranker", "model_hub": "modelscope" + }, + { + "model_name": "jina-reranker-m0", + "type": "LLM-based multimodal", + "language": ["en", "zh"], + "max_tokens": 10240, + "model_id": "jinaai/jina-reranker-m0", + "model_revision": "master", + "model_hub": "modelscope" } ]