diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index b344b503dc..454f40780b 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -492,6 +492,16 @@ def serve(self, logging_conf: Optional[dict] = None): else None ), ) + self._router.add_api_route( + "/v1/model_registrations/{model_type}/{model_hub}/{model_format}/{user}/{repo}", + self.get_model_info_from_hub, + methods=["GET"], + dependencies=( + [Security(self._auth_service, scopes=["models:register"])] + if self.is_authenticated() + else None + ), + ) # Clear the global Registry for the MetricsMiddleware, or # the MetricsMiddleware will register duplicated metrics if the port @@ -1507,6 +1517,32 @@ async def get_cluster_version(self) -> JSONResponse: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + async def get_model_info_from_hub( + self, model_type: str, model_hub: str, model_format: str, user: str, repo: str + ) -> JSONResponse: + try: + if model_type == "LLM": + llm_family_info = await ( + await self._get_supervisor_ref() + ).get_llm_family_from_hub(f"{user}/{repo}", model_format, model_hub) + return JSONResponse(content=llm_family_info) + if model_type == "embedding": + embed_spec = await ( + await self._get_supervisor_ref() + ).get_embedding_spec_from_hub(f"{user}/{repo}", model_hub) + return JSONResponse(content=embed_spec) + except ValueError as re: + logger.error(re, exc_info=True) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + raise HTTPException( + status_code=400, + detail="only LLM and embedding model type supported currently", + ) + def run( supervisor_address: str, diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 62e2bacce2..e7d3bbc5ac 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -14,6 +14,7 @@ import asyncio import itertools +import json import time import typing from dataclasses import dataclass @@ -21,6 +22,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union import xoscar as xo +from typing_extensions import Literal, cast from ..constants import ( XINFERENCE_DISABLE_HEALTH_CHECK, @@ -30,11 +32,22 @@ ) from ..core import ModelActor from ..core.status_guard import InstanceInfo, LaunchStatus +from ..model.embedding import CustomEmbeddingModelSpec +from ..model.embedding.utils import get_language_from_model_id +from ..model.llm import GgmlLLMSpecV1 +from ..model.llm.llm_family import ( + DEFAULT_CONTEXT_LENGTH, + HubImportLLMFamilyV1, + PytorchLLMSpecV1, +) +from ..model.llm.utils import MODEL_HUB, ModelHubUtil from .metrics import record_metrics from .resource import GPUStatus, ResourceStatus from .utils import ( build_replica_model_uid, gen_random_string, + get_llama_cpp_quantization_info, + get_model_size_from_model_id, is_valid_model_uid, iter_replica_model_uid, log_async, @@ -54,7 +67,6 @@ logger = getLogger(__name__) - ASYNC_LAUNCH_TASKS = {} # type: ignore @@ -79,6 +91,7 @@ class ReplicaInfo: class SupervisorActor(xo.StatelessActor): def __init__(self): super().__init__() + self._model_hub_util = ModelHubUtil() self._worker_address_to_worker: Dict[str, xo.ActorRefType["WorkerActor"]] = {} self._worker_status: Dict[str, WorkerStatus] = {} self._replica_model_uid_to_worker: Dict[ @@ -665,8 +678,8 @@ async def launch_speculative_llm( model_uid = self._gen_model_uid(model_name) logger.debug( ( - f"Enter launch_speculative_llm, model_uid: %s, model_name: %s, model_size: %s, " - f"draft_model_name: %s, draft_model_size: %s" + "Enter launch_speculative_llm, model_uid: %s, model_name: %s, model_size: %s, " + "draft_model_name: %s, draft_model_size: %s" ), model_uid, model_name, @@ -1005,3 +1018,117 @@ async def report_worker_status( @staticmethod def record_metrics(name, op, kwargs): record_metrics(name, op, kwargs) + + @log_async(logger=logger) + async def get_llm_family_from_hub( + self, + model_id: str, + model_format: str, + model_hub: str, + ) -> HubImportLLMFamilyV1: + if model_hub not in ["huggingface", "modelscope"]: + raise ValueError(f"Unsupported model hub: {model_hub}") + + model_hub = cast(MODEL_HUB, model_hub) + + context_length = DEFAULT_CONTEXT_LENGTH + + repo_exists = await self._model_hub_util.a_repo_exists( + model_id, + model_hub, + ) + if not repo_exists: + raise ValueError(f"Model {model_id} does not exist") + + if config_path := await self._model_hub_util.a_get_config_path( + model_id, model_hub + ): + with open(config_path) as f: + config = json.load(f) + if "max_position_embeddings" in config: + context_length = config["max_position_embeddings"] + + if model_format in ["ggmlv3", "ggufv2"]: + filenames = await self._model_hub_util.a_list_repo_files( + model_id, model_hub + ) + + ( + model_file_name_template, + model_file_name_split_template, + quantizations, + quantization_parts, + ) = get_llama_cpp_quantization_info( + filenames, typing.cast(Literal["ggmlv3", "ggufv2"], model_format) + ) + + llm_spec = GgmlLLMSpecV1( + model_id=model_id, + model_format=model_format, + model_hub=model_hub, + quantizations=quantizations, + quantization_parts=quantization_parts, + model_size_in_billions=get_model_size_from_model_id(model_id), + model_file_name_template=model_file_name_template, + model_file_name_split_template=model_file_name_split_template, + ) + + return HubImportLLMFamilyV1( + version=1, context_length=context_length, model_specs=[llm_spec] + ) + elif model_format in ["pytorch", "awq"]: + llm_spec = PytorchLLMSpecV1( + model_id=model_id, + model_format=model_format, + model_hub=model_hub, + model_size_in_billions=get_model_size_from_model_id(model_id), + quantizations=( + ["4-bit", "8-bit", "none"] + if model_format == "pytorch" + else ["Int4"] + ), + ) + return HubImportLLMFamilyV1( + version=1, context_length=context_length, model_specs=[llm_spec] + ) + elif model_format == "gptq": + raise NotImplementedError("gptq is not implemented yet") + else: + raise ValueError(f"Unsupported model format: {model_format}") + + @log_async(logger=logger) + async def get_embedding_spec_from_hub( + self, model_id: str, model_hub: str + ) -> CustomEmbeddingModelSpec: + if model_hub not in ["huggingface", "modelscope"]: + raise ValueError(f"Unsupported model hub: {model_hub}") + + model_hub = cast(MODEL_HUB, model_hub) + + repo_exists = await self._model_hub_util.a_repo_exists( + model_id, + model_hub, + ) + + if not repo_exists: + raise ValueError(f"Model {model_id} does not exist") + + max_tokens = 512 + dimensions = 768 + if config_path := await self._model_hub_util.a_get_config_path( + model_id, model_hub + ): + with open(config_path) as f: + config = json.load(f) + if "max_position_embeddings" in config: + max_tokens = config["max_position_embeddings"] + if "hidden_size" in config: + dimensions = config["hidden_size"] + return CustomEmbeddingModelSpec( + model_name=model_id.split("/")[-1], + model_id=model_id, + max_tokens=max_tokens, + dimensions=dimensions, + model_hub=model_hub, + language=[get_language_from_model_id(model_id)], + ) diff --git a/xinference/core/tests/test_supervisor.py b/xinference/core/tests/test_supervisor.py new file mode 100644 index 0000000000..09001fe07a --- /dev/null +++ b/xinference/core/tests/test_supervisor.py @@ -0,0 +1,240 @@ +import pytest + +from ..supervisor import SupervisorActor + + +@pytest.mark.asyncio +async def test_get_llm_spec_hf(): + supervisor = SupervisorActor() + llm_family = await supervisor.get_llm_family_from_hub( + "TheBloke/Llama-2-7B-Chat-GGML", "ggmlv3", "huggingface" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + assert llm_family.model_specs[0].model_id == "TheBloke/Llama-2-7B-Chat-GGML" + assert llm_family.model_specs[0].model_size_in_billions == 7 + assert llm_family.model_specs[0].model_hub == "huggingface" + assert len(llm_family.model_specs[0].quantizations) == 14 + assert ( + llm_family.model_specs[0].model_file_name_template + == "llama-2-7b-chat.ggmlv3.{quantization}.bin" + ) + assert llm_family.model_specs[0].model_file_name_split_template is None + assert llm_family.model_specs[0].quantization_parts is None + + assert { + "q2_K", + "q3_K_L", + "q3_K_M", + "q3_K_S", + "q4_0", + "q4_1", + "q4_K_M", + "q4_K_S", + "q5_0", + "q5_1", + "q5_K_M", + "q5_K_S", + "q6_K", + "q8_0", + }.intersection(set(llm_family.model_specs[0].quantizations)) == set( + llm_family.model_specs[0].quantizations + ) + + llm_family = await supervisor.get_llm_family_from_hub( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "ggufv2", "huggingface" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + assert llm_family.model_specs[0].model_id == "TheBloke/KafkaLM-70B-German-V0.1-GGUF" + assert llm_family.model_specs[0].model_size_in_billions == 70 + assert llm_family.model_specs[0].model_hub == "huggingface" + qs = llm_family.model_specs[0].quantizations + assert len(qs) == 12 + assert ( + llm_family.model_specs[0].model_file_name_template + == "kafkalm-70b-german-v0.1.{quantization}.gguf" + ) + assert ( + llm_family.model_specs[0].model_file_name_split_template + == "kafkalm-70b-german-v0.1.{quantization}.gguf-split-{part}" + ) + parts = llm_family.model_specs[0].quantization_parts + assert parts is not None + assert len(parts) == 2 + assert len(parts["Q8_0"]) == 2 + + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + }.intersection(set(qs)) == set(qs) + + with pytest.raises(ValueError, match="Model Nobody/No_This_Repo does not exist"): + await supervisor.get_llm_family_from_hub( + "Nobody/No_This_Repo", "ggufv2", "huggingface" + ) + + +@pytest.mark.asyncio +async def test_get_llm_spec_ms(): + supervisor = SupervisorActor() + llm_family = await supervisor.get_llm_family_from_hub( + "Xorbits/Llama-2-7B-Chat-GGML", "ggmlv3", "modelscope" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + assert llm_family.model_specs[0].model_id == "Xorbits/Llama-2-7B-Chat-GGML" + assert llm_family.model_specs[0].model_size_in_billions == 7 + assert llm_family.model_specs[0].model_hub == "modelscope" + assert len(llm_family.model_specs[0].quantizations) == 14 + assert ( + llm_family.model_specs[0].model_file_name_template + == "llama-2-7b-chat.ggmlv3.{quantization}.bin" + ) + assert llm_family.model_specs[0].model_file_name_split_template is None + assert llm_family.model_specs[0].quantization_parts is None + + assert { + "q2_K", + "q3_K_L", + "q3_K_M", + "q3_K_S", + "q4_0", + "q4_1", + "q4_K_M", + "q4_K_S", + "q5_0", + "q5_1", + "q5_K_M", + "q5_K_S", + "q6_K", + "q8_0", + }.intersection(set(llm_family.model_specs[0].quantizations)) == set( + llm_family.model_specs[0].quantizations + ) + + llm_family = await supervisor.get_llm_family_from_hub( + "qwen/Qwen1.5-72B-Chat-GGUF", "ggufv2", "modelscope" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + assert llm_family.model_specs[0].model_id == "qwen/Qwen1.5-72B-Chat-GGUF" + assert llm_family.model_specs[0].model_size_in_billions == 72 + assert llm_family.model_specs[0].model_hub == "modelscope" + qs = llm_family.model_specs[0].quantizations + assert len(qs) == 8 + assert ( + llm_family.model_specs[0].model_file_name_template + == "qwen1_5-72b-chat-{quantization}.gguf" + ) + assert ( + llm_family.model_specs[0].model_file_name_split_template + == "qwen1_5-72b-chat-{quantization}.gguf.{part}" + ) + parts = llm_family.model_specs[0].quantization_parts + assert parts is not None + assert len(parts) == 6 + assert len(parts["q8_0"]) == 3 + + assert { + "q2_k", + "q3_k_m", + "q4_0", + "q4_k_m", + "q5_0", + "q5_k_m", + "q6_k", + "q8_0", + }.intersection(set(qs)) == set(qs) + + with pytest.raises(ValueError, match="Model Nobody/No_This_Repo does not exist"): + await supervisor.get_llm_family_from_hub( + "Nobody/No_This_Repo", "ggufv2", "modelscope" + ) + + +@pytest.mark.asyncio +async def test_get_llm_spec_2(): + supervisor = SupervisorActor() + llm_family = await supervisor.get_llm_family_from_hub( + "Qwen/Qwen1.5-1.8B", "pytorch", "huggingface" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + pytorch_qs = {"4-bit", "8-bit", "none"} + assert ( + pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs + ) + + assert llm_family.model_specs[0].model_size_in_billions == "1_8" + assert llm_family.context_length == 32768 + + llm_family = await supervisor.get_llm_family_from_hub( + "qwen/Qwen-14B-Chat", "pytorch", "modelscope" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + pytorch_qs = {"4-bit", "8-bit", "none"} + assert ( + pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs + ) + + assert llm_family.model_specs[0].model_size_in_billions == 14 + assert llm_family.context_length == 8192 + + llm_family = await supervisor.get_llm_family_from_hub( + "qwen/Qwen1.5-7B-Chat-AWQ", "awq", "modelscope" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + pytorch_qs = {"Int4"} + assert ( + pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs + ) + + assert llm_family.model_specs[0].model_size_in_billions == 7 + assert llm_family.context_length == 32768 + + llm_family = await supervisor.get_llm_family_from_hub( + "casperhansen/mixtral-instruct-awq", "awq", "huggingface" + ) + assert llm_family is not None + assert len(llm_family.model_specs) == 1 + pytorch_qs = {"Int4"} + assert ( + pytorch_qs.intersection(llm_family.model_specs[0].quantizations) == pytorch_qs + ) + + assert llm_family.model_specs[0].model_size_in_billions == 0 + assert llm_family.context_length == 32768 + + +@pytest.mark.asyncio +async def test_get_embedding_spec_from_hub(): + supervisor = SupervisorActor() + embedding_spec = await supervisor.get_embedding_spec_from_hub( + "BAAI/bge-large-zh-v1.5", "huggingface" + ) + assert embedding_spec is not None + assert embedding_spec.model_name == "bge-large-zh-v1.5" + assert embedding_spec.model_id == "BAAI/bge-large-zh-v1.5" + + embedding_spec = await supervisor.get_embedding_spec_from_hub( + "bensonpeng/bge-large-en-v1.5", "modelscope" + ) + + assert embedding_spec is not None + assert embedding_spec.model_name == "bge-large-en-v1.5" + assert embedding_spec.model_id == "bensonpeng/bge-large-en-v1.5" + assert embedding_spec.max_tokens == 512 + assert embedding_spec.dimensions == 1024 diff --git a/xinference/core/tests/test_utils.py b/xinference/core/tests/test_utils.py index ce94e50a35..c225aa12f5 100644 --- a/xinference/core/tests/test_utils.py +++ b/xinference/core/tests/test_utils.py @@ -11,9 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest from ..utils import ( + SUPPORTED_QUANTIZATIONS, build_replica_model_uid, + get_llama_cpp_quantization_info, + get_match_quantization_filenames, + get_model_size_from_model_id, + get_prefix_suffix, iter_replica_model_uid, parse_replica_model_uid, ) @@ -29,3 +35,304 @@ def test_replica_model_uid(): all_gen_ids.append(replica_model_uid) assert len(all_gen_ids) == 5 assert len(set(all_gen_ids)) == 5 + + +def test_get_model_size_from_model_id(): + model_id = "froggeric/WestLake-10.7B-v2-GGUF" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "10_7" + + model_id = "m-a-p/OpenCodeInterpreter-DS-33B" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 33 + + model_id = "MBZUAI/MobiLlama-05B" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "0_5" + + model_id = "ibivibiv/alpaca-dragon-72b-v1" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 72 + + model_id = "ISTA-DASLab/Mixtral-8x7B-Instruct-v0_1-AQLM-2Bit-1x16-hf" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 7 + + model_id = "internlm/internlm-xcomposer2-vl-7b-4bit" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 7 + + model_id = "ahxt/LiteLlama-460M-1T" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "0_46" + + model_id = "Dracones/Midnight-Miqu-70B-v1.0_exl2_2.24bpw" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 70 + + model_id = "MaziyarPanahi/MixTAO-7Bx2-MoE-v8.1-GGUF" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 7 + + model_id = "ISTA-DASLab/Mixtral-8x7b-AQLM-2Bit-1x16-hf" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 7 + + model_id = "stabilityai/stablelm-2-zephyr-1_6b" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "1_6" + + model_id = "Qwen/Qwen1.5-Chat-4bit-GPTQ-72B" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 72 + + model_id = "m-a-p/OpenCodeInterpreter-3Bee-DS-33B" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 33 + + model_id = "qwen/Qwen1.5-0.5B-Chat" + model_size = get_model_size_from_model_id(model_id) + assert model_size == "0_5" + + model_id = "mlx-community/c4ai-command-r-v01-4bit" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 0 + + model_id = "lemonilia/ShoriRP-v0.75d" + model_size = get_model_size_from_model_id(model_id) + assert model_size == 0 + + model_id = "abc" + with pytest.raises(ValueError, match=r"Cannot parse model_id: .+"): + get_model_size_from_model_id(model_id) + + +def test_get_match_quantization_filenames(): + filenames = [ + "kafkalm-70b-german-v0.1.Q2_K.gguf", + "kafkalm-70b-german-v0.1.Q3_K_L.gguf", + "kafkalm-70b-german-v0.1.Q3_K_M.gguf", + "kafkalm-70b-german-v0.1.Q3_K_S.gguf", + "kafkalm-70b-german-v0.1.Q4_0.gguf", + "kafkalm-70b-german-v0.1.Q4_K_M.gguf", + "kafkalm-70b-german-v0.1.Q4_K_S.gguf", + "kafkalm-70b-german-v0.1.Q5_K_M.gguf", + "kafkalm-70b-german-v0.1.Q5_K_S.gguf", + "kafkalm-70b-german-v0.1.Q6_K.gguf-split-a", + "kafkalm-70b-german-v0.1.Q6_K.gguf-split-b", + "kafkalm-70b-german-v0.1.Q8_0.gguf-split-a", + "kafkalm-70b-german-v0.1.Q8_0.gguf-split-b", + ] + + results = get_match_quantization_filenames(filenames) + assert len(results) == 13 + assert all(x[0][: x[2]] == "kafkalm-70b-german-v0.1." for x in results) + assert all(x[1].upper() in SUPPORTED_QUANTIZATIONS for x in results) + assert results[0][0][results[0][2] + len(results[0][1]) :] == ".gguf" + assert results[-1][0][results[-1][2] + len(results[-1][1]) :] == ".gguf-split-b" + + +def test_get_prefix_suffix(): + names = [ + ".gguf-split-a", + ".gguf-split-b", + ".gguf-split-a", + ".gguf-split-b", + ".gguf-split-c", + ] + prefix, suffix = get_prefix_suffix(names) + assert prefix == ".gguf-split-" + assert suffix == "" + + names = ["-part-a.gguf", "-part-b.gguf", "-part-c.gguf", "-part-a.gguf"] + + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-part-" + assert suffix == ".gguf" + + names = ["-part-1.gguf", "-part-2.gguf", "-part-12.gguf", "-part-2.gguf"] + + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-part-" + assert suffix == ".gguf" + + names = [".gguf", "-part-1.gguf", "-part-2.gguf", "-part-12.gguf", "-part-2.gguf"] + prefix, suffix = get_prefix_suffix(names) + assert prefix == "" + assert suffix == ".gguf" + + names = [ + "-test.gguf", + "-test-part-1.gguf", + "-test-part-2.gguf", + "-test-part-12.gguf", + "-test-part-2.gguf", + ] + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-test" + assert suffix == ".gguf" + + names = ["-part-1.gguf", "-part-1.gguf", "-part-1.gguf"] + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-part-1.gguf" + assert suffix == "" + + prefix, suffix = get_prefix_suffix([]) + assert prefix == "" + assert suffix == "" + + names = ["-only-1.gguf"] + prefix, suffix = get_prefix_suffix(names) + assert prefix == "-only-1.gguf" + assert suffix == "" + + +def test_get_llama_cpp_quantization_info(): + filenames = [ + "kafkalm-70b-german-v0.1.Q2_K.gguf", + "kafkalm-70b-german-v0.1.Q3_K_L.gguf", + "kafkalm-70b-german-v0.1.Q3_K_M.gguf", + "kafkalm-70b-german-v0.1.Q3_K_S.gguf", + "kafkalm-70b-german-v0.1.Q4_0.gguf", + "kafkalm-70b-german-v0.1.Q4_K_M.gguf", + "kafkalm-70b-german-v0.1.Q4_K_S.gguf", + "kafkalm-70b-german-v0.1.Q5_K_M.gguf", + "kafkalm-70b-german-v0.1.Q5_K_S.gguf", + "kafkalm-70b-german-v0.1.Q6_K.gguf-split-a", + "kafkalm-70b-german-v0.1.Q6_K.gguf-split-b", + "kafkalm-70b-german-v0.1.Q8_0.gguf-split-a", + "kafkalm-70b-german-v0.1.Q8_0.gguf-split-b", + ] + + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames[:-4], "ggufv2") + assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.gguf" + assert tpl2 is None + assert len(qs) == 9 + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_K_M", + "Q5_K_S", + }.intersection(set(qs)) == set(qs) + assert parts is None + + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames, "ggufv2") + assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.gguf" + assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.gguf-split-{part}" + assert len(qs) == 11 + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + }.intersection(set(qs)) == set(qs) + assert len(parts) == 2 + assert len(parts["Q6_K"]) == 2 + assert len(parts["Q8_0"]) == 2 + assert parts["Q6_K"][0] == "a" + assert parts["Q8_0"][1] == "b" + + filenames = [ + "kafkalm-70b-german-v0.1.Q2_K.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_L.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q4_0.test.gguf", + "kafkalm-70b-german-v0.1.Q4_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q4_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q5_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q5_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q6_K.test-split-a.gguf", + "kafkalm-70b-german-v0.1.Q6_K.test-split-b.gguf", + "kafkalm-70b-german-v0.1.Q8_0.test-split-a.gguf", + "kafkalm-70b-german-v0.1.Q8_0.test-split-b.gguf", + ] + + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames, "ggufv2") + assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf" + assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.test-split-{part}.gguf" + assert len(qs) == 11 + assert len(parts) == 2 + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + }.intersection(set(qs)) == set(qs) + assert parts["Q8_0"][1] == "b" + + filenames = [ + "kafkalm-70b-german-v0.1.Q2_K.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_L.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q3_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q4_0.test.gguf", + "kafkalm-70b-german-v0.1.Q4_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q4_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q5_K_M.test.gguf", + "kafkalm-70b-german-v0.1.Q5_K_S.test.gguf", + "kafkalm-70b-german-v0.1.Q6_K.test.gguf-part1of2", + "kafkalm-70b-german-v0.1.Q6_K.test.gguf-part2of2", + "kafkalm-70b-german-v0.1.Q8_0.test.gguf-part1of3", + "kafkalm-70b-german-v0.1.Q8_0.test.gguf-part2of3", + "kafkalm-70b-german-v0.1.Q8_0.test.gguf-part3of3", + ] + + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames, "ggufv2") + assert tpl1 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf" + assert tpl2 == "kafkalm-70b-german-v0.1.{quantization}.test.gguf-part{part}" + assert len(qs) == 11 + assert { + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0", + }.intersection(set(qs)) == set(qs) + assert len(parts) == 2 + assert len(parts["Q8_0"]) == 3 + assert parts["Q8_0"][2] == "3of3" + + filenames = [ + "llama-2-7b-chat.ggmlv3.q2_K.bin", + "llama-2-7b-chat.ggmlv3.q3_K_L.bin", + "llama-2-7b-chat.ggmlv3.q3_K_M.bin", + "llama-2-7b-chat.ggmlv3.q3_K_S.bin", + "llama-2-7b-chat.ggmlv3.q4_0.bin", + "llama-2-7b-chat.ggmlv3.q4_K_M.bin", + "llama-2-7b-chat.ggmlv3.q4_K_S.bin", + "llama-2-7b-chat.ggmlv3.q5_K_M.bin", + "llama-2-7b-chat.ggmlv3.q5_K_S.bin", + ] + + tpl1, tpl2, qs, parts = get_llama_cpp_quantization_info(filenames, "ggmlv3") + + assert tpl1 == "llama-2-7b-chat.ggmlv3.{quantization}.bin" + assert tpl2 is None + assert len(qs) == 9 + assert parts is None diff --git a/xinference/core/utils.py b/xinference/core/utils.py index 0a121f4769..e146931d86 100644 --- a/xinference/core/utils.py +++ b/xinference/core/utils.py @@ -15,11 +15,13 @@ import logging import os import random +import re import string -from typing import Dict, Generator, List, Tuple, Union +from typing import Dict, Generator, Iterable, List, Optional, Tuple, Union, cast import orjson from pynvml import nvmlDeviceGetCount, nvmlInit, nvmlShutdown +from typing_extensions import Literal from .._compat import BaseModel @@ -191,3 +193,238 @@ def get_nvidia_gpu_info() -> Dict: nvmlShutdown() except: pass + + +def get_model_size_from_model_id(model_id: str) -> Union[str, float, int]: + """ + Get model size from model_id. + + Args: + model_id: model_id in format of `user/repo` + + Returns: + model size in format of `100B`, if size is in M, divide into 1000 and return as B. + For example, `100M` will be returned as `0.1B`. + + If there is no model size in the repo name, return `UNKNOWN`. + """ + + def resize_to_billion(size: str) -> Union[str, int, float]: + if size == "UNKNOWN": + return 0 + + if size.lower().endswith("m"): + return str(round(int(size[:-1]) / 1000, 2)).replace(".", "_") + + size = size[:-1] + if "_" not in size: + if size[0] == "0" and "." not in size: + size = size[0] + "." + str(size[1:]) + + if "." in size: + return size.replace(".", "_") + else: + return int(size) + + return size + + split = model_id.split("/") + if len(split) != 2: + raise ValueError(f"Cannot parse model_id: {model_id}") + user, repo = split + segs = repo.split("-") + param_pattern = re.compile(r"\d+(?:[._]\d+)?[bm]", re.I) + partial_matched = "UNKNOWN" + for seg in segs: + if m := param_pattern.search(seg): + if m.start() == 0 and m.end() == len(seg): + return resize_to_billion(seg) + else: + # only match the first partial matched, and do not match `bit` for quantization mode + if ( + partial_matched == "UNKNOWN" + and seg[m.end(0) : m.end(0) + 2].lower() != "it" + ): + partial_matched = m.group(0) + return resize_to_billion(partial_matched) + + +SUPPORTED_QUANTIZATIONS = [ + "Q3_K_S", + "Q3_K_M", + "Q3_K_L", + "Q4_K_S", + "Q4_K_M", + "Q5_K_S", + "Q5_K_M", + "Q6_K", + "F32", + "F16", + "Q4_0", + "Q4_1", + "Q8_0", + "Q5_0", + "Q5_1", + "Q2_K", +] + + +def get_match_quantization_filenames( + filenames: List[str], +) -> List[Tuple[str, str, int]]: + """ + Get the quantization info from filenames. + + Return: + A list of tuples: (filename, quantization, index of the quantization in filename) + """ + results: List[Tuple[str, str, int]] = [] + for filename in filenames: + for quantization in SUPPORTED_QUANTIZATIONS: + if (index := filename.upper().find(quantization)) != -1: + results.append((filename, quantization, index)) + return results + + +def get_prefix_suffix(names: Iterable[str]) -> Tuple[str, str]: + """ + Get the common prefix and suffix from a list of names. + """ + if len(list(names)) == 0: + return "", "" + + # if all names are the same, or only one name, return the first name as prefix and suffix is empty + if len(set(names)) == 1: + return list(names)[0], "" + + min_len = min(map(len, names)) + name = [n for n in names if len(n) == min_len][0] + + for i in range(min_len): + if len(set(map(lambda x: x[: i + 1], names))) > 1: + prefix = name[:i] + break + else: + prefix = name + + for i in range(min_len): + if len(set(map(lambda x: x[-i - 1 :], names))) > 1: + suffix = name[len(name) - i :] + break + else: + suffix = name + + return prefix, suffix + + +def get_llama_cpp_quantization_info( + filenames: List[str], model_type: Literal["ggmlv3", "ggufv2"] +) -> Tuple[Optional[str], Optional[str], List[str], Optional[Dict[str, List[str]]]]: + """ + Get the model file name template and split template from a list of filenames. + + NOTE: not support multiple quantization files in multi-part zip files. + for example: a-16b.ggmlv3.zip a-16b.ggmlv3.z01 a-16b.ggmlv3.z02 are not supported + + Return: + model_file_name_template: the model file name with quantization info + model_file_name_split_template: the model file name with quantization info and part index + quantizations: the quantization info + parts: the quantization part index + """ + model_file_name_template = None + model_file_name_split_template: Optional[str] = None + quantizations: List[str] = [] + parts: Optional[Dict[str, List[str]]] = None + + if model_type == "ggmlv3": + filenames = [ + filename + for filename in filenames + if filename.lower().endswith(".bin") or "ggml" in filename.lower() + ] + elif model_type == "ggufv2": + filenames = [filename for filename in filenames if ".gguf" in filename.lower()] + else: + raise ValueError(f"Unsupported model type: {model_type}") + + matched = get_match_quantization_filenames(filenames) + + if len(matched) == 0: + raise ValueError("Cannot find any quantization files in this") + + prefixes = set() + suffixes = set() + + for filename, quantization, index in matched: + prefixes.add(filename[:index]) + suffixes.add(filename[index + len(quantization) :]) + q = filename[index : index + len(quantization)] + if q not in quantizations: + quantizations.append(q) + + if len(prefixes) == 1 and len(suffixes) == 1: + model_file_name_template = prefixes.pop() + "{quantization}" + suffixes.pop() + return ( + model_file_name_template, + model_file_name_split_template, + quantizations, + parts, + ) + + if len(prefixes) == 1 and len(suffixes) > 1: + parts = {} + shortest_suffix = min(suffixes, key=len) + part_prefix, part_suffix = get_prefix_suffix(suffixes) + if shortest_suffix == part_prefix + part_suffix: + model_file_name_template = ( + list(prefixes)[0] + "{quantization}" + shortest_suffix + ) + part_prefix, part_suffix = get_prefix_suffix( + [suffix for suffix in suffixes if suffix != shortest_suffix] + ) + + model_file_name_split_template = ( + prefixes.pop() + "{quantization}" + part_prefix + "{part}" + part_suffix + ) + + elif len(prefixes) > 1 and len(suffixes) == 1: + parts = {} + shortest_prefix = min(prefixes, key=len) + part_prefix, part_suffix = get_prefix_suffix(prefixes) + if shortest_prefix == part_prefix + part_suffix: + model_file_name_template = ( + shortest_prefix + "{quantization}" + list(suffixes)[0] + ) + part_prefix, part_suffix = get_prefix_suffix( + [prefix for prefix in prefixes if prefix != shortest_prefix] + ) + + model_file_name_split_template = ( + part_prefix + "{part}" + part_suffix + "{quantization}" + suffixes.pop() + ) + else: + logger.info("Cannot find a valid template for model file names") + + if model_file_name_split_template is not None: + part_pattern_str = model_file_name_split_template.replace( + "{part}", r"(?P\w+)" + ) + quan_pattern_str = "(?P" + f"{'|'.join(quantizations)})" + part_pattern_str = part_pattern_str.replace("{quantization}", quan_pattern_str) + + part_pattern = re.compile(part_pattern_str) + for filename in filenames: + if m := part_pattern.match(filename): + matched_quan = m.group("quantization") + parts = cast(Dict[str, List[str]], parts) + if matched_quan not in parts: + parts[matched_quan] = [] + parts[matched_quan].append(m.group("part")) + + return ( + model_file_name_template, + model_file_name_split_template, + quantizations, + parts, + ) diff --git a/xinference/model/embedding/custom.py b/xinference/model/embedding/custom.py index 8e311bbd7d..83ce51bfdb 100644 --- a/xinference/model/embedding/custom.py +++ b/xinference/model/embedding/custom.py @@ -63,7 +63,8 @@ def register_embedding(model_spec: CustomEmbeddingModelSpec, persist: bool): if persist: # We only validate model URL when persist is True. model_uri = model_spec.model_uri - if model_uri and not is_valid_model_uri(model_uri): + model_id = model_spec.model_id + if model_id is None and model_uri and not is_valid_model_uri(model_uri): raise ValueError(f"Invalid model URI {model_uri}.") persist_path = os.path.join( diff --git a/xinference/model/embedding/tests/test_utils.py b/xinference/model/embedding/tests/test_utils.py new file mode 100644 index 0000000000..c5d2bead19 --- /dev/null +++ b/xinference/model/embedding/tests/test_utils.py @@ -0,0 +1,30 @@ +from ..utils import get_language_from_model_id + + +def test_get_language_from_model_id(): + model_id = "BAAI/bge-large-zh-v1.5" + assert get_language_from_model_id(model_id) == "zh" + + model_id = "BAAI/bge-large-base-v1.5" + assert get_language_from_model_id(model_id) == "en" + + model_id = "google-bert/bert-base-multilingual-cased" + assert get_language_from_model_id(model_id) == "zh" + + model_id = "jinaai/jina-embeddings-v2-base-en" + assert get_language_from_model_id(model_id) == "en" + + model_id = "jinaai/jina-embeddings-v2-base-es" + # now only support zh and en, if it is not chinese, then en, even the language is specified as es + assert get_language_from_model_id(model_id) == "en" + + model_id = "bge-large-zh-v1.5" + # wrong model id will cause the en is returned + assert get_language_from_model_id(model_id) == "en" + + model_id = "BAAI/newtype/bge-large-zh-v1.5" + # wrong model id format, return en + assert get_language_from_model_id(model_id) == "en" + + model_id = "" + assert get_language_from_model_id(model_id) == "en" diff --git a/xinference/model/embedding/utils.py b/xinference/model/embedding/utils.py index 8b63e6eb5f..547ec33171 100644 --- a/xinference/model/embedding/utils.py +++ b/xinference/model/embedding/utils.py @@ -11,8 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from logging import getLogger + from .core import EmbeddingModelSpec def get_model_version(embedding_model: EmbeddingModelSpec) -> str: return f"{embedding_model.model_name}--{embedding_model.max_tokens}--{embedding_model.dimensions}" + + +def get_language_from_model_id(model_id: str) -> str: + split = model_id.split("/") + if len(split) != 2: + logger = getLogger(__name__) + logger.error(f"Invalid model_id: {model_id}, return the default en language") + return "en" + model_id = split[-1] + segments = model_id.split("-") + for seg in segments: + if seg.lower() in ["zh", "cn", "chinese", "multilingual"]: + return "zh" + return "en" diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 15ff0db84c..da34566ed7 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -122,6 +122,13 @@ class LLMFamilyV1(BaseModel): prompt_style: Optional["PromptStyleV1"] +class HubImportLLMFamilyV1(BaseModel): + version: Literal[1] + context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH + model_specs: List["LLMSpecV1"] = [] + prompt_style: Optional["PromptStyleV1"] + + class CustomLLMFamilyV1(LLMFamilyV1): prompt_style: Optional[Union["PromptStyleV1", str]] # type: ignore @@ -208,6 +215,7 @@ def parse_raw( ] LLMFamilyV1.update_forward_refs() +HubImportLLMFamilyV1.update_forward_refs() CustomLLMFamilyV1.update_forward_refs() @@ -534,7 +542,10 @@ def _generate_model_file_names( ) need_merge = False - if llm_spec.quantization_parts is None: + if ( + llm_spec.quantization_parts is None + or quantization not in llm_spec.quantization_parts + ): file_names.append(final_file_name) elif quantization is not None and quantization in llm_spec.quantization_parts: parts = llm_spec.quantization_parts[quantization] diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index d5e40d7561..d8daef217a 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os + +import pytest from ....types import ChatCompletionMessage from ..llm_family import PromptStyleV1 -from ..utils import ChatModelMixin +from ..utils import ChatModelMixin, ModelHubUtil def test_prompt_style_add_colon_single(): @@ -421,3 +424,133 @@ def test_is_valid_model_name(): assert not is_valid_model_name("foo/bar") assert not is_valid_model_name(" ") assert not is_valid_model_name("") + + +@pytest.fixture +def model_hub_util(): + return ModelHubUtil() + + +def test__hf_api(model_hub_util): + assert model_hub_util._hf_api is not None + + +def test__ms_api(model_hub_util): + assert model_hub_util._ms_api is not None + + +def test_repo_exists(model_hub_util): + assert model_hub_util.repo_exists( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert not model_hub_util.repo_exists("Nobody/No_This_Repo", "huggingface") + with pytest.raises(ValueError, match="Unsupported model hub"): + model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") + + assert model_hub_util.repo_exists("qwen/Qwen1.5-72B-Chat-GGUF", "modelscope") + assert not model_hub_util.repo_exists("Nobody/No_This_Repo", "modelscope") + with pytest.raises(ValueError, match="Unsupported model hub"): + model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") + + +@pytest.mark.asyncio +async def test_a_repo_exists(model_hub_util): + assert await model_hub_util.a_repo_exists( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert not await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "huggingface") + with pytest.raises(ValueError, match="Unsupported model hub"): + model_hub_util.repo_exists("Nobody/No_This_Repo", "unknown_hub") + + assert await model_hub_util.a_repo_exists( + "qwen/Qwen1.5-72B-Chat-GGUF", "modelscope" + ) + assert not await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "modelscope") + with pytest.raises(ValueError, match="Unsupported model hub"): + await model_hub_util.a_repo_exists("Nobody/No_This_Repo", "unknown_hub") + + +def test_get_config_path(model_hub_util): + p = model_hub_util.get_config_path( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert p is not None + assert os.path.isfile(p) + + assert model_hub_util.get_config_path("Nobody/No_This_Repo", "huggingface") is None + + p = model_hub_util.get_config_path("qwen/Qwen1.5-72B-Chat-GGUF", "modelscope") + assert p is None + + p = model_hub_util.get_config_path("deepseek-ai/deepseek-vl-7b-chat", "modelscope") + assert p is not None + assert os.path.isfile(p) + + assert model_hub_util.get_config_path("Nobody/No_This_Repo", "modelscope") is None + + +@pytest.mark.asyncio +async def test_a_get_config_path_async(model_hub_util): + p = await model_hub_util.a_get_config_path( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert p is not None + assert os.path.isfile(p) + + assert ( + await model_hub_util.a_get_config_path("Nobody/No_This_Repo", "huggingface") + is None + ) + + p = await model_hub_util.a_get_config_path( + "qwen/Qwen1.5-72B-Chat-GGUF", "modelscope" + ) + assert p is None + + p = await model_hub_util.a_get_config_path( + "deepseek-ai/deepseek-vl-7b-chat", "modelscope" + ) + assert p is not None + assert os.path.isfile(p) + + assert ( + await model_hub_util.a_get_config_path("Nobody/No_This_Repo", "modelscope") + is None + ) + + +def test_list_repo_files(model_hub_util): + files = model_hub_util.list_repo_files( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert len(files) == 20 + + files = model_hub_util.list_repo_files( + "deepseek-ai/deepseek-vl-7b-chat", "modelscope" + ) + assert len(files) == 12 # the `.gitattributes` file is not included + + with pytest.raises(ValueError, match="Repository Nobody/No_This_Repo not found."): + model_hub_util.list_repo_files("Nobody/No_This_Repo", "huggingface") + + with pytest.raises(ValueError, match="Repository Nobody/No_This_Repo not found."): + model_hub_util.list_repo_files("Nobody/No_This_Repo", "modelscope") + + +@pytest.mark.asyncio +async def test_a_list_repo_files(model_hub_util): + files = await model_hub_util.a_list_repo_files( + "TheBloke/KafkaLM-70B-German-V0.1-GGUF", "huggingface" + ) + assert len(files) == 20 + + files = await model_hub_util.a_list_repo_files( + "deepseek-ai/deepseek-vl-7b-chat", "modelscope" + ) + assert len(files) == 12 # the `.gitattributes` file is not included + + with pytest.raises(ValueError, match="Repository Nobody/No_This_Repo not found."): + await model_hub_util.a_list_repo_files("Nobody/No_This_Repo", "huggingface") + + with pytest.raises(ValueError, match="Repository Nobody/No_This_Repo not found."): + await model_hub_util.a_list_repo_files("Nobody/No_This_Repo", "modelscope") diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index f404aba5e5..de2138c273 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -19,6 +19,14 @@ import uuid from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast +from huggingface_hub import HfApi +from huggingface_hub.utils import RepositoryNotFoundError +from modelscope import HubApi +from modelscope.hub.errors import NotExistError +from modelscope.hub.file_download import model_file_download +from requests import HTTPError +from typing_extensions import Literal + from ...types import ( SPECIAL_TOOL_PROMPT, ChatCompletion, @@ -27,6 +35,7 @@ Completion, CompletionChunk, ) +from ...utils import AsyncRunner from .llm_family import ( GgmlLLMSpecV1, LLMFamilyV1, @@ -676,3 +685,82 @@ def get_model_version( llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str ) -> str: return f"{llm_family.model_name}--{llm_spec.model_size_in_billions}B--{llm_spec.model_format}--{quantization}" + + +MODEL_HUB = Literal["huggingface", "modelscope"] + + +class ModelHubUtil(object): + def __init__(self): + self.__hf_api: Optional[HfApi] = None + self.__ms_api: Optional[HubApi] = None + self.__async_runner = AsyncRunner() + + @property + def _hf_api(self) -> HfApi: + if self.__hf_api is None: + self.__hf_api = HfApi() + return self.__hf_api + + @property + def _ms_api(self) -> HubApi: + if self.__ms_api is None: + self.__ms_api = HubApi() + return self.__ms_api + + def repo_exists(self, model_id: str, hub: MODEL_HUB) -> bool: + if hub == "huggingface": + return self._hf_api.repo_exists(model_id) + elif hub == "modelscope": + try: + self._ms_api.get_model(model_id) + return True + except (NotExistError, HTTPError): + return False + else: + raise ValueError("Unsupported model hub") + + async def a_repo_exists(self, model_id: str, hub: MODEL_HUB) -> bool: + return await self.__async_runner.async_run(self.repo_exists, model_id, hub) + + def get_config_path(self, model_id: str, hub: MODEL_HUB) -> Optional[str]: + if hub == "huggingface": + try: + return self._hf_api.hf_hub_download(model_id, "config.json") + except (ValueError, HTTPError) as e: + logging.error(e) + return None + elif hub == "modelscope": + try: + return model_file_download(model_id, "config.json") + except (NotExistError, HTTPError) as e: + logging.error(e) + return None + + async def a_get_config_path(self, model_id: str, hub: MODEL_HUB) -> Optional[str]: + return await self.__async_runner.async_run(self.get_config_path, model_id, hub) + + def list_repo_files(self, model_id: str, hub: MODEL_HUB) -> List[str]: + """ + List all files in the model repo. + + Notice: ModelScope does not return the hidden files which start with dot, + however, HuggingFace does. + """ + if hub == "huggingface": + try: + return self._hf_api.list_repo_files(model_id) + except RepositoryNotFoundError: + raise ValueError(f"Repository {model_id} not found.") + elif hub == "modelscope": + try: + return [ + entry["Path"] for entry in self._ms_api.get_model_files(model_id) + ] + except HTTPError: + raise ValueError(f"Repository {model_id} not found.") + else: + raise ValueError("Unsupported model hub") + + async def a_list_repo_files(self, model_id: str, hub: MODEL_HUB) -> List[str]: + return await self.__async_runner.async_run(self.list_repo_files, model_id, hub) diff --git a/xinference/tests/__init__.py b/xinference/tests/__init__.py new file mode 100644 index 0000000000..37f6558d95 --- /dev/null +++ b/xinference/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/tests/test_utils.py b/xinference/tests/test_utils.py new file mode 100644 index 0000000000..54b6057f96 --- /dev/null +++ b/xinference/tests/test_utils.py @@ -0,0 +1,43 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from concurrent.futures import Future + +import pytest +from typing_extensions import Coroutine + +from ..utils import AsyncRunner + + +@pytest.fixture +def async_runner(): + return AsyncRunner() + + +def test__thread_pool(async_runner): + assert async_runner._thread_pool is not None + + +def test_run_as_future(async_runner): + future = async_runner.run_as_future(lambda: 1) + assert isinstance(future, Future) + assert future.result() == 1 + + +def test_async_run(async_runner): + assert isinstance(async_runner.async_run(lambda: 1), Coroutine) + + +@pytest.mark.asyncio +async def test_async_run_a(async_runner): + assert await async_runner.async_run(lambda: 1) == 1 diff --git a/xinference/utils.py b/xinference/utils.py index 5b3741c222..bc36f01a5d 100644 --- a/xinference/utils.py +++ b/xinference/utils.py @@ -11,10 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import asyncio +from concurrent.futures import Future +from concurrent.futures.thread import ThreadPoolExecutor import torch +from typing_extensions import Callable, Optional, TypeVar def cuda_count(): # even if install torch cpu, this interface would return 0. return torch.cuda.device_count() + + +R = TypeVar("R") # Return type + + +class AsyncRunner(object): + def __init__(self): + self.__thread_pool: Optional[ThreadPoolExecutor] = None + + @property + def _thread_pool(self): + if self.__thread_pool is None: + self.__thread_pool = ThreadPoolExecutor(max_workers=1) + return self.__thread_pool + + def run_as_future(self, fn: Callable[..., R], *args, **kwargs) -> Future[R]: + return self._thread_pool.submit(fn, *args, **kwargs) + + async def async_run(self, fn: Callable[..., R], *args, **kwargs) -> R: + return await asyncio.wrap_future(self.run_as_future(fn, *args, **kwargs)) diff --git a/xinference/web/ui/package-lock.json b/xinference/web/ui/package-lock.json index 4ae5037245..ab3397754e 100644 --- a/xinference/web/ui/package-lock.json +++ b/xinference/web/ui/package-lock.json @@ -6924,9 +6924,9 @@ } }, "node_modules/caniuse-lite": { - "version": "1.0.30001515", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001515.tgz", - "integrity": "sha512-eEFDwUOZbE24sb+Ecsx3+OvNETqjWIdabMy52oOkIgcUtAsQifjUG9q4U9dgTHJM2mfk4uEPxc0+xuFdJ629QA==", + "version": "1.0.30001599", + "resolved": "https://mirrors.cloud.tencent.com/npm/caniuse-lite/-/caniuse-lite-1.0.30001599.tgz", + "integrity": "sha512-LRAQHZ4yT1+f9LemSMeqdMpMxZcc4RMWdj4tiFe3G8tNkWK+E58g+/tzotb5cU6TbcVJLr4fySiAW7XmxQvZQA==", "funding": [ { "type": "opencollective", diff --git a/xinference/web/ui/src/scenes/register_model/index.js b/xinference/web/ui/src/scenes/register_model/index.js index de3c04dc43..d211c223c9 100644 --- a/xinference/web/ui/src/scenes/register_model/index.js +++ b/xinference/web/ui/src/scenes/register_model/index.js @@ -1,322 +1,34 @@ import { TabContext, TabList, TabPanel } from '@mui/lab' import { Box, - Checkbox, - FormControl, - FormControlLabel, - FormHelperText, - Radio, - RadioGroup, Tab, } from '@mui/material' -import Alert from '@mui/material/Alert' -import AlertTitle from '@mui/material/AlertTitle' -import Button from '@mui/material/Button' -import TextField from '@mui/material/TextField' -import React, { useContext, useEffect, useState } from 'react' +import React, { useEffect } from 'react' import { useCookies } from 'react-cookie' import { useNavigate } from 'react-router-dom' -import { ApiContext } from '../../components/apiContext' import ErrorMessageSnackBar from '../../components/errorMessageSnackBar' -import fetcher from '../../components/fetcher' import Title from '../../components/Title' -import { useMode } from '../../theme' import RegisterEmbeddingModel from './register_embedding' +import RegisterLanguageModel from './register_language' import RegisterRerankModel from './register_rerank' -const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } -const SUPPORTED_FEATURES = ['Generate', 'Chat'] - -// Convert dictionary of supported languages into list -const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) const RegisterModel = () => { - const ERROR_COLOR = useMode() - const endPoint = useContext(ApiContext).endPoint - const { setErrorMsg } = useContext(ApiContext) - const [successMsg, setSuccessMsg] = useState('') - const [modelFormat, setModelFormat] = useState('pytorch') - const [modelSize, setModelSize] = useState(7) - const [modelUri, setModelUri] = useState('/path/to/llama-2') - const [quantization, setQuantization] = useState('') - const [formData, setFormData] = useState({ - version: 1, - context_length: 2048, - model_name: 'custom-llama-2', - model_lang: ['en'], - model_ability: ['generate'], - model_description: 'This is a custom model description.', - model_family: '', - model_specs: [], - prompt_style: undefined, - }) - const [promptStyles, setPromptStyles] = useState([]) - const [family, setFamily] = useState({ - chat: [], - generate: [], - }) - const [familyLabel, setFamilyLabel] = useState('') const [tabValue, setTabValue] = React.useState('1') const [cookie] = useCookies(['token']) const navigate = useNavigate() - const errorModelName = formData.model_name.trim().length <= 0 - const errorModelDescription = formData.model_description.length < 0 - const errorContextLength = formData.context_length === 0 - const errorLanguage = - formData.model_lang === undefined || formData.model_lang.length === 0 - const errorAbility = - formData.model_ability === undefined || formData.model_ability.length === 0 - const errorModelSize = - formData.model_specs && - formData.model_specs.some((spec) => { - return ( - spec.model_size_in_billions === undefined || - spec.model_size_in_billions === 0 - ) - }) - const errorFamily = familyLabel === '' - const errorAny = - errorModelName || - errorModelDescription || - errorContextLength || - errorLanguage || - errorAbility || - errorModelSize || - errorFamily - useEffect(() => { if (cookie.token === '' || cookie.token === undefined) { return } if (cookie.token === 'need_auth') { navigate('/login', { replace: true }) - return } - const getBuiltinFamilies = async () => { - const response = await fetch(endPoint + '/v1/models/families', { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, - }) - if (!response.ok) { - const errorData = await response.json() // Assuming the server returns error details in JSON format - setErrorMsg( - `Server error: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - } else { - const data = await response.json() - data.chat.push('other') - data.generate.push('other') - setFamily(data) - } - } - - const getBuiltInPromptStyles = async () => { - const response = await fetch(endPoint + '/v1/models/prompts', { - method: 'GET', - headers: { - 'Content-Type': 'application/json', - }, - }) - if (!response.ok) { - const errorData = await response.json() // Assuming the server returns error details in JSON format - setErrorMsg( - `Server error: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - } else { - const data = await response.json() - let res = [] - for (const key in data) { - let v = data[key] - v['name'] = key - res.push(v) - } - setPromptStyles(res) - } - } - // avoid keep requesting backend to get prompts - if (promptStyles.length === 0) { - getBuiltInPromptStyles().catch((error) => { - setErrorMsg( - error.message || - 'An unexpected error occurred when getting builtin prompt styles.' - ) - console.error('Error: ', error) - }) - } - if (family.chat.length === 0) { - getBuiltinFamilies().catch((error) => { - setErrorMsg( - error.message || - 'An unexpected error occurred when getting builtin prompt styles.' - ) - console.error('Error: ', error) - }) - } }, [cookie.token]) - const getFamilyByAbility = () => { - if (formData.model_ability.includes('chat')) { - return family.chat - } else { - return family.generate - } - } - - const isModelFormatPytorch = () => { - return modelFormat === 'pytorch' - } - - const isModelFormatGPTQ = () => { - return modelFormat === 'gptq' - } - - const isModelFormatAWQ = () => { - return modelFormat === 'awq' - } - - const getPathComponents = (path) => { - const normalizedPath = path.replace(/\\/g, '/') - const baseDir = normalizedPath.substring(0, normalizedPath.lastIndexOf('/')) - const filename = normalizedPath.substring( - normalizedPath.lastIndexOf('/') + 1 - ) - return { baseDir, filename } - } - - const handleClick = async () => { - if (isModelFormatGPTQ()) { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_uri: modelUri, - }, - ] - } else if (isModelFormatAWQ()) { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_uri: modelUri, - }, - ] - } else if (!isModelFormatPytorch()) { - const { baseDir, filename } = getPathComponents(modelUri) - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: [quantization], - model_id: '', - model_file_name_template: filename, - model_uri: baseDir, - }, - ] - } else { - formData.model_specs = [ - { - model_format: modelFormat, - model_size_in_billions: modelSize, - quantizations: ['4-bit', '8-bit', 'none'], - model_id: '', - model_uri: modelUri, - }, - ] - } - - formData.model_family = familyLabel - - if (formData.model_ability.includes('chat')) { - const ps = promptStyles.find((item) => item.name === familyLabel) - if (ps) { - formData.prompt_style = { - style_name: ps.style_name, - system_prompt: ps.system_prompt, - roles: ps.roles, - intra_message_sep: ps.intra_message_sep, - inter_message_sep: ps.inter_message_sep, - stop: ps.stop ?? null, - stop_token_ids: ps.stop_token_ids ?? null, - } - } - } - - if (errorAny) { - setErrorMsg('Please fill in valid value for all fields') - return - } - - try { - const response = await fetcher(endPoint + '/v1/model_registrations/LLM', { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - model: JSON.stringify(formData), - persist: true, - }), - }) - if (!response.ok) { - const errorData = await response.json() // Assuming the server returns error details in JSON format - setErrorMsg( - `Server error: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - } else { - setSuccessMsg( - 'Model has been registered successfully! Navigate to launch model page to proceed.' - ) - } - } catch (error) { - console.error('There was a problem with the fetch operation:', error) - setErrorMsg(error.message || 'An unexpected error occurred.') - } - } - - const toggleLanguage = (lang) => { - if (formData.model_lang.includes(lang)) { - setFormData({ - ...formData, - model_lang: formData.model_lang.filter((l) => l !== lang), - }) - } else { - setFormData({ - ...formData, - model_lang: [...formData.model_lang, lang], - }) - } - } - - const toggleAbility = (ability) => { - setFamilyLabel('') - if (formData.model_ability.includes(ability)) { - setFormData({ - ...formData, - model_ability: formData.model_ability.filter((a) => a !== ability), - }) - } else { - setFormData({ - ...formData, - model_ability: [...formData.model_ability, ability], - }) - } - } - return ( @@ -336,284 +48,7 @@ const RegisterModel = () => { </TabList> </Box> <TabPanel value="1" sx={{ padding: 0 }}> - <Box padding="20px"></Box> - {/* Base Information */} - <FormControl sx={styles.baseFormControl}> - <TextField - label="Model Name" - error={errorModelName} - defaultValue={formData.model_name} - size="small" - helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." - onChange={(event) => - setFormData({ ...formData, model_name: event.target.value }) - } - /> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - }} - > - Model Format - </label> - - <RadioGroup - value={modelFormat} - onChange={(e) => { - setModelFormat(e.target.value) - }} - > - <Box sx={styles.checkboxWrapper}> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="pytorch" - control={<Radio />} - label="PyTorch" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="ggmlv3" - control={<Radio />} - label="GGML" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="ggufv2" - control={<Radio />} - label="GGUF" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="gptq" - control={<Radio />} - label="GPTQ" - /> - </Box> - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel - value="awq" - control={<Radio />} - label="AWQ" - /> - </Box> - </Box> - </RadioGroup> - <Box padding="15px"></Box> - - <TextField - error={errorContextLength} - label="Context Length" - value={formData.context_length} - size="small" - onChange={(event) => { - let value = event.target.value - // Remove leading zeros - if (/^0+/.test(value)) { - value = value.replace(/^0+/, '') || '0' - } - // Ensure it's a positive integer, if not set it to the minimum - if (!/^\d+$/.test(value) || parseInt(value) < 0) { - value = '0' - } - // Update with the processed value - setFormData({ - ...formData, - context_length: Number(value), - }) - }} - /> - <Box padding="15px"></Box> - - <TextField - label="Model Size in Billions" - size="small" - error={errorModelSize} - value={modelSize} - onChange={(e) => { - let value = e.target.value - // Remove leading zeros - if (/^0+/.test(value)) { - value = value.replace(/^0+/, '') || '0' - } - // Ensure it's a positive integer, if not set it to the minimum - if (!/^\d+$/.test(value) || parseInt(value) < 0) { - value = '0' - } - setModelSize(Number(value)) - }} - /> - <Box padding="15px"></Box> - - <TextField - label="Model Path" - size="small" - value={modelUri} - onChange={(e) => { - setModelUri(e.target.value) - }} - helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path." - /> - <Box padding="15px"></Box> - - <TextField - label="Quantization (Optional)" - size="small" - value={quantization} - onChange={(e) => { - setQuantization(e.target.value) - }} - helperText="For GPTQ/AWQ models, please be careful to fill in the quantization corresponding to the model you want to register." - /> - <Box padding="15px"></Box> - - <TextField - label="Model Description (Optional)" - error={errorModelDescription} - defaultValue={formData.model_description} - size="small" - onChange={(event) => - setFormData({ - ...formData, - model_description: event.target.value, - }) - } - /> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - color: errorLanguage ? ERROR_COLOR : 'inherit', - }} - > - Model Languages - </label> - <Box sx={styles.checkboxWrapper}> - {SUPPORTED_LANGUAGES.map((lang) => ( - <Box key={lang} sx={{ marginRight: '10px' }}> - <FormControlLabel - control={ - <Checkbox - checked={formData.model_lang.includes(lang)} - onChange={() => toggleLanguage(lang)} - name={lang} - sx={ - errorLanguage - ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } - : {} - } - /> - } - label={SUPPORTED_LANGUAGES_DICT[lang]} - style={{ - paddingLeft: 10, - color: errorLanguage ? ERROR_COLOR : 'inherit', - }} - /> - </Box> - ))} - </Box> - <Box padding="15px"></Box> - - <label - style={{ - paddingLeft: 5, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - > - Model Abilities - </label> - <Box sx={styles.checkboxWrapper}> - {SUPPORTED_FEATURES.map((ability) => ( - <Box key={ability} sx={{ marginRight: '10px' }}> - <FormControlLabel - control={ - <Checkbox - checked={formData.model_ability.includes( - ability.toLowerCase() - )} - onChange={() => toggleAbility(ability.toLowerCase())} - name={ability} - sx={ - errorAbility - ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } - : {} - } - /> - } - label={ability} - style={{ - paddingLeft: 10, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - /> - </Box> - ))} - </Box> - <Box padding="15px"></Box> - </FormControl> - - <FormControl sx={styles.baseFormControl}> - <label - style={{ - paddingLeft: 5, - color: errorAbility ? ERROR_COLOR : 'inherit', - }} - > - Model Family - </label> - <FormHelperText> - Please be careful to select the family name corresponding to the - model you want to register. If not found, please choose `other`. - </FormHelperText> - <RadioGroup - value={familyLabel} - onChange={(e) => { - setFamilyLabel(e.target.value) - }} - > - <Box sx={styles.checkboxWrapper}> - {getFamilyByAbility().map((v) => ( - <Box sx={{ marginLeft: '10px' }}> - <FormControlLabel value={v} control={<Radio />} label={v} /> - </Box> - ))} - </Box> - </RadioGroup> - <Box padding="15px"></Box> - </FormControl> - - <Box width={'100%'}> - {successMsg !== '' && ( - <Alert severity="success"> - <AlertTitle>Success</AlertTitle> - {successMsg} - </Alert> - )} - <Button - variant="contained" - color="primary" - type="submit" - onClick={handleClick} - > - Register Model - </Button> - </Box> + <RegisterLanguageModel /> </TabPanel> <TabPanel value="2" sx={{ padding: 0 }}> <RegisterEmbeddingModel /> @@ -627,32 +62,3 @@ const RegisterModel = () => { } export default RegisterModel - -const styles = { - baseFormControl: { - width: '100%', - margin: 'normal', - size: 'small', - }, - checkboxWrapper: { - display: 'flex', - flexWrap: 'wrap', - maxWidth: '80%', - }, - labelPaddingLeft: { - paddingLeft: 5, - }, - formControlLabelPaddingLeft: { - paddingLeft: 10, - }, - buttonBox: { - width: '100%', - margin: '20px', - }, - error: { - fontWeight: 'bold', - margin: '5px 0', - padding: '1px', - borderRadius: '5px', - }, -} diff --git a/xinference/web/ui/src/scenes/register_model/register_embedding.js b/xinference/web/ui/src/scenes/register_model/register_embedding.js index ac7ab8d4ae..87531813ed 100644 --- a/xinference/web/ui/src/scenes/register_model/register_embedding.js +++ b/xinference/web/ui/src/scenes/register_model/register_embedding.js @@ -1,4 +1,14 @@ -import { Box, Checkbox, FormControl, FormControlLabel } from '@mui/material' +import { + Box, + Checkbox, + FormControl, + FormControlLabel, + InputLabel, + MenuItem, + Radio, + RadioGroup, + Select, +} from '@mui/material' import Alert from '@mui/material/Alert' import AlertTitle from '@mui/material/AlertTitle' import Button from '@mui/material/Button' @@ -13,20 +23,33 @@ const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } // Convert dictionary of supported languages into list const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) + +const SUPPORTED_HUBS_DICT = { huggingface: 'HuggingFace', modelscope: 'ModelScope' } +const SUPPORTED_HUBS = Object.keys(SUPPORTED_HUBS_DICT) + +const SOURCES_DICT = { self_hosted: 'Self Hosted', hub: 'Hub' } +const SOURCES = Object.keys(SOURCES_DICT) + const RegisterEmbeddingModel = () => { const ERROR_COLOR = useMode() const endPoint = useContext(ApiContext).endPoint const { setErrorMsg } = useContext(ApiContext) const [successMsg, setSuccessMsg] = useState('') + const [modelSource, setModelSource] = useState(SOURCES[0]) + const [hub, setHub] = useState(SUPPORTED_HUBS[0]) + const [modelId, setModelId] = useState('') const [formData, setFormData] = useState({ model_name: 'custom-embedding', dimensions: 768, max_tokens: 512, language: ['en'], model_uri: '/path/to/embedding-model', + model_id: null, + model_hub: null, }) const errorModelName = formData.model_name.trim().length <= 0 + const errorModelId = modelSource === 'hub' && modelId.search('\\w+/\\w+') === -1 const errorDimensions = formData.dimensions < 0 const errorMaxTokens = formData.max_tokens < 0 const errorLanguage = @@ -34,13 +57,27 @@ const RegisterEmbeddingModel = () => { const handleClick = async () => { const errorAny = - errorModelName || errorDimensions || errorMaxTokens || errorLanguage + errorModelName || errorDimensions || errorMaxTokens || errorLanguage || errorModelId if (errorAny) { setErrorMsg('Please fill in valid value for all fields') return } + let myFormData + if (modelSource === 'self_hosted') { + myFormData = { + ...formData, + model_hub: null, + model_id: null, + } + } else { + myFormData = { + ...formData, + model_uri: null, + } + } + console.log(myFormData) try { const response = await fetcher( endPoint + '/v1/model_registrations/embedding', @@ -50,21 +87,21 @@ const RegisterEmbeddingModel = () => { 'Content-Type': 'application/json', }, body: JSON.stringify({ - model: JSON.stringify(formData), + model: JSON.stringify(myFormData), persist: true, }), - } + }, ) if (!response.ok) { const errorData = await response.json() // Assuming the server returns error details in JSON format setErrorMsg( `Server error: ${response.status} - ${ errorData.detail || 'Unknown error' - }` + }`, ) } else { setSuccessMsg( - 'Model has been registered successfully! Navigate to launch model page to proceed.' + 'Model has been registered successfully! Navigate to launch model page to proceed.', ) } } catch (error) { @@ -87,6 +124,40 @@ const RegisterEmbeddingModel = () => { } } + const handleImportModel = async () => { + if (errorModelId) { + setErrorMsg('Please fill in valid value for Model Id') + return + } + const response = await fetcher(endPoint + + `/v1/model_registrations/embedding/${hub}/_/${modelId}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }) + + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }`, + ) + } else { + const data = await response.json() + setFormData({ + ...formData, + dimensions: data.dimensions, + max_tokens: data.max_tokens, + language: data.language, + model_hub: hub, + model_id: modelId, + }) + + } + } + return ( <React.Fragment> <Box padding="20px"></Box> @@ -104,6 +175,91 @@ const RegisterEmbeddingModel = () => { /> <Box padding="15px"></Box> + <label + style={{ + paddingLeft: 5, + }} + > + Model Source + </label> + + <RadioGroup + value={modelSource} + onChange={(e) => { + setModelSource(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + {SOURCES.map((item) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value={item} + control={<Radio />} + label={SOURCES_DICT[item]} + /> + </Box> + ))} + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + {modelSource === 'self_hosted' && + <TextField + label="Model Path" + size="small" + value={formData.model_uri} + onChange={(e) => { + setFormData({ + ...formData, + model_uri: e.target.value, + }) + }} + helperText="Provide the model directory path." + />} + + {modelSource === 'hub' && + <Box sx={styles.checkboxWrapper}> + + <TextField + sx={{ width: '400px' }} + label="Model Id" + size="small" + error={errorModelId} + value={modelId} + onChange={(e) => { + setModelId(e.target.value) + }} + placeholder="user/repo" + /> + + <FormControl variant="standard" + sx={{ marginLeft: '10px' }}> + <InputLabel id="hub-label">Hub</InputLabel> + <Select + labelId="hub-label" + value={hub} + label="Hub" + onChange={(e) => { + setHub(e.target.value) + }} + > + {SUPPORTED_HUBS.map((item) => ( + <MenuItem value={item}>{SUPPORTED_HUBS_DICT[item]}</MenuItem> + ))} + </Select> + </FormControl> + <Button + sx={{ marginLeft: '10px' }} + variant="contained" + color="primary" + onClick={handleImportModel} + > + Import Model + </Button> + </Box> + } + <Box padding="15px"></Box> + <TextField error={errorDimensions} label="Dimensions" @@ -132,20 +288,6 @@ const RegisterEmbeddingModel = () => { /> <Box padding="15px"></Box> - <TextField - label="Model Path" - size="small" - value={formData.model_uri} - onChange={(e) => { - setFormData({ - ...formData, - model_uri: e.target.value, - }) - }} - helperText="Provide the model directory path." - /> - <Box padding="15px"></Box> - <label style={{ paddingLeft: 5, @@ -166,11 +308,11 @@ const RegisterEmbeddingModel = () => { sx={ errorLanguage ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } : {} } /> diff --git a/xinference/web/ui/src/scenes/register_model/register_language.js b/xinference/web/ui/src/scenes/register_model/register_language.js new file mode 100644 index 0000000000..2ac956af79 --- /dev/null +++ b/xinference/web/ui/src/scenes/register_model/register_language.js @@ -0,0 +1,861 @@ +import { + Box, + Checkbox, + FormControl, + FormControlLabel, + FormHelperText, + InputLabel, + MenuItem, + Radio, + RadioGroup, + Select, +} from '@mui/material' +import Alert from '@mui/material/Alert' +import AlertTitle from '@mui/material/AlertTitle' +import Button from '@mui/material/Button' +import TextField from '@mui/material/TextField' +import React, { useContext, useEffect, useState } from 'react' +import { useCookies } from 'react-cookie' + +import { ApiContext } from '../../components/apiContext' +import fetcher from '../../components/fetcher' +import { useMode } from '../../theme' + +const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } +const SUPPORTED_FEATURES = ['Generate', 'Chat'] + +const SUPPORTED_HUBS_DICT = { huggingface: 'HuggingFace', modelscope: 'ModelScope' } +const SUPPORTED_HUBS = Object.keys(SUPPORTED_HUBS_DICT) + +const SOURCES_DICT = { self_hosted: 'Self Hosted', hub: 'Hub' } +const SOURCES = Object.keys(SOURCES_DICT) + +// Convert dictionary of supported languages into list +const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) + +const RegisterLanguageModel = () => { + const ERROR_COLOR = useMode() + const endPoint = useContext(ApiContext).endPoint + const { setErrorMsg } = useContext(ApiContext) + const [successMsg, setSuccessMsg] = useState('') + const [modelFormat, setModelFormat] = useState('pytorch') + const [modelFileNameTemplate, setModelFileNameTemplate] = useState('') + const [modelFileNameSplitTemplate, setModelFileNameSplitTemplate] = useState('') + const [modelSize, setModelSize] = useState(7) + const [modelUri, setModelUri] = useState('/path/to/llama-2') + const [modelId, setModelId] = useState('') + const [quantization, setQuantization] = useState('') + const [quantizationParts, setQuantizationParts] = useState('') + const [modelSource, setModelSource] = useState(SOURCES[0]) + const [hub, setHub] = useState(SUPPORTED_HUBS[0]) + const [formData, setFormData] = useState({ + version: 1, + context_length: 2048, + model_name: 'custom-llama-2', + model_lang: ['en'], + model_ability: ['generate'], + model_description: 'This is a custom model description.', + model_family: '', + model_specs: [], + prompt_style: undefined, + }) + const [promptStyles, setPromptStyles] = useState([]) + const [family, setFamily] = useState({ + chat: [], + generate: [], + }) + const [familyLabel, setFamilyLabel] = useState('') + + const [cookie] = useCookies(['token']) + const errorModelName = formData.model_name.trim().length <= 0 + const errorModelDescription = formData.model_description.length < 0 + const errorContextLength = formData.context_length === 0 + const errorLanguage = + formData.model_lang === undefined || formData.model_lang.length === 0 + const errorAbility = + formData.model_ability === undefined || formData.model_ability.length === 0 + const errorModelSize = + formData.model_specs && + formData.model_specs.some((spec) => { + return ( + spec.model_size_in_billions === undefined || + spec.model_size_in_billions === 0 + ) + }) + const errorFamily = familyLabel === '' + const errorModelId = modelSource === 'hub' && modelId.search('\\w+/\\w+') === -1 + const errorModelFileNameTemplate = modelSource === 'hub' && ['ggufv2', 'ggmlv3'].includes(modelFormat) && + modelFileNameTemplate.trim().length <= 0 + const errorQuantizationParts = modelSource === 'hub' && ['ggufv2', 'ggmlv3'].includes(modelFormat) && + modelFileNameSplitTemplate.trim().length > 0 && quantizationParts.trim().length <= 0 + const errorAny = + errorModelName || + errorModelDescription || + errorContextLength || + errorLanguage || + errorAbility || + errorModelSize || + errorFamily || + errorModelId || + errorModelFileNameTemplate || + errorQuantizationParts + + useEffect(() => { + if (cookie.token === '' || cookie.token === undefined) { + return + } + + const getBuiltinFamilies = async () => { + const response = await fetch(endPoint + '/v1/models/families', { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }`, + ) + } else { + const data = await response.json() + data.chat.push('other') + data.generate.push('other') + setFamily(data) + } + } + + const getBuiltInPromptStyles = async () => { + const response = await fetch(endPoint + '/v1/models/prompts', { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }`, + ) + } else { + const data = await response.json() + let res = [] + for (const key in data) { + let v = data[key] + v['name'] = key + res.push(v) + } + setPromptStyles(res) + } + } + // avoid keep requesting backend to get prompts + if (promptStyles.length === 0) { + getBuiltInPromptStyles().catch((error) => { + setErrorMsg( + error.message || + 'An unexpected error occurred when getting builtin prompt styles.', + ) + console.error('Error: ', error) + }) + } + if (family.chat.length === 0) { + getBuiltinFamilies().catch((error) => { + setErrorMsg( + error.message || + 'An unexpected error occurred when getting builtin prompt styles.', + ) + console.error('Error: ', error) + }) + } + }, [cookie.token]) + + const getFamilyByAbility = () => { + if (formData.model_ability.includes('chat')) { + return family.chat + } else { + return family.generate + } + } + + const isModelFormatPytorch = () => { + return modelFormat === 'pytorch' + } + + const isModelFormatGPTQ = () => { + return modelFormat === 'gptq' + } + + const isModelFormatAWQ = () => { + return modelFormat === 'awq' + } + + const getPathComponents = (path) => { + const normalizedPath = path.replace(/\\/g, '/') + const baseDir = normalizedPath.substring(0, normalizedPath.lastIndexOf('/')) + const filename = normalizedPath.substring( + normalizedPath.lastIndexOf('/') + 1, + ) + return { baseDir, filename } + } + + const handleClick = async () => { + if (modelSource === 'self_hosted') { + if (isModelFormatGPTQ()) { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_uri: modelUri, + }, + ] + } else if (isModelFormatAWQ()) { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_uri: modelUri, + }, + ] + } else if (!isModelFormatPytorch()) { + const { baseDir, filename } = getPathComponents(modelUri) + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: [quantization], + model_id: '', + model_file_name_template: filename, + model_uri: baseDir, + }, + ] + } else { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: ['4-bit', '8-bit', 'none'], + model_id: '', + model_uri: modelUri, + }, + ] + } + } else if (modelSource === 'hub') { + const quantization_array = quantization.split(',') + if (isModelFormatGPTQ() || isModelFormatAWQ()) { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: quantization_array, + model_hub: hub, + model_id: modelId, + model_uri: null, + }, + ] + } else if (!isModelFormatPytorch()) { + const qParts = quantizationParts.length > 0 ? JSON.parse(quantizationParts) : null + let splitTemplate = modelFileNameSplitTemplate.trim() + splitTemplate = splitTemplate.length > 0 ? splitTemplate : null + + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + model_file_name_template: modelFileNameTemplate, + model_file_name_split_template: splitTemplate, + quantizations: quantization_array, + quantization_parts: qParts, + model_hub: hub, + model_id: modelId, + model_uri: null, + }, + ] + } else { + formData.model_specs = [ + { + model_format: modelFormat, + model_size_in_billions: modelSize, + quantizations: ['4-bit', '8-bit', 'none'], + model_hub: hub, + model_id: modelId, + model_uri: null, + }, + ] + } + } + + formData.model_family = familyLabel + + if (formData.model_ability.includes('chat')) { + const ps = promptStyles.find((item) => item.name === familyLabel) + if (ps) { + formData.prompt_style = { + style_name: ps.style_name, + system_prompt: ps.system_prompt, + roles: ps.roles, + intra_message_sep: ps.intra_message_sep, + inter_message_sep: ps.inter_message_sep, + stop: ps.stop ?? null, + stop_token_ids: ps.stop_token_ids ?? null, + } + } + } + + if (errorAny) { + setErrorMsg('Please fill in valid value for all fields') + return + } + + try { + const response = await fetcher(endPoint + '/v1/model_registrations/LLM', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: JSON.stringify(formData), + persist: true, + }), + }) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${ + errorData.detail || 'Unknown error' + }`, + ) + } else { + setSuccessMsg( + 'Model has been registered successfully! Navigate to launch model page to proceed.', + ) + } + } catch (error) { + console.error('There was a problem with the fetch operation:', error) + setErrorMsg(error.message || 'An unexpected error occurred.') + } + } + + const handleImportModel = async () => { + if (errorModelId) { + setErrorMsg('Please fill in valid value for Model Id') + return + } + const response = await fetcher(endPoint + + `/v1/model_registrations/LLM/${hub}/${modelFormat}/${modelId}`, { + method: 'GET', + headers: { + 'Content-Type': 'application/json', + }, + }) + if (!response.ok) { + const errorData = await response.json() // Assuming the server returns error details in JSON format + setErrorMsg( + `Server error: ${response.status} - ${errorData.detail || 'Unknown error'}`, + ) + } else { + const body = await response.json() + console.log('response', body) + if ('context_length' in body && body['context_length'] > 0) { + setFormData({ + ...formData, + context_length: Number(body['context_length']), + }) + } + + /** + * @type {object[]} + */ + const modelSpecs = body['model_specs'] + if (modelSpecs.length === 0) { + return + } + const modelSpec = modelSpecs[0] + + const modelSize = modelSpec['model_size_in_billions'] + setModelSize(modelSize) + + if (['ggufv2', 'ggmlv3'].includes(modelFormat)) { + + const modelFileNameTemplate = modelSpec['model_file_name_template'] + setModelFileNameTemplate(modelFileNameTemplate) + + const quantizations = modelSpec['quantizations'] + setQuantization(quantizations.join(',')) + + /** + * @type {string | null} + */ + const modelFileNameSplitTemplate = modelSpec['model_file_name_split_template'] + if (modelFileNameSplitTemplate !== null && modelFileNameSplitTemplate.trim() !== '') { + setModelFileNameSplitTemplate(modelFileNameSplitTemplate) + const parts = JSON.stringify(modelSpec['quantization_parts']) + setQuantizationParts(parts) + } + } + } + } + + const toggleLanguage = (lang) => { + if (formData.model_lang.includes(lang)) { + setFormData({ + ...formData, + model_lang: formData.model_lang.filter((l) => l !== lang), + }) + } else { + setFormData({ + ...formData, + model_lang: [...formData.model_lang, lang], + }) + } + } + + const toggleAbility = (ability) => { + setFamilyLabel('') + if (formData.model_ability.includes(ability)) { + setFormData({ + ...formData, + model_ability: formData.model_ability.filter((a) => a !== ability), + }) + } else { + setFormData({ + ...formData, + model_ability: [...formData.model_ability, ability], + }) + } + } + + return ( + <React.Fragment> + <Box padding="20px"></Box> + {/* Base Information */} + <FormControl sx={styles.baseFormControl}> + <TextField + label="Model Name" + error={errorModelName} + defaultValue={formData.model_name} + size="small" + helperText="Alphanumeric characters with properly placed hyphens and underscores. Must not match any built-in model names." + onChange={(event) => + setFormData({ ...formData, model_name: event.target.value }) + } + /> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + }} + > + Model Format + </label> + + <RadioGroup + value={modelFormat} + onChange={(e) => { + setModelFormat(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="pytorch" + control={<Radio />} + label="PyTorch" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="ggmlv3" + control={<Radio />} + label="GGML" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="ggufv2" + control={<Radio />} + label="GGUF" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="gptq" + control={<Radio />} + label="GPTQ" + /> + </Box> + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value="awq" + control={<Radio />} + label="AWQ" + /> + </Box> + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + }} + > + Model Source + </label> + + <RadioGroup + value={modelSource} + onChange={(e) => { + setModelSource(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + {SOURCES.map((item) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value={item} + control={<Radio />} + label={SOURCES_DICT[item]} + /> + </Box> + ))} + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + {modelSource === 'self_hosted' && + <TextField + label="Model Path" + size="small" + value={modelUri} + onChange={(e) => { + setModelUri(e.target.value) + }} + helperText="For PyTorch, provide the model directory. For GGML/GGUF, provide the model file path." + />} + {modelSource === 'hub' && + <Box sx={styles.checkboxWrapper}> + + + <TextField + sx={{ width: '400px' }} + label="Model Id" + size="small" + error={errorModelId} + value={modelId} + onChange={(e) => { + setModelId(e.target.value) + }} + placeholder="user/repo" + /> + + <FormControl variant="standard" + sx={{ marginLeft: '10px' }}> + <InputLabel id="hub-label">Hub</InputLabel> + <Select + labelId="hub-label" + value={hub} + label="Hub" + onChange={(e) => { + setHub(e.target.value) + }} + > + {SUPPORTED_HUBS.map((item) => ( + <MenuItem value={item}>{SUPPORTED_HUBS_DICT[item]}</MenuItem> + ))} + </Select> + </FormControl> + <Button + sx={{ marginLeft: '10px' }} + variant="contained" + color="primary" + onClick={handleImportModel} + > + Import Model + </Button> + </Box> + } + <Box padding="15px"></Box> + + + <TextField + error={errorContextLength} + label="Context Length" + value={formData.context_length} + size="small" + onChange={(event) => { + let value = event.target.value + // Remove leading zeros + if (/^0+/.test(value)) { + value = value.replace(/^0+/, '') || '0' + } + // Ensure it's a positive integer, if not set it to the minimum + if (!/^\d+$/.test(value) || parseInt(value) < 0) { + value = '0' + } + // Update with the processed value + setFormData({ + ...formData, + context_length: Number(value), + }) + }} + /> + <Box padding="15px"></Box> + + <TextField + label="Model Size in Billions" + size="small" + error={errorModelSize} + value={modelSize} + onChange={(e) => { + let value = e.target.value + // Remove leading zeros + if (/^0+/.test(value)) { + value = value.replace(/^0+/, '') || '0' + } + // Ensure it's a positive integer, if not set it to the minimum + if (!/^\d+$/.test(value) || parseInt(value) < 0) { + value = '0' + } + setModelSize(Number(value)) + }} + /> + <Box padding="15px"></Box> + + {modelSource === 'hub' && ['ggufv2', 'ggmlv3'].includes(modelFormat) && + <> + <TextField + label="Model File Name Template" + size="small" + value={modelFileNameTemplate} + onChange={(e) => { + setModelFileNameTemplate(e.target.value) + }} + error={errorModelFileNameTemplate} + /> + <Box padding="15px"></Box> + <TextField + label="Model File Name Split Template (Optional)" + size="small" + value={modelFileNameSplitTemplate} + onChange={(e) => { + setModelFileNameSplitTemplate(e.target.value) + }} + /> + <Box padding="15px"></Box> + </> + } + + <TextField + label="Quantization (Optional)" + size="small" + value={quantization} + onChange={(e) => { + setQuantization(e.target.value) + }} + helperText="For GPTQ/AWQ models, please be careful to fill in the quantization corresponding to the model you want to register." + /> + <Box padding="15px"></Box> + + {modelSource === 'hub' && ['ggufv2', 'ggmlv3'].includes(modelFormat) && + modelFileNameSplitTemplate.trim().length > 0 && + <> + <TextField + label="Quantization Parts (Optional)" + size="small" + value={quantizationParts} + error={errorQuantizationParts} + onChange={(e) => { + setQuantizationParts(e.target.value.trim()) + }} + helperText="If there is more than 1 quantization parts, separated by commas" + /> + <Box padding="15px"></Box> + </> + } + + <TextField + label="Model Description (Optional)" + error={errorModelDescription} + defaultValue={formData.model_description} + size="small" + onChange={(event) => + setFormData({ + ...formData, + model_description: event.target.value, + }) + } + /> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + > + Model Languages + </label> + <Box sx={styles.checkboxWrapper}> + {SUPPORTED_LANGUAGES.map((lang) => ( + <Box key={lang} sx={{ marginRight: '10px' }}> + <FormControlLabel + control={ + <Checkbox + checked={formData.model_lang.includes(lang)} + onChange={() => toggleLanguage(lang)} + name={lang} + sx={ + errorLanguage + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } + : {} + } + /> + } + label={SUPPORTED_LANGUAGES_DICT[lang]} + style={{ + paddingLeft: 10, + color: errorLanguage ? ERROR_COLOR : 'inherit', + }} + /> + </Box> + ))} + </Box> + <Box padding="15px"></Box> + + <label + style={{ + paddingLeft: 5, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + > + Model Abilities + </label> + <Box sx={styles.checkboxWrapper}> + {SUPPORTED_FEATURES.map((ability) => ( + <Box key={ability} sx={{ marginRight: '10px' }}> + <FormControlLabel + control={ + <Checkbox + checked={formData.model_ability.includes( + ability.toLowerCase(), + )} + onChange={() => toggleAbility(ability.toLowerCase())} + name={ability} + sx={ + errorAbility + ? { + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } + : {} + } + /> + } + label={ability} + style={{ + paddingLeft: 10, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + /> + </Box> + ))} + </Box> + <Box padding="15px"></Box> + </FormControl> + + <FormControl sx={styles.baseFormControl}> + <label + style={{ + paddingLeft: 5, + color: errorAbility ? ERROR_COLOR : 'inherit', + }} + > + Model Family + </label> + <FormHelperText> + Please be careful to select the family name corresponding to the + model you want to register. If not found, please choose `other`. + </FormHelperText> + <RadioGroup + value={familyLabel} + onChange={(e) => { + setFamilyLabel(e.target.value) + }} + > + <Box sx={styles.checkboxWrapper}> + {getFamilyByAbility().map((v) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel value={v} control={<Radio />} label={v} /> + </Box> + ))} + </Box> + </RadioGroup> + <Box padding="15px"></Box> + </FormControl> + + <Box width={'100%'}> + {successMsg !== '' && ( + <Alert severity="success"> + <AlertTitle>Success</AlertTitle> + {successMsg} + </Alert> + )} + <Button + variant="contained" + color="primary" + type="submit" + onClick={handleClick} + > + Register Model + </Button> + </Box> + </React.Fragment> + ) +} + +export default RegisterLanguageModel + +const styles = { + baseFormControl: { + width: '100%', + margin: 'normal', + size: 'small', + }, + checkboxWrapper: { + display: 'flex', + flexWrap: 'wrap', + maxWidth: '80%', + }, + labelPaddingLeft: { + paddingLeft: 5, + }, + formControlLabelPaddingLeft: { + paddingLeft: 10, + }, + buttonBox: { + width: '100%', + margin: '20px', + }, + error: { + fontWeight: 'bold', + margin: '5px 0', + padding: '1px', + borderRadius: '5px', + }, +} diff --git a/xinference/web/ui/src/scenes/register_model/register_rerank.js b/xinference/web/ui/src/scenes/register_model/register_rerank.js index 075b35ff9d..a2d9b2ad00 100644 --- a/xinference/web/ui/src/scenes/register_model/register_rerank.js +++ b/xinference/web/ui/src/scenes/register_model/register_rerank.js @@ -1,4 +1,14 @@ -import { Box, Checkbox, FormControl, FormControlLabel } from '@mui/material' +import { + Box, + Checkbox, + FormControl, + FormControlLabel, + InputLabel, + MenuItem, + Radio, + RadioGroup, + Select, +} from '@mui/material' import Alert from '@mui/material/Alert' import AlertTitle from '@mui/material/AlertTitle' import Button from '@mui/material/Button' @@ -12,24 +22,36 @@ import { useMode } from '../../theme' const SUPPORTED_LANGUAGES_DICT = { en: 'English', zh: 'Chinese' } // Convert dictionary of supported languages into list const SUPPORTED_LANGUAGES = Object.keys(SUPPORTED_LANGUAGES_DICT) +const SUPPORTED_HUBS_DICT = { huggingface: 'HuggingFace', modelscope: 'ModelScope' } +const SUPPORTED_HUBS = Object.keys(SUPPORTED_HUBS_DICT) + +const SOURCES_DICT = { self_hosted: 'Self Hosted', hub: 'Hub' } +const SOURCES = Object.keys(SOURCES_DICT) const RegisterRerankModel = () => { const ERROR_COLOR = useMode() const endPoint = useContext(ApiContext).endPoint const { setErrorMsg } = useContext(ApiContext) const [successMsg, setSuccessMsg] = useState('') + const [modelSource, setModelSource] = useState(SOURCES[0]) + const [hub, setHub] = useState(SUPPORTED_HUBS[0]) + const [modelId, setModelId] = useState('') + const [formData, setFormData] = useState({ model_name: 'custom-rerank', language: ['en'], model_uri: '/path/to/rerank-model', + model_id: null, + model_hub: null, }) const errorModelName = formData.model_name.trim().length <= 0 + const errorModelId = modelSource === 'hub' && modelId.search('\\w+/\\w+') === -1 const errorLanguage = formData.language === undefined || formData.language.length === 0 const handleClick = async () => { - const errorAny = errorModelName || errorLanguage + const errorAny = errorModelName || errorLanguage || errorModelId if (errorAny) { setErrorMsg('Please fill in valid value for all fields') @@ -37,6 +59,21 @@ const RegisterRerankModel = () => { } try { + let myFormData + if (modelSource === 'hub') { + myFormData = { + ...formData, + model_id: modelId, + model_hub: hub, + model_uri: null, + } + } else { + myFormData = { + ...formData, + model_id: null, + model_hub: null, + } + } const response = await fetcher( endPoint + '/v1/model_registrations/rerank', { @@ -45,21 +82,21 @@ const RegisterRerankModel = () => { 'Content-Type': 'application/json', }, body: JSON.stringify({ - model: JSON.stringify(formData), + model: JSON.stringify(myFormData), persist: true, }), - } + }, ) if (!response.ok) { const errorData = await response.json() // Assuming the server returns error details in JSON format setErrorMsg( `Server error: ${response.status} - ${ errorData.detail || 'Unknown error' - }` + }`, ) } else { setSuccessMsg( - 'Model has been registered successfully! Navigate to launch model page to proceed.' + 'Model has been registered successfully! Navigate to launch model page to proceed.', ) } } catch (error) { @@ -82,6 +119,33 @@ const RegisterRerankModel = () => { } } + const handleImportModel = async () => { + if (errorModelId) { + setErrorMsg('Please fill in valid value for Model Id') + return + } + + const split = modelId.split('/') + if (split.length !== 2) { + setErrorMsg('Please fill in valid value for Model Id') + return + } + + const repo_name = split[1] + const repo_split = repo_name.split(/[-_]/) + let lang = 'en' + for (const seg of repo_split) { + if (['zh', 'cn', 'chinese'].includes(seg.toLowerCase())) { + lang = 'zh' + break + } + } + setFormData({ + ...formData, + language: [lang], + }) + } + return ( <React.Fragment> <Box padding="20px"></Box> @@ -99,18 +163,90 @@ const RegisterRerankModel = () => { /> <Box padding="15px"></Box> - <TextField - label="Model Path" - size="small" - value={formData.model_uri} + + <label + style={{ + paddingLeft: 5, + }} + > + Model Source + </label> + + <RadioGroup + value={modelSource} onChange={(e) => { - setFormData({ - ...formData, - model_uri: e.target.value, - }) + setModelSource(e.target.value) }} - helperText="Provide the model directory path." - /> + > + <Box sx={styles.checkboxWrapper}> + {SOURCES.map((item) => ( + <Box sx={{ marginLeft: '10px' }}> + <FormControlLabel + value={item} + control={<Radio />} + label={SOURCES_DICT[item]} + /> + </Box> + ))} + </Box> + </RadioGroup> + <Box padding="15px"></Box> + + + {modelSource === 'self_hosted' && + <TextField + label="Model Path" + size="small" + value={formData.model_uri} + onChange={(e) => { + setFormData({ + ...formData, + model_uri: e.target.value, + }) + }} + helperText="Provide the model directory path." + />} + {modelSource === 'hub' && + <Box sx={styles.checkboxWrapper}> + + <TextField + sx={{ width: '400px' }} + label="Model Id" + size="small" + error={errorModelId} + value={modelId} + onChange={(e) => { + setModelId(e.target.value) + }} + placeholder="user/repo" + /> + + <FormControl variant="standard" + sx={{ marginLeft: '10px' }}> + <InputLabel id="hub-label">Hub</InputLabel> + <Select + labelId="hub-label" + value={hub} + label="Hub" + onChange={(e) => { + setHub(e.target.value) + }} + > + {SUPPORTED_HUBS.map((item) => ( + <MenuItem value={item}>{SUPPORTED_HUBS_DICT[item]}</MenuItem> + ))} + </Select> + </FormControl> + <Button + sx={{ marginLeft: '10px' }} + variant="contained" + color="primary" + onClick={handleImportModel} + > + Import Model + </Button> + </Box> + } <Box padding="15px"></Box> <label @@ -133,11 +269,11 @@ const RegisterRerankModel = () => { sx={ errorLanguage ? { - 'color': ERROR_COLOR, - '&.Mui-checked': { - color: ERROR_COLOR, - }, - } + 'color': ERROR_COLOR, + '&.Mui-checked': { + color: ERROR_COLOR, + }, + } : {} } />