Skip to content

Commit c689c6c

Browse files
authored
Azure OpenAI and OpenAI proxy support (#322)
* implement Azure OpenAI support and model ID labels * pre-commit
1 parent 3a2a467 commit c689c6c

File tree

8 files changed

+64
-9
lines changed

8 files changed

+64
-9
lines changed

docs/source/users/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ see the following interface:
238238
Each of the additional fields under "Language model" is required. These fields
239239
should contain the following data:
240240

241-
- **Local model ID**: The name of your endpoint. This can be retrieved from the
241+
- **Endpoint name**: The name of your endpoint. This can be retrieved from the
242242
AWS Console at the URL
243243
`https://<region>.console.aws.amazon.com/sagemaker/home?region=<region>#/endpoints`.
244244

packages/jupyter-ai-magics/jupyter_ai_magics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .providers import (
1414
AI21Provider,
1515
AnthropicProvider,
16+
AzureChatOpenAIProvider,
1617
BaseProvider,
1718
BedrockProvider,
1819
ChatOpenAINewProvider,

packages/jupyter-ai-magics/jupyter_ai_magics/providers.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from jsonpath_ng import parse
1111
from langchain import PromptTemplate
12-
from langchain.chat_models import ChatOpenAI
12+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
1313
from langchain.llms import (
1414
AI21,
1515
Anthropic,
@@ -107,6 +107,9 @@ class Config:
107107
model_id_key: ClassVar[str] = ...
108108
"""Kwarg expected by the upstream LangChain provider."""
109109

110+
model_id_label: ClassVar[str] = ""
111+
"""Human-readable label of the model ID."""
112+
110113
pypi_package_deps: ClassVar[List[str]] = []
111114
"""List of PyPi package dependencies."""
112115

@@ -464,6 +467,40 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
464467
pypi_package_deps = ["openai"]
465468
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
466469

470+
fields = [
471+
TextField(
472+
key="openai_api_base", label="Base API URL (optional)", format="text"
473+
),
474+
TextField(
475+
key="openai_organization", label="Organization (optional)", format="text"
476+
),
477+
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
478+
]
479+
480+
481+
class AzureChatOpenAIProvider(BaseProvider, AzureChatOpenAI):
482+
id = "azure-chat-openai"
483+
name = "Azure OpenAI"
484+
models = ["*"]
485+
model_id_key = "deployment_name"
486+
model_id_label = "Deployment name"
487+
pypi_package_deps = ["openai"]
488+
auth_strategy = EnvAuthStrategy(name="OPENAI_API_KEY")
489+
registry = True
490+
491+
fields = [
492+
TextField(
493+
key="openai_api_base", label="Base API URL (required)", format="text"
494+
),
495+
TextField(
496+
key="openai_api_version", label="API version (required)", format="text"
497+
),
498+
TextField(
499+
key="openai_organization", label="Organization (optional)", format="text"
500+
),
501+
TextField(key="openai_proxy", label="Proxy (optional)", format="text"),
502+
]
503+
467504

468505
class JsonContentHandler(LLMContentHandler):
469506
content_type = "application/json"
@@ -501,6 +538,7 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
501538
name = "SageMaker endpoint"
502539
models = ["*"]
503540
model_id_key = "endpoint_name"
541+
model_id_label = "Endpoint name"
504542
# This all needs to be on one line of markdown, for use in a table
505543
help = (
506544
"Specify an endpoint name as the model ID. "
@@ -513,9 +551,13 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
513551
auth_strategy = AwsAuthStrategy()
514552
registry = True
515553
fields = [
516-
TextField(key="region_name", label="Region name", format="text"),
517-
MultilineTextField(key="request_schema", label="Request schema", format="json"),
518-
TextField(key="response_path", label="Response path", format="jsonpath"),
554+
TextField(key="region_name", label="Region name (required)", format="text"),
555+
MultilineTextField(
556+
key="request_schema", label="Request schema (required)", format="json"
557+
),
558+
TextField(
559+
key="response_path", label="Response path (required)", format="jsonpath"
560+
),
519561
]
520562

521563
def __init__(self, *args, **kwargs):

packages/jupyter-ai-magics/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ huggingface_hub = "jupyter_ai_magics:HfHubProvider"
6363
openai = "jupyter_ai_magics:OpenAIProvider"
6464
openai-chat = "jupyter_ai_magics:ChatOpenAIProvider"
6565
openai-chat-new = "jupyter_ai_magics:ChatOpenAINewProvider"
66+
azure-chat-openai = "jupyter_ai_magics:AzureChatOpenAIProvider"
6667
sagemaker-endpoint = "jupyter_ai_magics:SmEndpointProvider"
6768
amazon-bedrock = "jupyter_ai_magics:BedrockProvider"
6869

packages/jupyter-ai/jupyter_ai/handlers.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from asyncio import AbstractEventLoop
66
from dataclasses import asdict
7-
from typing import Dict, List
7+
from typing import TYPE_CHECKING, Dict, List
88

99
import tornado
1010
from jupyter_ai.chat_handlers import BaseChatHandler
@@ -29,6 +29,10 @@
2929
Message,
3030
)
3131

32+
if TYPE_CHECKING:
33+
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
34+
from jupyter_ai_magics.providers import BaseProvider
35+
3236

3337
class ChatHistoryHandler(BaseAPIHandler):
3438
"""Handler to return message history"""
@@ -237,7 +241,7 @@ def on_close(self):
237241

238242
class ModelProviderHandler(BaseAPIHandler):
239243
@property
240-
def lm_providers(self):
244+
def lm_providers(self) -> Dict[str, "BaseProvider"]:
241245
return self.settings["lm_providers"]
242246

243247
@web.authenticated
@@ -248,6 +252,10 @@ def get(self):
248252
if provider.id == "openai-chat":
249253
continue
250254

255+
optionals = {}
256+
if provider.model_id_label:
257+
optionals["model_id_label"] = provider.model_id_label
258+
251259
providers.append(
252260
ListProvidersEntry(
253261
id=provider.id,
@@ -256,6 +264,7 @@ def get(self):
256264
auth_strategy=provider.auth_strategy,
257265
registry=provider.registry,
258266
fields=provider.fields,
267+
**optionals,
259268
)
260269
)
261270

@@ -267,7 +276,7 @@ def get(self):
267276

268277
class EmbeddingsModelProviderHandler(BaseAPIHandler):
269278
@property
270-
def em_providers(self):
279+
def em_providers(self) -> Dict[str, "BaseEmbeddingsProvider"]:
271280
return self.settings["em_providers"]
272281

273282
@web.authenticated

packages/jupyter-ai/jupyter_ai/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class ListProvidersEntry(BaseModel):
7777

7878
id: str
7979
name: str
80+
model_id_label: Optional[str]
8081
models: List[str]
8182
auth_strategy: AuthStrategy
8283
registry: bool

packages/jupyter-ai/src/components/chat-settings.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ export function ChatSettings(): JSX.Element {
338338
</Select>
339339
{showLmLocalId && (
340340
<TextField
341-
label="Local model ID"
341+
label={lmProvider?.model_id_label || 'Local model ID'}
342342
value={lmLocalId}
343343
onChange={e => setLmLocalId(e.target.value)}
344344
fullWidth

packages/jupyter-ai/src/handler.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ export namespace AiService {
156156
export type ListProvidersEntry = {
157157
id: string;
158158
name: string;
159+
model_id_label?: string;
159160
models: string[];
160161
auth_strategy: AuthStrategy;
161162
registry: boolean;

0 commit comments

Comments
 (0)