22import json
33import os
44import logging
5+ import hashlib
56from typing import List , Dict , Union , Optional , AsyncGenerator , Any
67from chattool .utils import setup_logger
78from .request import HTTPClient
@@ -45,6 +46,12 @@ def __init__(self,
4546 model = OPENAI_API_MODEL
4647 if api_base is None :
4748 api_base = OPENAI_API_BASE
49+ if headers is None :
50+ headers = {
51+ 'Content-Type' : 'application/json'
52+ }
53+ if not headers .get ('Authorization' ) and api_key :
54+ headers ['Authorization' ] = f'Bearer { api_key } '
4855 super ().__init__ (
4956 logger = logger ,
5057 api_base = api_base ,
@@ -64,6 +71,40 @@ def __init__(self,
6471 self ._chat_log : List [Dict ] = messages or []
6572 self ._last_response : Optional [ChatResponse ] = None
6673
74+ def _prepare_data (self ,
75+ messages : List [Dict [str , str ]],
76+ model : Optional [str ] = None ,
77+ temperature : Optional [float ] = None ,
78+ top_p : Optional [float ] = None ,
79+ max_tokens : Optional [int ] = None ,
80+ ** kwargs ) -> Dict [str , Any ]:
81+ """准备请求数据"""
82+ # 合并参数
83+ model = model if model is not None else self .model
84+ data = {
85+ 'model' : model ,
86+ "messages" : messages ,
87+ ** kwargs
88+ }
89+ if temperature is not None :
90+ data ['temperature' ] = temperature
91+ if top_p is not None :
92+ data ['top_p' ] = top_p
93+ if max_tokens is not None :
94+ data ['max_tokens' ] = max_tokens
95+ for k in self .config .kwargs : # 默认的额外参数
96+ if k not in data and self .config .get (k ) is not None :
97+ data [k ] = self .config .get (k )
98+ return data
99+
100+ def _get_params (self ) -> Optional [Dict [str , str ]]:
101+ """合并默认查询参数和自定义参数"""
102+ return None
103+
104+ def _get_uri (self )-> str :
105+ """获取请求 URI"""
106+ return '/chat/completions'
107+
67108 def chat_completion (
68109 self ,
69110 # data 部分
@@ -73,8 +114,6 @@ def chat_completion(
73114 top_p : Optional [float ] = None ,
74115 max_tokens : Optional [int ] = None ,
75116 stream : bool = False ,
76- # 默认请求地址
77- uri : str = '/chat/completions' ,
78117 ** kwargs # 其他参数,传入给 data 部分
79118 ) -> Dict [str , Any ]:
80119 """
@@ -87,35 +126,21 @@ def chat_completion(
87126 top_p: top_p 参数
88127 max_tokens: 最大token数
89128 stream: 是否使用流式响应
90- uri: 请求 URI
91- headers: 自定义请求头(OpenAI 特有参数)
92- params: 自定义查询参数(Azure 特有参数)
93129 **kwargs: 其他参数
94130 """
95131 # 合并参数
96- model = model if model is not None else self .model
97- data = {
98- 'model' : model ,
99- "messages" : messages ,
132+ data = self ._prepare_data (
133+ messages = messages ,
134+ model = model ,
135+ temperature = temperature ,
136+ top_p = top_p ,
137+ max_tokens = max_tokens ,
100138 ** kwargs
101- }
102- if temperature is not None :
103- data ['temperature' ] = temperature
104- if top_p is not None :
105- data ['top_p' ] = top_p
106- if max_tokens is not None :
107- data ['max_tokens' ] = max_tokens
139+ )
108140 if stream :
109141 # NOTE: 流式响应请使用 chat_completion_stream_async
110142 pass
111- for k in self .config .kwargs : # 默认的额外参数
112- if k not in data and self .config .get (k ) is not None :
113- data [k ] = self .config .get (k )
114- headers = {
115- 'Authorization' : f'Bearer { self .api_key } ' ,
116- 'Content-Type' : 'application/json' ,
117- }
118- response = self .post (uri , data = data , headers = headers )
143+ response = self .post (self ._get_uri (), data = data , params = self ._get_params ())
119144 return response .json ()
120145
121146 async def async_chat_completion (
@@ -126,34 +151,22 @@ async def async_chat_completion(
126151 top_p : Optional [float ] = None ,
127152 max_tokens : Optional [int ] = None ,
128153 stream : bool = False ,
129- uri : str = '/chat/completions' ,
130154 ** kwargs
131155 ) -> Union [Dict [str , Any ], AsyncGenerator ]:
132156 """OpenAI Chat Completion API (异步版本)"""
133157 # 合并参数
134- model = model if model is not None else self .model
135- data = {
136- 'model' : model ,
137- "messages" : messages ,
158+ data = self ._prepare_data (
159+ messages = messages ,
160+ model = model ,
161+ temperature = temperature ,
162+ top_p = top_p ,
163+ max_tokens = max_tokens ,
138164 ** kwargs
139- }
140- if temperature is not None :
141- data ['temperature' ] = temperature
142- if top_p is not None :
143- data ['top_p' ] = top_p
144- if max_tokens is not None :
145- data ['max_tokens' ] = max_tokens
165+ )
146166 if stream :
147167 # NOTE: 流式响应请使用 chat_completion_stream_async
148168 pass
149- for k in self .config .kwargs : # 默认的额外参数
150- if k not in data and self .config .get (k ) is not None :
151- data [k ] = self .config .get (k )
152- headers = {
153- 'Authorization' : f'Bearer { self .api_key } ' ,
154- 'Content-Type' : 'application/json' ,
155- }
156- response = await self .async_post (uri , data = data , headers = headers )
169+ response = await self .async_post (self ._get_uri (), data = data , params = self ._get_params ())
157170 return response .json ()
158171
159172 async def chat_completion_stream_async (
@@ -163,39 +176,30 @@ async def chat_completion_stream_async(
163176 temperature : Optional [float ] = None ,
164177 top_p : Optional [float ] = None ,
165178 max_tokens : Optional [int ] = None ,
166- uri : str = '/chat/completions' ,
167179 stream : bool = True ,
168180 ** kwargs
169181 ):
170182 """OpenAI Chat Completion API 流式响应(异步版本)"""
171- data = {
172- 'model' : model if model is not None else self .model ,
173- "messages" : messages ,
174- 'stream' : True ,
183+ # 合并参数
184+ data = self ._prepare_data (
185+ messages = messages ,
186+ model = model ,
187+ temperature = temperature ,
188+ top_p = top_p ,
189+ max_tokens = max_tokens ,
190+ stream = True ,
175191 ** kwargs
176- }
177- if temperature is not None :
178- data ['temperature' ] = temperature
179- if top_p is not None :
180- data ['top_p' ] = top_p
181- if max_tokens is not None :
182- data ['max_tokens' ] = max_tokens
183- for k in self .config .kwargs : # 默认的额外参数
184- if k not in data and self .config .get (k ) is not None :
185- data [k ] = self .config .get (k )
186- headers = {
187- 'Authorization' : f'Bearer { self .api_key } ' ,
188- 'Content-Type' : 'application/json' ,
189- }
192+ )
190193 # 构建完整 URL
191- url = self ._build_url (uri )
194+ url = self ._build_url (self . _get_uri () )
192195 client = self ._get_async_client ()
193196
194197 async with client .stream (
195198 method = "POST" ,
196199 url = url ,
197200 json = data ,
198- headers = headers
201+ headers = self .config .headers ,
202+ params = self ._get_params ()
199203 ) as stream :
200204 # 检查响应状态码
201205 if stream .status_code >= 400 :
@@ -498,4 +502,79 @@ def __str__(self) -> str:
498502
499503class AzureChat (Chat ):
500504 """Azure OpenAI Chat 实现 - 继承 Chat 复用逻辑"""
501- pass
505+ def __init__ (self ,
506+ messages : Optional [Union [str , List [Dict [str , str ]]]]= None ,
507+ logger : logging .Logger = None ,
508+ api_key : str = None ,
509+ api_base : str = None ,
510+ api_version : str = None ,
511+ model : str = None ,
512+ timeout : float = 0 ,
513+ max_retries : int = 3 ,
514+ retry_delay : float = 1.0 ,
515+ headers : Optional [Dict [str , str ]]= None ,
516+ ** kwargs ):
517+ """初始化聊天模型
518+
519+ Args:
520+ logger: 日志记录器
521+ api_key: API 密钥
522+ api_base: API 基础 URL
523+ model: 模型名称
524+ timeout: 请求超时时间(秒)
525+ max_retries: 最大重试次数
526+ retry_delay: 重试延迟时间(秒)
527+ headers: 自定义 HTTP 头
528+ **kwargs: 其他配置参数,传入给 data 部分
529+ """
530+ logger = logger or setup_logger ('Chat' )
531+ if api_key is None :
532+ api_key = AZURE_OPENAI_API_KEY
533+ if model is None :
534+ model = AZURE_OPENAI_API_MODEL
535+ if api_base is None :
536+ api_base = AZURE_OPENAI_ENDPOINT
537+ if api_version is None :
538+ api_version = AZURE_OPENAI_API_VERSION
539+ if headers is None :
540+ headers = {
541+ 'Content-Type' : 'application/json'
542+ }
543+ HTTPClient .__init__ (
544+ self ,
545+ logger = logger ,
546+ api_base = api_base ,
547+ timeout = timeout ,
548+ max_retries = max_retries ,
549+ retry_delay = retry_delay ,
550+ headers = headers ,
551+ ** kwargs
552+ )
553+
554+ self .api_key = api_key
555+ self .api_version = api_version
556+ self .model = model
557+
558+ # 初始化对话历史
559+ if isinstance (messages , str ):
560+ messages = [{"role" : "user" , "content" : messages }]
561+ self ._chat_log : List [Dict ] = messages or []
562+ self ._last_response : Optional [ChatResponse ] = None
563+
564+ def _get_params (self ) -> Optional [Dict [str , str ]]:
565+ """合并默认查询参数和自定义参数"""
566+ params = {}
567+ if self .api_version :
568+ params ['api-version' ] = self .api_version
569+ if self .api_key :
570+ params ['ak' ] = self .api_key
571+ return params
572+
573+ def _get_uri (self ) -> str :
574+ """获取 API URI"""
575+ return None
576+
577+ def _generate_log_id (self , messages : List [Dict [str , str ]]) -> str :
578+ """生成请求的 log ID"""
579+ content = str (messages ).encode ()
580+ return hashlib .sha256 (content ).hexdigest ()
0 commit comments