99
1010from jsonpath_ng import parse
1111from langchain import PromptTemplate
12- from langchain .chat_models import ChatOpenAI
12+ from langchain .chat_models import AzureChatOpenAI , ChatOpenAI
1313from langchain .llms import (
1414 AI21 ,
1515 Anthropic ,
@@ -107,6 +107,9 @@ class Config:
107107 model_id_key : ClassVar [str ] = ...
108108 """Kwarg expected by the upstream LangChain provider."""
109109
110+ model_id_label : ClassVar [str ] = ""
111+ """Human-readable label of the model ID."""
112+
110113 pypi_package_deps : ClassVar [List [str ]] = []
111114 """List of PyPi package dependencies."""
112115
@@ -464,6 +467,40 @@ class ChatOpenAINewProvider(BaseProvider, ChatOpenAI):
464467 pypi_package_deps = ["openai" ]
465468 auth_strategy = EnvAuthStrategy (name = "OPENAI_API_KEY" )
466469
470+ fields = [
471+ TextField (
472+ key = "openai_api_base" , label = "Base API URL (optional)" , format = "text"
473+ ),
474+ TextField (
475+ key = "openai_organization" , label = "Organization (optional)" , format = "text"
476+ ),
477+ TextField (key = "openai_proxy" , label = "Proxy (optional)" , format = "text" ),
478+ ]
479+
480+
481+ class AzureChatOpenAIProvider (BaseProvider , AzureChatOpenAI ):
482+ id = "azure-chat-openai"
483+ name = "Azure OpenAI"
484+ models = ["*" ]
485+ model_id_key = "deployment_name"
486+ model_id_label = "Deployment name"
487+ pypi_package_deps = ["openai" ]
488+ auth_strategy = EnvAuthStrategy (name = "OPENAI_API_KEY" )
489+ registry = True
490+
491+ fields = [
492+ TextField (
493+ key = "openai_api_base" , label = "Base API URL (required)" , format = "text"
494+ ),
495+ TextField (
496+ key = "openai_api_version" , label = "API version (required)" , format = "text"
497+ ),
498+ TextField (
499+ key = "openai_organization" , label = "Organization (optional)" , format = "text"
500+ ),
501+ TextField (key = "openai_proxy" , label = "Proxy (optional)" , format = "text" ),
502+ ]
503+
467504
468505class JsonContentHandler (LLMContentHandler ):
469506 content_type = "application/json"
@@ -501,6 +538,7 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
501538 name = "SageMaker endpoint"
502539 models = ["*" ]
503540 model_id_key = "endpoint_name"
541+ model_id_label = "Endpoint name"
504542 # This all needs to be on one line of markdown, for use in a table
505543 help = (
506544 "Specify an endpoint name as the model ID. "
@@ -513,9 +551,13 @@ class SmEndpointProvider(BaseProvider, SagemakerEndpoint):
513551 auth_strategy = AwsAuthStrategy ()
514552 registry = True
515553 fields = [
516- TextField (key = "region_name" , label = "Region name" , format = "text" ),
517- MultilineTextField (key = "request_schema" , label = "Request schema" , format = "json" ),
518- TextField (key = "response_path" , label = "Response path" , format = "jsonpath" ),
554+ TextField (key = "region_name" , label = "Region name (required)" , format = "text" ),
555+ MultilineTextField (
556+ key = "request_schema" , label = "Request schema (required)" , format = "json"
557+ ),
558+ TextField (
559+ key = "response_path" , label = "Response path (required)" , format = "jsonpath"
560+ ),
519561 ]
520562
521563 def __init__ (self , * args , ** kwargs ):
0 commit comments