diff --git a/README.md b/README.md index 88b8316..c10171a 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,29 @@ [![Github Actions Status](https://github.com/jupyter-ai-contrib/jupyter-ai-litellm/workflows/Build/badge.svg)](https://github.com/jupyter-ai-contrib/jupyter-ai-litellm/actions/workflows/build.yml) -A JupyterLab extension that provides LiteLLM model abstraction +A JupyterLab extension that provides LiteLLM model abstraction for Jupyter AI -This extension is composed of a Python package named `jupyter_ai_litellm`. +This extension is composed of a Python package named `jupyter_ai_litellm` that exposes LiteLLM's extensive catalog of language models through a standardized API. + +## Features + +- **Comprehensive Model Support**: Access to hundreds of chat and embedding models from various providers (OpenAI, Anthropic, Google, Cohere, Azure, AWS, and more) through LiteLLM's unified interface +- **Standardized API**: Consistent REST API endpoints for model discovery and interaction +- **Easy Integration**: Seamlessly integrates with Jupyter AI to expand available model options + +## API Endpoints + +### Chat Models + +- `GET /api/ai/models/chat` - Returns a list of all available chat models + +The response includes model IDs in LiteLLM format (e.g., `openai/gpt-4`, `anthropic/claude-3-sonnet`, etc.) + +### Model Lists + +The extension automatically discovers and categorizes models from LiteLLM's supported providers: +- Chat models for conversational AI +- Embedding models for vector representations ## Requirements diff --git a/jupyter_ai_litellm/_version.py b/jupyter_ai_litellm/_version.py new file mode 100644 index 0000000..6853c36 --- /dev/null +++ b/jupyter_ai_litellm/_version.py @@ -0,0 +1 @@ +__version__ = '0.0.0' \ No newline at end of file diff --git a/jupyter_ai_litellm/chat_models_rest_api.py b/jupyter_ai_litellm/chat_models_rest_api.py new file mode 100644 index 0000000..d4012ee --- /dev/null +++ b/jupyter_ai_litellm/chat_models_rest_api.py @@ -0,0 +1,29 @@ +from jupyter_server.base.handlers import APIHandler as BaseAPIHandler +from pydantic import BaseModel +from tornado import web + +from .model_list import CHAT_MODELS + + +class ChatModelsRestAPI(BaseAPIHandler): + """ + A Tornado handler that defines the REST API served on the + `/api/ai/models/chat` endpoint. + + - `GET /api/ai/models/chat`: returns list of all chat models. + + - `GET /api/ai/models/chat?id=`: returns info on that model (TODO) + """ + + @web.authenticated + def get(self): + response = ListChatModelsResponse(chat_models=CHAT_MODELS) + self.finish(response.model_dump_json()) + + +class ListChatModelsResponse(BaseModel): + chat_models: list[str] + + +class ListEmbeddingModelsResponse(BaseModel): + embedding_models: list[str] diff --git a/jupyter_ai_litellm/handlers.py b/jupyter_ai_litellm/handlers.py index 399630f..d103382 100644 --- a/jupyter_ai_litellm/handlers.py +++ b/jupyter_ai_litellm/handlers.py @@ -4,6 +4,8 @@ from jupyter_server.utils import url_path_join import tornado +from .chat_models_rest_api import ChatModelsRestAPI + class RouteHandler(APIHandler): # The following decorator should be present on all verb methods (head, get, post, # patch, put, delete, options) to ensure only authorized user can request the @@ -19,6 +21,10 @@ def setup_handlers(web_app): host_pattern = ".*$" base_url = web_app.settings["base_url"] + print(f"Base url is {base_url}") route_pattern = url_path_join(base_url, "jupyter-ai-litellm", "get-example") - handlers = [(route_pattern, RouteHandler)] + handlers = [ + (route_pattern, RouteHandler), + (url_path_join(base_url, "api/ai/models/chat") + r"(?:\?.*)?", ChatModelsRestAPI) + ] web_app.add_handlers(host_pattern, handlers) diff --git a/jupyter_ai_litellm/model_list.py b/jupyter_ai_litellm/model_list.py new file mode 100644 index 0000000..acd899b --- /dev/null +++ b/jupyter_ai_litellm/model_list.py @@ -0,0 +1,36 @@ +from litellm import all_embedding_models, models_by_provider + +chat_model_ids = [] +embedding_model_ids = [] +embedding_model_set = set(all_embedding_models) + +for provider_name in models_by_provider: + for model_name in models_by_provider[provider_name]: + model_name: str = model_name + + if model_name.startswith(f"{provider_name}/"): + model_id = model_name + else: + model_id = f"{provider_name}/{model_name}" + + is_embedding = ( + model_name in embedding_model_set + or model_id in embedding_model_set + or "embed" in model_id + ) + + if is_embedding: + embedding_model_ids.append(model_id) + else: + chat_model_ids.append(model_id) + + +CHAT_MODELS = sorted(chat_model_ids) +""" +List of chat model IDs, following the `litellm` syntax. +""" + +EMBEDDING_MODELS = sorted(embedding_model_ids) +""" +List of embedding model IDs, following the `litellm` syntax. +""" diff --git a/jupyter_ai_litellm/tests/test_handlers.py b/jupyter_ai_litellm/tests/test_handlers.py index 2fa9ece..8f17be9 100644 --- a/jupyter_ai_litellm/tests/test_handlers.py +++ b/jupyter_ai_litellm/tests/test_handlers.py @@ -10,4 +10,16 @@ async def test_get_example(jp_fetch): payload = json.loads(response.body) assert payload == { "data": "This is /jupyter-ai-litellm/get-example endpoint!" - } \ No newline at end of file + } + +async def test_get_chat_models(jp_fetch): + # When + response = await jp_fetch("api", "ai", "models", "chat") + + # Then + assert response.code == 200 + payload = json.loads(response.body) + chat_models = payload.get("chat_models") + + assert chat_models + assert len(chat_models) > 0 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f5a48d1..c17bbb3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ authors = [ { name = "Project Jupyter", email = "jupyter@googlegroups.com" }, ] dependencies = [ - "jupyter_server>=2.4.0,<3" + "jupyter_server>=2.4.0,<3", + "litellm>=1.73,<2", ] dynamic = ["version"]