-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Support text, JSON, XML and YAML DocumentUrl
and BinaryContent
on OpenAI
#2851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
5793b44
e60161e
c7d258c
544ae22
63edc4a
e98df12
8b00447
cf45451
f245f50
4780ec2
69a322a
2caee86
19e2ef4
11c4c4f
e848cbf
9222b86
b036b55
abb13ae
c1cc80e
2f1f0d5
fb63178
c0c5775
25f0286
bdbc30c
cbe96ef
a471dec
b01f4fc
5fa678a
f248dc7
c9d1be2
3d7464a
b6f0257
cd6458d
45bdd64
d429f6c
8e85d92
224458e
d9d44f7
afede9d
85c69bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -727,60 +727,112 @@ async def _map_user_message(self, message: ModelRequest) -> AsyncIterable[chat.C | |||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
async def _map_user_prompt(part: UserPromptPart) -> chat.ChatCompletionUserMessageParam: | ||||||||||||||
content: str | list[ChatCompletionContentPartParam] | ||||||||||||||
if isinstance(part.content, str): | ||||||||||||||
content = part.content | ||||||||||||||
return chat.ChatCompletionUserMessageParam(role='user', content=part.content) | ||||||||||||||
content_parts = await OpenAIChatModel._map_user_prompt_items(part.content) | ||||||||||||||
return chat.ChatCompletionUserMessageParam(role='user', content=content_parts) | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
async def _map_user_prompt_items(items: Sequence[object]) -> list[ChatCompletionContentPartParam]: | ||||||||||||||
result: list[ChatCompletionContentPartParam] = [] | ||||||||||||||
for item in items: | ||||||||||||||
result.extend(await OpenAIChatModel._map_single_item(item)) | ||||||||||||||
return result | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
async def _map_single_item(item: object) -> list[ChatCompletionContentPartParam]: | ||||||||||||||
if isinstance(item, str): | ||||||||||||||
return [ChatCompletionContentPartTextParam(text=item, type='text')] | ||||||||||||||
elif isinstance(item, ImageUrl): | ||||||||||||||
return OpenAIChatModel._handle_image_url(item) or [] | ||||||||||||||
elif isinstance(item, BinaryContent): | ||||||||||||||
return await OpenAIChatModel._handle_binary_content(item) or [] | ||||||||||||||
elif isinstance(item, AudioUrl): | ||||||||||||||
return await OpenAIChatModel._handle_audio_url(item) or [] | ||||||||||||||
elif isinstance(item, DocumentUrl): | ||||||||||||||
return await OpenAIChatModel._handle_document_url(item) or [] | ||||||||||||||
elif isinstance(item, VideoUrl): | ||||||||||||||
raise NotImplementedError('VideoUrl is not supported for OpenAI') | ||||||||||||||
else: | ||||||||||||||
content = [] | ||||||||||||||
for item in part.content: | ||||||||||||||
if isinstance(item, str): | ||||||||||||||
content.append(ChatCompletionContentPartTextParam(text=item, type='text')) | ||||||||||||||
elif isinstance(item, ImageUrl): | ||||||||||||||
image_url = ImageURL(url=item.url) | ||||||||||||||
content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) | ||||||||||||||
elif isinstance(item, BinaryContent): | ||||||||||||||
base64_encoded = base64.b64encode(item.data).decode('utf-8') | ||||||||||||||
if item.is_image: | ||||||||||||||
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') | ||||||||||||||
content.append(ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')) | ||||||||||||||
elif item.is_audio: | ||||||||||||||
assert item.format in ('wav', 'mp3') | ||||||||||||||
audio = InputAudio(data=base64_encoded, format=item.format) | ||||||||||||||
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')) | ||||||||||||||
elif item.is_document: | ||||||||||||||
content.append( | ||||||||||||||
File( | ||||||||||||||
file=FileFile( | ||||||||||||||
file_data=f'data:{item.media_type};base64,{base64_encoded}', | ||||||||||||||
filename=f'filename.{item.format}', | ||||||||||||||
), | ||||||||||||||
type='file', | ||||||||||||||
) | ||||||||||||||
) | ||||||||||||||
else: # pragma: no cover | ||||||||||||||
raise RuntimeError(f'Unsupported binary content type: {item.media_type}') | ||||||||||||||
elif isinstance(item, AudioUrl): | ||||||||||||||
downloaded_item = await download_item(item, data_format='base64', type_format='extension') | ||||||||||||||
assert downloaded_item['data_type'] in ( | ||||||||||||||
'wav', | ||||||||||||||
'mp3', | ||||||||||||||
), f'Unsupported audio format: {downloaded_item["data_type"]}' | ||||||||||||||
audio = InputAudio(data=downloaded_item['data'], format=downloaded_item['data_type']) | ||||||||||||||
content.append(ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')) | ||||||||||||||
elif isinstance(item, DocumentUrl): | ||||||||||||||
downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension') | ||||||||||||||
file = File( | ||||||||||||||
file=FileFile( | ||||||||||||||
file_data=downloaded_item['data'], filename=f'filename.{downloaded_item["data_type"]}' | ||||||||||||||
), | ||||||||||||||
type='file', | ||||||||||||||
) | ||||||||||||||
content.append(file) | ||||||||||||||
elif isinstance(item, VideoUrl): # pragma: no cover | ||||||||||||||
raise NotImplementedError('VideoUrl is not supported for OpenAI') | ||||||||||||||
else: | ||||||||||||||
assert_never(item) | ||||||||||||||
return chat.ChatCompletionUserMessageParam(role='user', content=content) | ||||||||||||||
# Fallback: unknown type — return empty parts to avoid type-checker Never error | ||||||||||||||
return [] | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
def _handle_image_url(item: ImageUrl) -> list[ChatCompletionContentPartParam]: | ||||||||||||||
image_url = ImageURL(url=item.url) | ||||||||||||||
return [ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')] | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
async def _handle_binary_content(item: BinaryContent) -> list[ChatCompletionContentPartParam]: | ||||||||||||||
if OpenAIChatModel._is_text_like_media_type(item.media_type): | ||||||||||||||
# Inline text-like binary content as a text block | ||||||||||||||
text = item.data.decode('utf-8') | ||||||||||||||
media_type = item.media_type | ||||||||||||||
inline = OpenAIChatModel._inline_file_block(media_type, text, identifier=item.identifier) | ||||||||||||||
return [ChatCompletionContentPartTextParam(text=inline, type='text')] | ||||||||||||||
base64_encoded = base64.b64encode(item.data).decode('utf-8') | ||||||||||||||
if item.is_image: | ||||||||||||||
image_url = ImageURL(url=f'data:{item.media_type};base64,{base64_encoded}') | ||||||||||||||
return [ChatCompletionContentPartImageParam(image_url=image_url, type='image_url')] | ||||||||||||||
if item.is_audio: | ||||||||||||||
assert item.format in ('wav', 'mp3') | ||||||||||||||
audio = InputAudio(data=base64_encoded, format=item.format) | ||||||||||||||
return [ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')] | ||||||||||||||
if item.is_document: | ||||||||||||||
return [ | ||||||||||||||
File( | ||||||||||||||
file=FileFile( | ||||||||||||||
file_data=f'data:{item.media_type};base64,{base64_encoded}', | ||||||||||||||
filename=f'filename.{item.format}', | ||||||||||||||
), | ||||||||||||||
type='file', | ||||||||||||||
) | ||||||||||||||
] | ||||||||||||||
return [] | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
async def _handle_audio_url(item: AudioUrl) -> list[ChatCompletionContentPartParam]: | ||||||||||||||
downloaded_item = await download_item(item, data_format='base64', type_format='extension') | ||||||||||||||
assert downloaded_item['data_type'] in ('wav', 'mp3'), ( | ||||||||||||||
f'Unsupported audio format: {downloaded_item["data_type"]}' | ||||||||||||||
) | ||||||||||||||
audio = InputAudio(data=downloaded_item['data'], format=downloaded_item['data_type']) | ||||||||||||||
return [ChatCompletionContentPartInputAudioParam(input_audio=audio, type='input_audio')] | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
async def _handle_document_url(item: DocumentUrl) -> list[ChatCompletionContentPartParam]: | ||||||||||||||
if OpenAIChatModel._is_text_like_media_type(item.media_type): | ||||||||||||||
downloaded_text = await download_item(item, data_format='text', type_format='extension') | ||||||||||||||
inline = OpenAIChatModel._inline_file_block( | ||||||||||||||
item.media_type, downloaded_text['data'], identifier=item.identifier | ||||||||||||||
) | ||||||||||||||
return [ChatCompletionContentPartTextParam(text=inline, type='text')] | ||||||||||||||
downloaded_item = await download_item(item, data_format='base64_uri', type_format='extension') | ||||||||||||||
return [ | ||||||||||||||
File( | ||||||||||||||
file=FileFile( | ||||||||||||||
file_data=downloaded_item['data'], | ||||||||||||||
filename=f'filename.{downloaded_item["data_type"]}', | ||||||||||||||
), | ||||||||||||||
type='file', | ||||||||||||||
) | ||||||||||||||
] | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
def _is_text_like_media_type(media_type: str) -> bool: | ||||||||||||||
return ( | ||||||||||||||
media_type.startswith('text/') | ||||||||||||||
or media_type == 'application/json' | ||||||||||||||
or media_type.endswith('+json') | ||||||||||||||
or media_type == 'application/xml' | ||||||||||||||
or media_type.endswith('+xml') | ||||||||||||||
or media_type in ('application/x-yaml', 'application/yaml') | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
@staticmethod | ||||||||||||||
def _inline_file_block(media_type: str, text: str, identifier: str | None) -> str: | ||||||||||||||
id_attr = f' id="{identifier}"' if identifier else '' | ||||||||||||||
|
||||||||||||||
return ''.join(['-----BEGIN FILE', id_attr, ' type="', media_type, '"-----\n', text, '\n-----END FILE-----']) | ||||||||||||||
|
return ''.join(['-----BEGIN FILE', id_attr, ' type="', media_type, '"-----\n', text, '\n-----END FILE-----']) | |
return '\n'.join([ | |
f'-----BEGIN FILE{id_attr} type="{media_type}"-----', | |
text, | |
f'-----END FILE {id_attr}-----', | |
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
ToolCallPart, | ||
ToolReturnPart, | ||
UserPromptPart, | ||
VideoUrl, | ||
) | ||
from pydantic_ai.models import ModelRequestParameters | ||
from pydantic_ai.output import NativeOutput, PromptedOutput, TextOutput, ToolOutput | ||
|
@@ -822,6 +823,104 @@ async def test_document_url_input(allow_model_requests: None, openai_api_key: st | |
assert result.output == snapshot('The document contains the text "Dummy PDF file" on its single page.') | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'media_type,body', | ||
[ | ||
('text/plain', 'Hello'), | ||
('application/json', '{"a":1}'), | ||
('application/xml', '<a>1</a>'), | ||
('application/yaml', 'a: 1'), | ||
], | ||
) | ||
async def test_openai_binary_content_text_like_is_inlined( | ||
DouweM marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
media_type: str, body: str, openai_api_key: str, allow_model_requests: None | ||
) -> None: | ||
# Arrange input | ||
bin_content = BinaryContent(data=body.encode(), media_type=media_type) | ||
identifier = bin_content.identifier | ||
|
||
# Capture mapped OpenAI messages via public request() API | ||
captured: list[list[dict[str, Any]]] = [] | ||
|
||
async def fake_create(*args: Any, **kwargs: Any): | ||
captured.append(kwargs['messages']) | ||
raise RuntimeError('stop-after-capture') | ||
|
||
model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) | ||
# Monkeypatch the client's create method | ||
model.client.chat.completions.create = fake_create | ||
|
||
|
||
# Act | ||
with pytest.raises(RuntimeError, match='stop-after-capture'): | ||
await model.request([ModelRequest(parts=[UserPromptPart(content=[bin_content])])], {}, ModelRequestParameters()) | ||
|
||
# Assert on the mapped user message content | ||
user_msgs = captured[0] | ||
# Find the user message | ||
user = next(m for m in user_msgs if m.get('role') == 'user') | ||
parts = cast(list[dict[str, Any]], user['content']) | ||
assert parts[0]['type'] == 'text' | ||
text = parts[0]['text'] | ||
assert text.startswith(f'-----BEGIN FILE id="{identifier}" type="{media_type}"-----') | ||
assert text.rstrip().endswith('-----END FILE-----') | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'url,media_type,data_type,body', | ||
[ | ||
('https://example.com/file.txt', 'text/plain', 'txt', 'hello'), | ||
('https://example.com/data.csv', 'text/csv', 'csv', 'a,b\n1,2'), | ||
('https://example.com/data.json', 'application/json', 'json', '{"a":1}'), | ||
('https://example.com/data.xml', 'application/xml', 'xml', '<a>1</a>'), | ||
('https://example.com/readme.md', 'text/markdown', 'md', '# Title'), | ||
('https://example.com/conf.yaml', 'application/yaml', 'yaml', 'a: 1'), | ||
], | ||
) | ||
async def test_openai_document_url_text_like_is_inlined( | ||
monkeypatch: pytest.MonkeyPatch, | ||
url: str, | ||
media_type: str, | ||
data_type: str, | ||
body: str, | ||
openai_api_key: str, | ||
allow_model_requests: None, | ||
) -> None: | ||
async def fake_download_item( | ||
item: Any, data_format: str = 'text', type_format: str = 'extension' | ||
) -> dict[str, str]: | ||
assert data_format == 'text' | ||
return {'data': body, 'data_type': data_type} | ||
|
||
monkeypatch.setattr('pydantic_ai.models.openai.download_item', fake_download_item) | ||
|
||
document_url = DocumentUrl(url=url, media_type=media_type) | ||
identifier = document_url.identifier | ||
|
||
captured: list[list[dict[str, Any]]] = [] | ||
|
||
async def fake_create(*args: Any, **kwargs: Any): | ||
captured.append(kwargs['messages']) | ||
raise RuntimeError('stop-after-capture') | ||
|
||
model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) | ||
model.client.chat.completions.create = fake_create | ||
|
||
with pytest.raises(RuntimeError, match='stop-after-capture'): | ||
await model.request( | ||
[ModelRequest(parts=[UserPromptPart(content=[document_url])])], | ||
{}, | ||
ModelRequestParameters(), | ||
) | ||
|
||
user_msgs = captured[0] | ||
user = next(m for m in user_msgs if m.get('role') == 'user') | ||
parts = cast(list[dict[str, Any]], user['content']) | ||
assert parts[0]['type'] == 'text' | ||
text = parts[0]['text'] | ||
assert text.startswith(f'-----BEGIN FILE id="{identifier}" type="{media_type}"-----') | ||
assert text.rstrip().endswith('-----END FILE-----') | ||
|
||
|
||
|
||
@pytest.mark.vcr() | ||
async def test_image_url_tool_response(allow_model_requests: None, openai_api_key: str): | ||
m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) | ||
|
@@ -2929,3 +3028,64 @@ def test_deprecated_openai_model(openai_api_key: str): | |
|
||
provider = OpenAIProvider(api_key=openai_api_key) | ||
OpenAIModel('gpt-4o', provider=provider) # type: ignore[reportDeprecated] | ||
|
||
|
||
@pytest.mark.vcr() | ||
async def test_openai_video_url_raises_not_implemented(openai_api_key: str, allow_model_requests: None) -> None: | ||
model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) | ||
with pytest.raises(NotImplementedError): | ||
await model.request( | ||
[ModelRequest(parts=[UserPromptPart(content=[VideoUrl(url='https://example.com/file.mp4')])])], | ||
{}, | ||
ModelRequestParameters(), | ||
) | ||
|
||
|
||
async def test_openai_map_single_item_unknown_returns_empty_branch( | ||
|
||
openai_api_key: str, allow_model_requests: None | ||
) -> None: | ||
# Use BinaryContent with unsupported media_type to exercise empty mapping via public API | ||
|
||
captured: list[list[dict[str, Any]]] = [] | ||
|
||
async def fake_create(*args: Any, **kwargs: Any): | ||
captured.append(kwargs['messages']) | ||
raise RuntimeError('stop-after-capture') | ||
|
||
model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) | ||
model.client.chat.completions.create = fake_create | ||
|
||
bc = BinaryContent(data=b'data', media_type='application/octet-stream') | ||
with pytest.raises(RuntimeError, match='stop-after-capture'): | ||
await model.request([ModelRequest(parts=[UserPromptPart(content=[bc])])], {}, ModelRequestParameters()) | ||
|
||
user_msgs = captured[0] | ||
user = next(m for m in user_msgs if m.get('role') == 'user') | ||
parts = cast(list[dict[str, Any]], user['content']) | ||
assert parts == [] | ||
|
||
|
||
async def test_openai_binary_content_unsupported_type(openai_api_key: str, allow_model_requests: None) -> None: | ||
# Covers BinaryContent unsupported path (not text-like, not image/audio/document) via public API | ||
|
||
captured: list[list[dict[str, Any]]] = [] | ||
|
||
async def fake_create(*args: Any, **kwargs: Any): | ||
captured.append(kwargs['messages']) | ||
raise RuntimeError('stop-after-capture') | ||
|
||
model = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) | ||
model.client.chat.completions.create = fake_create | ||
|
||
class Location(TypedDict): | ||
city: str | ||
country: str | ||
|
||
unsupported = Location(city='Paris', country='France') | ||
with pytest.raises(RuntimeError, match='stop-after-capture'): | ||
await model.request([ModelRequest(parts=[UserPromptPart(content=[unsupported])])], {}, ModelRequestParameters()) # type: ignore[reportPrivateUsage] | ||
|
||
user_msgs = captured[0] | ||
user = next(m for m in user_msgs if m.get('role') == 'user') | ||
parts = cast(list[dict[str, Any]], user['content']) | ||
assert parts == [] |
Uh oh!
There was an error while loading. Please reload this page.