Skip to content

Commit 63beb95

Browse files
authored
Support for Azure (#94)
* update dependences * setup for azure
1 parent e1339c5 commit 63beb95

File tree

2 files changed

+152
-74
lines changed

2 files changed

+152
-74
lines changed

pyproject.toml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,13 @@ license = {text = "MIT license"}
2727
dependencies = [
2828
"Click>=7.0",
2929
"requests>=2.20",
30-
"responses>=0.23",
3130
"aiohttp>=3.8",
3231
"tqdm>=4.60",
33-
"docstring_parser>=0.10",
32+
# "docstring_parser>=0.10",
3433
"python-dotenv>=0.17.0",
3534
"loguru>=0.7",
36-
"dnspython",
3735
"fastmcp; python_version >= '3.10'",
38-
"alibabacloud_alidns20150109==3.5.10",
39-
"alibabacloud_tea_openapi>=0.3.0",
40-
"tencentcloud-sdk-python",
41-
"batch_executor",
36+
"batch_executor>=0.3.0",
4237
"colorama",
4338
"fastapi",
4439
"uvicorn",
@@ -55,7 +50,11 @@ dev = [
5550
"twine",
5651
"pytest>=3",
5752
"pytest-asyncio",
58-
"pytest-mock"
53+
"pytest-mock",
54+
"alibabacloud_alidns20150109==3.5.10",
55+
"alibabacloud_tea_openapi>=0.3.0",
56+
"tencentcloud-sdk-python",
57+
"dnspython",
5958
]
6059

6160
[project.urls]

src/chattool/core/chattype.py

Lines changed: 145 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import os
44
import logging
5+
import hashlib
56
from typing import List, Dict, Union, Optional, AsyncGenerator, Any
67
from chattool.utils import setup_logger
78
from .request import HTTPClient
@@ -45,6 +46,12 @@ def __init__(self,
4546
model = OPENAI_API_MODEL
4647
if api_base is None:
4748
api_base = OPENAI_API_BASE
49+
if headers is None:
50+
headers = {
51+
'Content-Type': 'application/json'
52+
}
53+
if not headers.get('Authorization') and api_key:
54+
headers['Authorization'] = f'Bearer {api_key}'
4855
super().__init__(
4956
logger=logger,
5057
api_base=api_base,
@@ -64,6 +71,40 @@ def __init__(self,
6471
self._chat_log: List[Dict] = messages or []
6572
self._last_response: Optional[ChatResponse] = None
6673

74+
def _prepare_data(self,
75+
messages: List[Dict[str, str]],
76+
model: Optional[str] = None,
77+
temperature: Optional[float] = None,
78+
top_p: Optional[float] = None,
79+
max_tokens: Optional[int] = None,
80+
**kwargs) -> Dict[str, Any]:
81+
"""准备请求数据"""
82+
# 合并参数
83+
model = model if model is not None else self.model
84+
data = {
85+
'model': model,
86+
"messages": messages,
87+
**kwargs
88+
}
89+
if temperature is not None:
90+
data['temperature'] = temperature
91+
if top_p is not None:
92+
data['top_p'] = top_p
93+
if max_tokens is not None:
94+
data['max_tokens'] = max_tokens
95+
for k in self.config.kwargs: # 默认的额外参数
96+
if k not in data and self.config.get(k) is not None:
97+
data[k] = self.config.get(k)
98+
return data
99+
100+
def _get_params(self) -> Optional[Dict[str, str]]:
101+
"""合并默认查询参数和自定义参数"""
102+
return None
103+
104+
def _get_uri(self)->str:
105+
"""获取请求 URI"""
106+
return '/chat/completions'
107+
67108
def chat_completion(
68109
self,
69110
# data 部分
@@ -73,8 +114,6 @@ def chat_completion(
73114
top_p: Optional[float] = None,
74115
max_tokens: Optional[int] = None,
75116
stream: bool = False,
76-
# 默认请求地址
77-
uri: str = '/chat/completions',
78117
**kwargs # 其他参数,传入给 data 部分
79118
) -> Dict[str, Any]:
80119
"""
@@ -87,35 +126,21 @@ def chat_completion(
87126
top_p: top_p 参数
88127
max_tokens: 最大token数
89128
stream: 是否使用流式响应
90-
uri: 请求 URI
91-
headers: 自定义请求头(OpenAI 特有参数)
92-
params: 自定义查询参数(Azure 特有参数)
93129
**kwargs: 其他参数
94130
"""
95131
# 合并参数
96-
model = model if model is not None else self.model
97-
data = {
98-
'model': model,
99-
"messages": messages,
132+
data = self._prepare_data(
133+
messages=messages,
134+
model=model,
135+
temperature=temperature,
136+
top_p=top_p,
137+
max_tokens=max_tokens,
100138
**kwargs
101-
}
102-
if temperature is not None:
103-
data['temperature'] = temperature
104-
if top_p is not None:
105-
data['top_p'] = top_p
106-
if max_tokens is not None:
107-
data['max_tokens'] = max_tokens
139+
)
108140
if stream:
109141
# NOTE: 流式响应请使用 chat_completion_stream_async
110142
pass
111-
for k in self.config.kwargs: # 默认的额外参数
112-
if k not in data and self.config.get(k) is not None:
113-
data[k] = self.config.get(k)
114-
headers = {
115-
'Authorization': f'Bearer {self.api_key}',
116-
'Content-Type': 'application/json',
117-
}
118-
response = self.post(uri, data=data, headers=headers)
143+
response = self.post(self._get_uri(), data=data, params=self._get_params())
119144
return response.json()
120145

121146
async def async_chat_completion(
@@ -126,34 +151,22 @@ async def async_chat_completion(
126151
top_p: Optional[float] = None,
127152
max_tokens: Optional[int] = None,
128153
stream: bool = False,
129-
uri: str = '/chat/completions',
130154
**kwargs
131155
) -> Union[Dict[str, Any], AsyncGenerator]:
132156
"""OpenAI Chat Completion API (异步版本)"""
133157
# 合并参数
134-
model = model if model is not None else self.model
135-
data = {
136-
'model': model,
137-
"messages": messages,
158+
data = self._prepare_data(
159+
messages=messages,
160+
model=model,
161+
temperature=temperature,
162+
top_p=top_p,
163+
max_tokens=max_tokens,
138164
**kwargs
139-
}
140-
if temperature is not None:
141-
data['temperature'] = temperature
142-
if top_p is not None:
143-
data['top_p'] = top_p
144-
if max_tokens is not None:
145-
data['max_tokens'] = max_tokens
165+
)
146166
if stream:
147167
# NOTE: 流式响应请使用 chat_completion_stream_async
148168
pass
149-
for k in self.config.kwargs: # 默认的额外参数
150-
if k not in data and self.config.get(k) is not None:
151-
data[k] = self.config.get(k)
152-
headers = {
153-
'Authorization': f'Bearer {self.api_key}',
154-
'Content-Type': 'application/json',
155-
}
156-
response = await self.async_post(uri, data=data, headers=headers)
169+
response = await self.async_post(self._get_uri(), data=data, params=self._get_params())
157170
return response.json()
158171

159172
async def chat_completion_stream_async(
@@ -163,39 +176,30 @@ async def chat_completion_stream_async(
163176
temperature: Optional[float] = None,
164177
top_p: Optional[float] = None,
165178
max_tokens: Optional[int] = None,
166-
uri: str = '/chat/completions',
167179
stream: bool = True,
168180
**kwargs
169181
):
170182
"""OpenAI Chat Completion API 流式响应(异步版本)"""
171-
data = {
172-
'model': model if model is not None else self.model,
173-
"messages": messages,
174-
'stream': True,
183+
# 合并参数
184+
data = self._prepare_data(
185+
messages=messages,
186+
model=model,
187+
temperature=temperature,
188+
top_p=top_p,
189+
max_tokens=max_tokens,
190+
stream=True,
175191
**kwargs
176-
}
177-
if temperature is not None:
178-
data['temperature'] = temperature
179-
if top_p is not None:
180-
data['top_p'] = top_p
181-
if max_tokens is not None:
182-
data['max_tokens'] = max_tokens
183-
for k in self.config.kwargs: # 默认的额外参数
184-
if k not in data and self.config.get(k) is not None:
185-
data[k] = self.config.get(k)
186-
headers = {
187-
'Authorization': f'Bearer {self.api_key}',
188-
'Content-Type': 'application/json',
189-
}
192+
)
190193
# 构建完整 URL
191-
url = self._build_url(uri)
194+
url = self._build_url(self._get_uri())
192195
client = self._get_async_client()
193196

194197
async with client.stream(
195198
method="POST",
196199
url=url,
197200
json=data,
198-
headers=headers
201+
headers=self.config.headers,
202+
params=self._get_params()
199203
) as stream:
200204
# 检查响应状态码
201205
if stream.status_code >= 400:
@@ -498,4 +502,79 @@ def __str__(self) -> str:
498502

499503
class AzureChat(Chat):
500504
"""Azure OpenAI Chat 实现 - 继承 Chat 复用逻辑"""
501-
pass
505+
def __init__(self,
506+
messages: Optional[Union[str, List[Dict[str, str]]]]=None,
507+
logger: logging.Logger=None,
508+
api_key: str=None,
509+
api_base: str=None,
510+
api_version: str=None,
511+
model: str=None,
512+
timeout: float = 0,
513+
max_retries: int = 3,
514+
retry_delay: float = 1.0,
515+
headers: Optional[Dict[str, str]]=None,
516+
**kwargs):
517+
"""初始化聊天模型
518+
519+
Args:
520+
logger: 日志记录器
521+
api_key: API 密钥
522+
api_base: API 基础 URL
523+
model: 模型名称
524+
timeout: 请求超时时间(秒)
525+
max_retries: 最大重试次数
526+
retry_delay: 重试延迟时间(秒)
527+
headers: 自定义 HTTP 头
528+
**kwargs: 其他配置参数,传入给 data 部分
529+
"""
530+
logger = logger or setup_logger('Chat')
531+
if api_key is None:
532+
api_key = AZURE_OPENAI_API_KEY
533+
if model is None:
534+
model = AZURE_OPENAI_API_MODEL
535+
if api_base is None:
536+
api_base = AZURE_OPENAI_ENDPOINT
537+
if api_version is None:
538+
api_version = AZURE_OPENAI_API_VERSION
539+
if headers is None:
540+
headers = {
541+
'Content-Type': 'application/json'
542+
}
543+
HTTPClient.__init__(
544+
self,
545+
logger=logger,
546+
api_base=api_base,
547+
timeout=timeout,
548+
max_retries=max_retries,
549+
retry_delay=retry_delay,
550+
headers=headers,
551+
**kwargs
552+
)
553+
554+
self.api_key = api_key
555+
self.api_version = api_version
556+
self.model = model
557+
558+
# 初始化对话历史
559+
if isinstance(messages, str):
560+
messages = [{"role": "user", "content": messages}]
561+
self._chat_log: List[Dict] = messages or []
562+
self._last_response: Optional[ChatResponse] = None
563+
564+
def _get_params(self) -> Optional[Dict[str, str]]:
565+
"""合并默认查询参数和自定义参数"""
566+
params = {}
567+
if self.api_version:
568+
params['api-version'] = self.api_version
569+
if self.api_key:
570+
params['ak'] = self.api_key
571+
return params
572+
573+
def _get_uri(self) -> str:
574+
"""获取 API URI"""
575+
return None
576+
577+
def _generate_log_id(self, messages: List[Dict[str, str]]) -> str:
578+
"""生成请求的 log ID"""
579+
content = str(messages).encode()
580+
return hashlib.sha256(content).hexdigest()

0 commit comments

Comments
 (0)