diff --git a/pyproject.toml b/pyproject.toml index ffb2bec..43b47ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,18 +27,13 @@ license = {text = "MIT license"} dependencies = [ "Click>=7.0", "requests>=2.20", - "responses>=0.23", "aiohttp>=3.8", "tqdm>=4.60", - "docstring_parser>=0.10", + # "docstring_parser>=0.10", "python-dotenv>=0.17.0", "loguru>=0.7", - "dnspython", "fastmcp; python_version >= '3.10'", - "alibabacloud_alidns20150109==3.5.10", - "alibabacloud_tea_openapi>=0.3.0", - "tencentcloud-sdk-python", - "batch_executor", + "batch_executor>=0.3.0", "colorama", "fastapi", "uvicorn", @@ -55,7 +50,11 @@ dev = [ "twine", "pytest>=3", "pytest-asyncio", - "pytest-mock" + "pytest-mock", + "alibabacloud_alidns20150109==3.5.10", + "alibabacloud_tea_openapi>=0.3.0", + "tencentcloud-sdk-python", + "dnspython", ] [project.urls] diff --git a/src/chattool/core/chattype.py b/src/chattool/core/chattype.py index ab20b02..2760f0e 100644 --- a/src/chattool/core/chattype.py +++ b/src/chattool/core/chattype.py @@ -2,6 +2,7 @@ import json import os import logging +import hashlib from typing import List, Dict, Union, Optional, AsyncGenerator, Any from chattool.utils import setup_logger from .request import HTTPClient @@ -45,6 +46,12 @@ def __init__(self, model = OPENAI_API_MODEL if api_base is None: api_base = OPENAI_API_BASE + if headers is None: + headers = { + 'Content-Type': 'application/json' + } + if not headers.get('Authorization') and api_key: + headers['Authorization'] = f'Bearer {api_key}' super().__init__( logger=logger, api_base=api_base, @@ -64,6 +71,40 @@ def __init__(self, self._chat_log: List[Dict] = messages or [] self._last_response: Optional[ChatResponse] = None + def _prepare_data(self, + messages: List[Dict[str, str]], + model: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + max_tokens: Optional[int] = None, + **kwargs) -> Dict[str, Any]: + """准备请求数据""" + # 合并参数 + model = model if model is not None else self.model + data = { + 'model': model, + "messages": messages, + **kwargs + } + if temperature is not None: + data['temperature'] = temperature + if top_p is not None: + data['top_p'] = top_p + if max_tokens is not None: + data['max_tokens'] = max_tokens + for k in self.config.kwargs: # 默认的额外参数 + if k not in data and self.config.get(k) is not None: + data[k] = self.config.get(k) + return data + + def _get_params(self) -> Optional[Dict[str, str]]: + """合并默认查询参数和自定义参数""" + return None + + def _get_uri(self)->str: + """获取请求 URI""" + return '/chat/completions' + def chat_completion( self, # data 部分 @@ -73,8 +114,6 @@ def chat_completion( top_p: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False, - # 默认请求地址 - uri: str = '/chat/completions', **kwargs # 其他参数,传入给 data 部分 ) -> Dict[str, Any]: """ @@ -87,35 +126,21 @@ def chat_completion( top_p: top_p 参数 max_tokens: 最大token数 stream: 是否使用流式响应 - uri: 请求 URI - headers: 自定义请求头(OpenAI 特有参数) - params: 自定义查询参数(Azure 特有参数) **kwargs: 其他参数 """ # 合并参数 - model = model if model is not None else self.model - data = { - 'model': model, - "messages": messages, + data = self._prepare_data( + messages=messages, + model=model, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, **kwargs - } - if temperature is not None: - data['temperature'] = temperature - if top_p is not None: - data['top_p'] = top_p - if max_tokens is not None: - data['max_tokens'] = max_tokens + ) if stream: # NOTE: 流式响应请使用 chat_completion_stream_async pass - for k in self.config.kwargs: # 默认的额外参数 - if k not in data and self.config.get(k) is not None: - data[k] = self.config.get(k) - headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json', - } - response = self.post(uri, data=data, headers=headers) + response = self.post(self._get_uri(), data=data, params=self._get_params()) return response.json() async def async_chat_completion( @@ -126,34 +151,22 @@ async def async_chat_completion( top_p: Optional[float] = None, max_tokens: Optional[int] = None, stream: bool = False, - uri: str = '/chat/completions', **kwargs ) -> Union[Dict[str, Any], AsyncGenerator]: """OpenAI Chat Completion API (异步版本)""" # 合并参数 - model = model if model is not None else self.model - data = { - 'model': model, - "messages": messages, + data = self._prepare_data( + messages=messages, + model=model, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, **kwargs - } - if temperature is not None: - data['temperature'] = temperature - if top_p is not None: - data['top_p'] = top_p - if max_tokens is not None: - data['max_tokens'] = max_tokens + ) if stream: # NOTE: 流式响应请使用 chat_completion_stream_async pass - for k in self.config.kwargs: # 默认的额外参数 - if k not in data and self.config.get(k) is not None: - data[k] = self.config.get(k) - headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json', - } - response = await self.async_post(uri, data=data, headers=headers) + response = await self.async_post(self._get_uri(), data=data, params=self._get_params()) return response.json() async def chat_completion_stream_async( @@ -163,39 +176,30 @@ async def chat_completion_stream_async( temperature: Optional[float] = None, top_p: Optional[float] = None, max_tokens: Optional[int] = None, - uri: str = '/chat/completions', stream: bool = True, **kwargs ): """OpenAI Chat Completion API 流式响应(异步版本)""" - data = { - 'model': model if model is not None else self.model, - "messages": messages, - 'stream': True, + # 合并参数 + data = self._prepare_data( + messages=messages, + model=model, + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + stream=True, **kwargs - } - if temperature is not None: - data['temperature'] = temperature - if top_p is not None: - data['top_p'] = top_p - if max_tokens is not None: - data['max_tokens'] = max_tokens - for k in self.config.kwargs: # 默认的额外参数 - if k not in data and self.config.get(k) is not None: - data[k] = self.config.get(k) - headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json', - } + ) # 构建完整 URL - url = self._build_url(uri) + url = self._build_url(self._get_uri()) client = self._get_async_client() async with client.stream( method="POST", url=url, json=data, - headers=headers + headers=self.config.headers, + params=self._get_params() ) as stream: # 检查响应状态码 if stream.status_code >= 400: @@ -498,4 +502,79 @@ def __str__(self) -> str: class AzureChat(Chat): """Azure OpenAI Chat 实现 - 继承 Chat 复用逻辑""" - pass \ No newline at end of file + def __init__(self, + messages: Optional[Union[str, List[Dict[str, str]]]]=None, + logger: logging.Logger=None, + api_key: str=None, + api_base: str=None, + api_version: str=None, + model: str=None, + timeout: float = 0, + max_retries: int = 3, + retry_delay: float = 1.0, + headers: Optional[Dict[str, str]]=None, + **kwargs): + """初始化聊天模型 + + Args: + logger: 日志记录器 + api_key: API 密钥 + api_base: API 基础 URL + model: 模型名称 + timeout: 请求超时时间(秒) + max_retries: 最大重试次数 + retry_delay: 重试延迟时间(秒) + headers: 自定义 HTTP 头 + **kwargs: 其他配置参数,传入给 data 部分 + """ + logger = logger or setup_logger('Chat') + if api_key is None: + api_key = AZURE_OPENAI_API_KEY + if model is None: + model = AZURE_OPENAI_API_MODEL + if api_base is None: + api_base = AZURE_OPENAI_ENDPOINT + if api_version is None: + api_version = AZURE_OPENAI_API_VERSION + if headers is None: + headers = { + 'Content-Type': 'application/json' + } + HTTPClient.__init__( + self, + logger=logger, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + retry_delay=retry_delay, + headers=headers, + **kwargs + ) + + self.api_key = api_key + self.api_version = api_version + self.model = model + + # 初始化对话历史 + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + self._chat_log: List[Dict] = messages or [] + self._last_response: Optional[ChatResponse] = None + + def _get_params(self) -> Optional[Dict[str, str]]: + """合并默认查询参数和自定义参数""" + params = {} + if self.api_version: + params['api-version'] = self.api_version + if self.api_key: + params['ak'] = self.api_key + return params + + def _get_uri(self) -> str: + """获取 API URI""" + return None + + def _generate_log_id(self, messages: List[Dict[str, str]]) -> str: + """生成请求的 log ID""" + content = str(messages).encode() + return hashlib.sha256(content).hexdigest()