66from typing import Optional
77
88from example_types import MessageList , SamplerBase
9- from zai import ZhipuAiClient
9+ from zai import ZhipuAiClient , ZaiClient
1010
1111
1212class ZaiSampler (SamplerBase ):
@@ -27,57 +27,49 @@ def __init__(
2727 self .temperature = temperature
2828 self .max_tokens = max_tokens
2929 self .model = model
30- self .client = ZhipuAiClient (api_key = api_key )
30+ self .client = ZaiClient (api_key = api_key )
3131 self .stream = stream
3232
3333 def get_resp (self , message_list ):
34- for _ in range (3 ):
35- try :
36- chat_completion = self .client .chat .completions .create (
37- messages = message_list ,
38- model = self .model ,
39- temperature = self .temperature ,
40- top_p = self .top_p ,
41- max_tokens = self .max_tokens
42- )
43- output = chat_completion .choices [0 ].message .content
44- return output
45- except Exception as e :
46- print (f"Exception: { e } \n Traceback: { traceback .format_exc ()} " )
47- time .sleep (1 )
48- continue
49- print (f"failed, last exception: { e if 'e' in locals () else '' } " )
50- return ''
34+ try :
35+ chat_completion = self .client .chat .completions .create (
36+ messages = message_list ,
37+ model = self .model ,
38+ temperature = self .temperature ,
39+ top_p = self .top_p ,
40+ max_tokens = self .max_tokens
41+ )
42+ output = chat_completion .choices [0 ].message .content
43+ return output
44+ except Exception as e :
45+ print (f"Exception: { e } \n Traceback: { traceback .format_exc ()} " )
46+ raise
5147
5248
5349 def get_resp_stream (self , message_list , top_p = - 1 , temperature = - 1 ):
5450 temperature = temperature if temperature > 0 else self .temperature
5551 top_p = top_p if top_p > 0 else 0.95
5652 final = ''
57- for _ in range (200 ):
58- try :
59- chat_completion_res = self .client .chat .completions .create (
60- model = self .model ,
61- messages = message_list ,
62- thinking = {
63- "type" : "enabled" ,
64- },
65- stream = True ,
66- max_tokens = self .max_tokens ,
67- temperature = temperature
68- )
69- for chunk in chat_completion_res :
70- if chunk .choices [0 ].delta .content :
71- final += chunk .choices [0 ].delta .content
72- break
73- except Exception as e :
74- final = ""
75- print (f"Exception: { e } \n Traceback: { traceback .format_exc ()} " )
76- time .sleep (5 )
77- continue
78-
53+ try :
54+ chat_completion_res = self .client .chat .completions .create (
55+ model = self .model ,
56+ messages = message_list ,
57+ thinking = {
58+ "type" : "enabled" ,
59+ },
60+ stream = True ,
61+ max_tokens = self .max_tokens ,
62+ temperature = temperature
63+ )
64+ for chunk in chat_completion_res :
65+ if chunk .choices [0 ].delta .content :
66+ final += chunk .choices [0 ].delta .content
67+ except Exception as e :
68+ print (f"Exception: { e } \n Traceback: { traceback .format_exc ()} " )
69+ raise
70+
7971 if final == '' :
80- print (f"failed in get_resp for 50 times, last exception: { e if 'e' in locals () else '' } " )
72+ print (f"failed in get_resp, no content received " )
8173 return ''
8274
8375 content = ''
@@ -105,9 +97,13 @@ def __call__(self, message_list: MessageList, top_p=0.95, temperature=0.6) -> st
10597
10698
10799if __name__ == "__main__" :
108- client = ZaiSampler (model = "glm-4.5" , api_key = os .getenv ("ZAI_API_KEY" ), stream = True )
109- messages = [
110- {"role" : "user" , "content" : "Hi?" },
111- ]
112- response = client (messages )
113- print (response )
100+ try :
101+ client = ZaiSampler (model = "glm-4.5" , api_key = os .getenv ("ZAI_API_KEY" ), stream = True )
102+ messages = [
103+ {"role" : "user" , "content" : "Hi? Tell me a joke." },
104+ ]
105+ response = client (messages )
106+ print (response )
107+ except Exception as e :
108+ print (f"Fatal error: { e } \n Traceback: { traceback .format_exc ()} " )
109+ sys .exit (1 )
0 commit comments