Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,29 @@

[![Github Actions Status](https://github.com/jupyter-ai-contrib/jupyter-ai-litellm/workflows/Build/badge.svg)](https://github.com/jupyter-ai-contrib/jupyter-ai-litellm/actions/workflows/build.yml)

A JupyterLab extension that provides LiteLLM model abstraction
A JupyterLab extension that provides LiteLLM model abstraction for Jupyter AI

This extension is composed of a Python package named `jupyter_ai_litellm`.
This extension is composed of a Python package named `jupyter_ai_litellm` that exposes LiteLLM's extensive catalog of language models through a standardized API.

## Features

- **Comprehensive Model Support**: Access to hundreds of chat and embedding models from various providers (OpenAI, Anthropic, Google, Cohere, Azure, AWS, and more) through LiteLLM's unified interface
- **Standardized API**: Consistent REST API endpoints for model discovery and interaction
- **Easy Integration**: Seamlessly integrates with Jupyter AI to expand available model options

## API Endpoints

### Chat Models

- `GET /api/ai/models/chat` - Returns a list of all available chat models

The response includes model IDs in LiteLLM format (e.g., `openai/gpt-4`, `anthropic/claude-3-sonnet`, etc.)

### Model Lists

The extension automatically discovers and categorizes models from LiteLLM's supported providers:
- Chat models for conversational AI
- Embedding models for vector representations

## Requirements

Expand Down
1 change: 1 addition & 0 deletions jupyter_ai_litellm/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = '0.0.0'
29 changes: 29 additions & 0 deletions jupyter_ai_litellm/chat_models_rest_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from jupyter_server.base.handlers import APIHandler as BaseAPIHandler
from pydantic import BaseModel
from tornado import web

from .model_list import CHAT_MODELS


class ChatModelsRestAPI(BaseAPIHandler):
"""
A Tornado handler that defines the REST API served on the
`/api/ai/models/chat` endpoint.

- `GET /api/ai/models/chat`: returns list of all chat models.

- `GET /api/ai/models/chat?id=<model_id>`: returns info on that model (TODO)
"""

@web.authenticated
def get(self):
response = ListChatModelsResponse(chat_models=CHAT_MODELS)
self.finish(response.model_dump_json())


class ListChatModelsResponse(BaseModel):
chat_models: list[str]


class ListEmbeddingModelsResponse(BaseModel):
embedding_models: list[str]
8 changes: 7 additions & 1 deletion jupyter_ai_litellm/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from jupyter_server.utils import url_path_join
import tornado

from .chat_models_rest_api import ChatModelsRestAPI

class RouteHandler(APIHandler):
# The following decorator should be present on all verb methods (head, get, post,
# patch, put, delete, options) to ensure only authorized user can request the
Expand All @@ -19,6 +21,10 @@ def setup_handlers(web_app):
host_pattern = ".*$"

base_url = web_app.settings["base_url"]
print(f"Base url is {base_url}")
route_pattern = url_path_join(base_url, "jupyter-ai-litellm", "get-example")
handlers = [(route_pattern, RouteHandler)]
handlers = [
(route_pattern, RouteHandler),
(url_path_join(base_url, "api/ai/models/chat") + r"(?:\?.*)?", ChatModelsRestAPI)
]
web_app.add_handlers(host_pattern, handlers)
36 changes: 36 additions & 0 deletions jupyter_ai_litellm/model_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from litellm import all_embedding_models, models_by_provider

chat_model_ids = []
embedding_model_ids = []
embedding_model_set = set(all_embedding_models)

for provider_name in models_by_provider:
for model_name in models_by_provider[provider_name]:
model_name: str = model_name

if model_name.startswith(f"{provider_name}/"):
model_id = model_name
else:
model_id = f"{provider_name}/{model_name}"

is_embedding = (
model_name in embedding_model_set
or model_id in embedding_model_set
or "embed" in model_id
)

if is_embedding:
embedding_model_ids.append(model_id)
else:
chat_model_ids.append(model_id)


CHAT_MODELS = sorted(chat_model_ids)
"""
List of chat model IDs, following the `litellm` syntax.
"""

EMBEDDING_MODELS = sorted(embedding_model_ids)
"""
List of embedding model IDs, following the `litellm` syntax.
"""
14 changes: 13 additions & 1 deletion jupyter_ai_litellm/tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,16 @@ async def test_get_example(jp_fetch):
payload = json.loads(response.body)
assert payload == {
"data": "This is /jupyter-ai-litellm/get-example endpoint!"
}
}

async def test_get_chat_models(jp_fetch):
# When
response = await jp_fetch("api", "ai", "models", "chat")

# Then
assert response.code == 200
payload = json.loads(response.body)
chat_models = payload.get("chat_models")

assert chat_models
assert len(chat_models) > 0
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ authors = [
{ name = "Project Jupyter", email = "[email protected]" },
]
dependencies = [
"jupyter_server>=2.4.0,<3"
"jupyter_server>=2.4.0,<3",
"litellm>=1.73,<2",
]
dynamic = ["version"]

Expand Down