Skip to content

Commit 82e09d6

Browse files
wellenzhengzhengweijun
andauthored
feat: modeldump and usage support (#25)
Co-authored-by: zhengweijun <[email protected]>
1 parent 49ca5af commit 82e09d6

File tree

15 files changed

+363
-281
lines changed

15 files changed

+363
-281
lines changed

examples/function_call_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def parse_function_call(model_response, messages):
4141
tools=tools,
4242
)
4343
print(response.choices[0].message)
44-
messages.append(response.choices[0].message.model_dump())
44+
messages.append(response.choices[0].message)
4545

4646
messages = []
4747
tools = [
@@ -104,7 +104,7 @@ def parse_function_call(model_response, messages):
104104
tools=tools,
105105
)
106106
print(response.choices[0].message)
107-
messages.append(response.choices[0].message.model_dump())
107+
messages.append(response.choices[0].message)
108108

109109
parse_function_call(response, messages)
110110

@@ -115,6 +115,6 @@ def parse_function_call(model_response, messages):
115115
tools=tools,
116116
)
117117
print(response.choices[0].message)
118-
messages.append(response.choices[0].message.model_dump())
118+
messages.append(response.choices[0].message)
119119

120120
parse_function_call(response, messages)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "zai-sdk"
3-
version = "0.0.3"
3+
version = "0.0.3.1"
44
description = "A SDK library for accessing big model apis from Z.ai"
55
authors = ["Z.ai"]
66
readme = "README.md"

src/zai/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
__title__ = 'Z.ai'
2-
__version__ = '0.0.3'
2+
__version__ = '0.0.3.1'

src/zai/api_resource/chat/completions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
make_request_options,
1919
maybe_transform,
2020
)
21+
from zai.core._base_models import BaseModel
2122
from zai.types.chat.chat_completion import Completion
2223
from zai.types.chat.chat_completion_chunk import ChatCompletionChunk
2324
from zai.types.chat.code_geex import code_geex_params
@@ -112,7 +113,9 @@ def create(
112113
logger.debug(f'temperature:{temperature}, top_p:{top_p}')
113114
if isinstance(messages, List):
114115
for item in messages:
115-
if item.get('content'):
116+
if isinstance(item, BaseModel) and hasattr(item, 'content'):
117+
item.content = drop_prefix_image_data(item.content)
118+
elif isinstance(item, dict) and item.get('content'):
116119
item['content'] = drop_prefix_image_data(item['content'])
117120

118121
body = deepcopy_minimal(

src/zai/core/_base_models.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,64 @@ def to_json(
160160
warnings=warnings,
161161
)
162162

163+
def get(self, key: str, default: Any = None) -> Any:
164+
"""Get the value of an attribute by name, with an optional default value.
165+
166+
This method allows you to access model attributes by their string name,
167+
similar to how dict.get() works.
168+
169+
Args:
170+
key: The name of the attribute to get
171+
default: The value to return if the attribute doesn't exist or is None.
172+
Defaults to None.
173+
174+
Returns:
175+
The value of the attribute if it exists, otherwise the default value.
176+
177+
Examples:
178+
>>> model = MyModel(name="test", age=25)
179+
>>> model.get("name") # Returns "test"
180+
>>> model.get("nonexistent") # Returns None
181+
>>> model.get("nonexistent", "default") # Returns "default"
182+
"""
183+
try:
184+
value = getattr(self, key)
185+
return value if value is not None else default
186+
except AttributeError:
187+
return default
188+
189+
def __json__(self) -> dict[str, Any]:
190+
"""Custom JSON serialization method.
191+
192+
This method is called by JSON encoders that support the __json__ protocol,
193+
making BaseModel objects directly serializable without requiring model_dump().
194+
195+
Returns:
196+
A dictionary representation of the model suitable for JSON serialization.
197+
"""
198+
return self.model_dump(by_alias=True, exclude_unset=True)
199+
200+
def __reduce_ex__(self, protocol):
201+
"""Support for pickle serialization by returning the dict representation."""
202+
return (self.__class__.model_validate, (self.model_dump(by_alias=True, exclude_unset=True),))
203+
204+
def __iter__(self):
205+
"""Make BaseModel iterable to support dict() conversion."""
206+
data = self.model_dump(by_alias=True, exclude_unset=True)
207+
return iter(data.items())
208+
209+
def keys(self):
210+
"""Return keys for dict-like interface."""
211+
return self.model_dump(by_alias=True, exclude_unset=True).keys()
212+
213+
def values(self):
214+
"""Return values for dict-like interface."""
215+
return self.model_dump(by_alias=True, exclude_unset=True).values()
216+
217+
def items(self):
218+
"""Return items for dict-like interface."""
219+
return self.model_dump(by_alias=True, exclude_unset=True).items()
220+
163221
@override
164222
def __str__(self) -> str:
165223
# mypy complains about an invalid self arg

src/zai/core/_http_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from ._response import APIResponse, BaseAPIResponse, extract_response_type
6666
from ._streaming import StreamResponse
6767
from ._utils import flatten, is_given, is_mapping
68+
from ._json_encoder import json_dumps
6869

6970
log: logging.Logger = logging.getLogger(__name__)
7071

@@ -355,6 +356,9 @@ def _build_request(self, options: FinalRequestOptions) -> httpx.Request:
355356
else:
356357
raise RuntimeError(f'Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`')
357358

359+
# Convert BaseModel objects to dicts before passing to httpx
360+
json_data = self._prepare_json_data(json_data)
361+
358362
content_type = headers.get('Content-Type')
359363
# multipart/form-data; boundary=---abc--
360364
if headers.get('Content-Type') == 'multipart/form-data':
@@ -377,6 +381,24 @@ def _build_request(self, options: FinalRequestOptions) -> httpx.Request:
377381
**kwargs,
378382
)
379383

384+
def _prepare_json_data(self, json_data: Any) -> Any:
385+
"""Prepare JSON data for httpx by converting BaseModel objects to dicts."""
386+
from ._base_models import BaseModel
387+
388+
if json_data is None:
389+
return None
390+
391+
if isinstance(json_data, BaseModel):
392+
return json_data.model_dump(by_alias=True, exclude_unset=True)
393+
394+
if isinstance(json_data, list):
395+
return [self._prepare_json_data(item) for item in json_data]
396+
397+
if isinstance(json_data, dict):
398+
return {key: self._prepare_json_data(value) for key, value in json_data.items()}
399+
400+
return json_data
401+
380402
def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
381403
items = []
382404

src/zai/core/_json_encoder.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""JSON encoding utilities for BaseModel objects."""
2+
3+
import json
4+
from typing import Any
5+
6+
from ._base_models import BaseModel
7+
8+
9+
class ZAIJSONEncoder(json.JSONEncoder):
10+
"""Custom JSON encoder that handles BaseModel objects."""
11+
12+
def default(self, obj: Any) -> Any:
13+
"""Override default method to handle BaseModel objects."""
14+
if isinstance(obj, BaseModel):
15+
return obj.model_dump(by_alias=True, exclude_unset=True)
16+
return super().default(obj)
17+
18+
19+
def json_dumps(obj: Any, **kwargs) -> str:
20+
"""
21+
JSON dumps with support for BaseModel objects.
22+
23+
Args:
24+
obj: Object to serialize
25+
**kwargs: Additional arguments to pass to json.dumps()
26+
27+
Returns:
28+
JSON string representation
29+
"""
30+
return json.dumps(obj, cls=ZAIJSONEncoder, **kwargs)
31+
32+
33+
def json_loads(s: str, **kwargs) -> Any:
34+
"""
35+
JSON loads with consistent interface.
36+
37+
Args:
38+
s: JSON string to deserialize
39+
**kwargs: Additional arguments to pass to json.loads()
40+
41+
Returns:
42+
Deserialized object
43+
"""
44+
return json.loads(s, **kwargs)

src/zai/types/chat/chat_completion_chunk.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,26 @@ class Choice(BaseModel):
9494
finish_reason: Optional[str] = None
9595
index: int
9696

97+
class PromptTokensDetails(BaseModel):
98+
"""
99+
Detailed breakdown of token usage for the input prompt
100+
101+
Attributes:
102+
cached_tokens: Number of tokens reused from cache
103+
"""
104+
105+
cached_tokens: int
106+
107+
108+
class CompletionTokensDetails(BaseModel):
109+
"""
110+
Detailed breakdown of token usage for the model completion
111+
112+
Attributes:
113+
reasoning_tokens: Number of tokens used for reasoning steps
114+
"""
115+
116+
reasoning_tokens: int
97117

98118
class CompletionUsage(BaseModel):
99119
"""
@@ -106,7 +126,9 @@ class CompletionUsage(BaseModel):
106126
"""
107127

108128
prompt_tokens: int
129+
prompt_tokens_details: Optional[PromptTokensDetails] = None
109130
completion_tokens: int
131+
completion_tokens_details: Optional[CompletionTokensDetails] = None
110132
total_tokens: int
111133

112134

tests/integration_tests/asr.wav

346 KB
Binary file not shown.

tests/integration_tests/asr1.wav

-85 KB
Binary file not shown.

0 commit comments

Comments
 (0)