Skip to content

Commit 49ca5af

Browse files
wellenzhengzhengweijun
andauthored
chore: update version to 0.0.3 and refactor ZaiSampler error handling (#24)
Co-authored-by: zhengweijun <[email protected]>
1 parent 47dc7ca commit 49ca5af

File tree

4 files changed

+53
-50
lines changed

4 files changed

+53
-50
lines changed

examples/glm4_5_thinking.py

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Optional
77

88
from example_types import MessageList, SamplerBase
9-
from zai import ZhipuAiClient
9+
from zai import ZhipuAiClient, ZaiClient
1010

1111

1212
class ZaiSampler(SamplerBase):
@@ -27,57 +27,49 @@ def __init__(
2727
self.temperature = temperature
2828
self.max_tokens = max_tokens
2929
self.model = model
30-
self.client = ZhipuAiClient(api_key=api_key)
30+
self.client = ZaiClient(api_key=api_key)
3131
self.stream = stream
3232

3333
def get_resp(self, message_list):
34-
for _ in range(3):
35-
try:
36-
chat_completion = self.client.chat.completions.create(
37-
messages=message_list,
38-
model=self.model,
39-
temperature=self.temperature,
40-
top_p=self.top_p,
41-
max_tokens=self.max_tokens
42-
)
43-
output = chat_completion.choices[0].message.content
44-
return output
45-
except Exception as e:
46-
print(f"Exception: {e}\nTraceback: {traceback.format_exc()}")
47-
time.sleep(1)
48-
continue
49-
print(f"failed, last exception: {e if 'e' in locals() else ''}")
50-
return ''
34+
try:
35+
chat_completion = self.client.chat.completions.create(
36+
messages=message_list,
37+
model=self.model,
38+
temperature=self.temperature,
39+
top_p=self.top_p,
40+
max_tokens=self.max_tokens
41+
)
42+
output = chat_completion.choices[0].message.content
43+
return output
44+
except Exception as e:
45+
print(f"Exception: {e}\nTraceback: {traceback.format_exc()}")
46+
raise
5147

5248

5349
def get_resp_stream(self, message_list, top_p=-1, temperature=-1):
5450
temperature = temperature if temperature > 0 else self.temperature
5551
top_p = top_p if top_p > 0 else 0.95
5652
final = ''
57-
for _ in range(200):
58-
try:
59-
chat_completion_res = self.client.chat.completions.create(
60-
model=self.model,
61-
messages=message_list,
62-
thinking={
63-
"type": "enabled",
64-
},
65-
stream=True,
66-
max_tokens=self.max_tokens,
67-
temperature=temperature
68-
)
69-
for chunk in chat_completion_res:
70-
if chunk.choices[0].delta.content:
71-
final += chunk.choices[0].delta.content
72-
break
73-
except Exception as e:
74-
final = ""
75-
print(f"Exception: {e}\nTraceback: {traceback.format_exc()}")
76-
time.sleep(5)
77-
continue
78-
53+
try:
54+
chat_completion_res = self.client.chat.completions.create(
55+
model=self.model,
56+
messages=message_list,
57+
thinking={
58+
"type": "enabled",
59+
},
60+
stream=True,
61+
max_tokens=self.max_tokens,
62+
temperature=temperature
63+
)
64+
for chunk in chat_completion_res:
65+
if chunk.choices[0].delta.content:
66+
final += chunk.choices[0].delta.content
67+
except Exception as e:
68+
print(f"Exception: {e}\nTraceback: {traceback.format_exc()}")
69+
raise
70+
7971
if final == '':
80-
print(f"failed in get_resp for 50 times, last exception: {e if 'e' in locals() else ''}")
72+
print(f"failed in get_resp, no content received")
8173
return ''
8274

8375
content = ''
@@ -105,9 +97,13 @@ def __call__(self, message_list: MessageList, top_p=0.95, temperature=0.6) -> st
10597

10698

10799
if __name__ == "__main__":
108-
client = ZaiSampler(model="glm-4.5", api_key=os.getenv("ZAI_API_KEY"), stream=True)
109-
messages = [
110-
{"role": "user", "content": "Hi?"},
111-
]
112-
response = client(messages)
113-
print(response)
100+
try:
101+
client = ZaiSampler(model="glm-4.5", api_key=os.getenv("ZAI_API_KEY"), stream=True)
102+
messages = [
103+
{"role": "user", "content": "Hi? Tell me a joke."},
104+
]
105+
response = client(messages)
106+
print(response)
107+
except Exception as e:
108+
print(f"Fatal error: {e}\nTraceback: {traceback.format_exc()}")
109+
sys.exit(1)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "zai-sdk"
3-
version = "0.0.2"
3+
version = "0.0.3"
44
description = "A SDK library for accessing big model apis from Z.ai"
55
authors = ["Z.ai"]
66
readme = "README.md"

src/zai/_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,13 @@ class ZaiClient(BaseClient):
218218
def default_base_url(self):
219219
return 'https://api.z.ai/api/paas/v4'
220220

221+
@property
222+
@override
223+
def auth_headers(self) -> dict[str, str]:
224+
headers = super().auth_headers
225+
headers['Accept-Language'] = 'en-US,en'
226+
return headers
227+
221228

222229
class ZhipuAiClient(BaseClient):
223230
@property

src/zai/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
__title__ = 'Z.ai'
2-
__version__ = '0.0.2'
2+
__version__ = '0.0.3'

0 commit comments

Comments
 (0)