Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,13 @@ license = {text = "MIT license"}
dependencies = [
"Click>=7.0",
"requests>=2.20",
"responses>=0.23",
"aiohttp>=3.8",
"tqdm>=4.60",
"docstring_parser>=0.10",
# "docstring_parser>=0.10",
"python-dotenv>=0.17.0",
"loguru>=0.7",
"dnspython",
"fastmcp; python_version >= '3.10'",
"alibabacloud_alidns20150109==3.5.10",
"alibabacloud_tea_openapi>=0.3.0",
"tencentcloud-sdk-python",
"batch_executor",
"batch_executor>=0.3.0",
"colorama",
"fastapi",
"uvicorn",
Expand All @@ -55,7 +50,11 @@ dev = [
"twine",
"pytest>=3",
"pytest-asyncio",
"pytest-mock"
"pytest-mock",
"alibabacloud_alidns20150109==3.5.10",
"alibabacloud_tea_openapi>=0.3.0",
"tencentcloud-sdk-python",
"dnspython",
]

[project.urls]
Expand Down
211 changes: 145 additions & 66 deletions src/chattool/core/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import logging
import hashlib
from typing import List, Dict, Union, Optional, AsyncGenerator, Any
from chattool.utils import setup_logger
from .request import HTTPClient
Expand Down Expand Up @@ -45,6 +46,12 @@ def __init__(self,
model = OPENAI_API_MODEL
if api_base is None:
api_base = OPENAI_API_BASE
if headers is None:
headers = {
'Content-Type': 'application/json'
}
if not headers.get('Authorization') and api_key:
headers['Authorization'] = f'Bearer {api_key}'
super().__init__(
logger=logger,
api_base=api_base,
Expand All @@ -64,6 +71,40 @@ def __init__(self,
self._chat_log: List[Dict] = messages or []
self._last_response: Optional[ChatResponse] = None

def _prepare_data(self,
messages: List[Dict[str, str]],
model: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
**kwargs) -> Dict[str, Any]:
"""准备请求数据"""
# 合并参数
model = model if model is not None else self.model
data = {
'model': model,
"messages": messages,
**kwargs
}
if temperature is not None:
data['temperature'] = temperature
if top_p is not None:
data['top_p'] = top_p
if max_tokens is not None:
data['max_tokens'] = max_tokens
for k in self.config.kwargs: # 默认的额外参数
if k not in data and self.config.get(k) is not None:
data[k] = self.config.get(k)
return data

def _get_params(self) -> Optional[Dict[str, str]]:
"""合并默认查询参数和自定义参数"""
return None

def _get_uri(self)->str:
"""获取请求 URI"""
return '/chat/completions'

def chat_completion(
self,
# data 部分
Expand All @@ -73,8 +114,6 @@ def chat_completion(
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
stream: bool = False,
# 默认请求地址
uri: str = '/chat/completions',
**kwargs # 其他参数,传入给 data 部分
) -> Dict[str, Any]:
"""
Expand All @@ -87,35 +126,21 @@ def chat_completion(
top_p: top_p 参数
max_tokens: 最大token数
stream: 是否使用流式响应
uri: 请求 URI
headers: 自定义请求头(OpenAI 特有参数)
params: 自定义查询参数(Azure 特有参数)
**kwargs: 其他参数
"""
# 合并参数
model = model if model is not None else self.model
data = {
'model': model,
"messages": messages,
data = self._prepare_data(
messages=messages,
model=model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
**kwargs
}
if temperature is not None:
data['temperature'] = temperature
if top_p is not None:
data['top_p'] = top_p
if max_tokens is not None:
data['max_tokens'] = max_tokens
)
if stream:
# NOTE: 流式响应请使用 chat_completion_stream_async
pass
for k in self.config.kwargs: # 默认的额外参数
if k not in data and self.config.get(k) is not None:
data[k] = self.config.get(k)
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
response = self.post(uri, data=data, headers=headers)
response = self.post(self._get_uri(), data=data, params=self._get_params())
return response.json()

async def async_chat_completion(
Expand All @@ -126,34 +151,22 @@ async def async_chat_completion(
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
stream: bool = False,
uri: str = '/chat/completions',
**kwargs
) -> Union[Dict[str, Any], AsyncGenerator]:
"""OpenAI Chat Completion API (异步版本)"""
# 合并参数
model = model if model is not None else self.model
data = {
'model': model,
"messages": messages,
data = self._prepare_data(
messages=messages,
model=model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
**kwargs
}
if temperature is not None:
data['temperature'] = temperature
if top_p is not None:
data['top_p'] = top_p
if max_tokens is not None:
data['max_tokens'] = max_tokens
)
if stream:
# NOTE: 流式响应请使用 chat_completion_stream_async
pass
for k in self.config.kwargs: # 默认的额外参数
if k not in data and self.config.get(k) is not None:
data[k] = self.config.get(k)
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
response = await self.async_post(uri, data=data, headers=headers)
response = await self.async_post(self._get_uri(), data=data, params=self._get_params())
return response.json()

async def chat_completion_stream_async(
Expand All @@ -163,39 +176,30 @@ async def chat_completion_stream_async(
temperature: Optional[float] = None,
top_p: Optional[float] = None,
max_tokens: Optional[int] = None,
uri: str = '/chat/completions',
stream: bool = True,
**kwargs
):
"""OpenAI Chat Completion API 流式响应(异步版本)"""
data = {
'model': model if model is not None else self.model,
"messages": messages,
'stream': True,
# 合并参数
data = self._prepare_data(
messages=messages,
model=model,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stream=True,
**kwargs
}
if temperature is not None:
data['temperature'] = temperature
if top_p is not None:
data['top_p'] = top_p
if max_tokens is not None:
data['max_tokens'] = max_tokens
for k in self.config.kwargs: # 默认的额外参数
if k not in data and self.config.get(k) is not None:
data[k] = self.config.get(k)
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
)
# 构建完整 URL
url = self._build_url(uri)
url = self._build_url(self._get_uri())
client = self._get_async_client()

async with client.stream(
method="POST",
url=url,
json=data,
headers=headers
headers=self.config.headers,
params=self._get_params()
) as stream:
# 检查响应状态码
if stream.status_code >= 400:
Expand Down Expand Up @@ -498,4 +502,79 @@ def __str__(self) -> str:

class AzureChat(Chat):
"""Azure OpenAI Chat 实现 - 继承 Chat 复用逻辑"""
pass
def __init__(self,
messages: Optional[Union[str, List[Dict[str, str]]]]=None,
logger: logging.Logger=None,
api_key: str=None,
api_base: str=None,
api_version: str=None,
model: str=None,
timeout: float = 0,
max_retries: int = 3,
retry_delay: float = 1.0,
headers: Optional[Dict[str, str]]=None,
**kwargs):
"""初始化聊天模型

Args:
logger: 日志记录器
api_key: API 密钥
api_base: API 基础 URL
model: 模型名称
timeout: 请求超时时间(秒)
max_retries: 最大重试次数
retry_delay: 重试延迟时间(秒)
headers: 自定义 HTTP 头
**kwargs: 其他配置参数,传入给 data 部分
"""
logger = logger or setup_logger('Chat')
if api_key is None:
api_key = AZURE_OPENAI_API_KEY
if model is None:
model = AZURE_OPENAI_API_MODEL
if api_base is None:
api_base = AZURE_OPENAI_ENDPOINT
if api_version is None:
api_version = AZURE_OPENAI_API_VERSION
if headers is None:
headers = {
'Content-Type': 'application/json'
}
HTTPClient.__init__(
self,
logger=logger,
api_base=api_base,
timeout=timeout,
max_retries=max_retries,
retry_delay=retry_delay,
headers=headers,
**kwargs
)

self.api_key = api_key
self.api_version = api_version
self.model = model

# 初始化对话历史
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
self._chat_log: List[Dict] = messages or []
self._last_response: Optional[ChatResponse] = None

def _get_params(self) -> Optional[Dict[str, str]]:
"""合并默认查询参数和自定义参数"""
params = {}
if self.api_version:
params['api-version'] = self.api_version
if self.api_key:
params['ak'] = self.api_key
return params

def _get_uri(self) -> str:
"""获取 API URI"""
return None

def _generate_log_id(self, messages: List[Dict[str, str]]) -> str:
"""生成请求的 log ID"""
content = str(messages).encode()
return hashlib.sha256(content).hexdigest()