diff --git a/pydantic_ai_slim/pydantic_ai/_otel_messages.py b/pydantic_ai_slim/pydantic_ai/_otel_messages.py new file mode 100644 index 000000000..62e03a590 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/_otel_messages.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import Literal + +from pydantic import JsonValue +from typing_extensions import NotRequired, TypeAlias, TypedDict + + +class TextPart(TypedDict): + type: Literal['text'] + content: NotRequired[str] + + +class ToolCallPart(TypedDict): + type: Literal['tool_call'] + id: str + name: str + arguments: NotRequired[JsonValue] + + +class ToolCallResponsePart(TypedDict): + type: Literal['tool_call_response'] + id: str + name: str + result: NotRequired[JsonValue] + + +class MediaUrlPart(TypedDict): + type: Literal['image-url', 'audio-url', 'video-url', 'document-url'] + url: NotRequired[str] + + +class BinaryDataPart(TypedDict): + type: Literal['binary'] + media_type: str + binary_content: NotRequired[str] + + +class ThinkingPart(TypedDict): + type: Literal['thinking'] + content: NotRequired[str] + + +MessagePart: TypeAlias = 'TextPart | ToolCallPart | ToolCallResponsePart | MediaUrlPart | BinaryDataPart | ThinkingPart' + + +Role = Literal['system', 'user', 'assistant'] + + +class ChatMessage(TypedDict): + role: Role + parts: list[MessagePart] + + +InputMessages: TypeAlias = list[ChatMessage] + + +class OutputMessage(ChatMessage): + finish_reason: NotRequired[str] + + +OutputMessages: TypeAlias = list[OutputMessage] diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 28447187e..5ccb6e84e 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -13,7 +13,7 @@ from opentelemetry._events import Event # pyright: ignore[reportPrivateImportUsage] from typing_extensions import TypeAlias, deprecated -from . import _utils +from . import _otel_messages, _utils from ._utils import ( generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, @@ -82,6 +82,9 @@ def otel_event(self, settings: InstrumentationSettings) -> Event: body={'role': 'system', **({'content': self.content} if settings.include_content else {})}, ) + def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: + return [_otel_messages.TextPart(type='text', **{'content': self.content} if settings.include_content else {})] + __repr__ = _utils.dataclasses_no_defaults_repr @@ -504,25 +507,38 @@ class UserPromptPart: """Part type identifier, this is available on all parts as a discriminator.""" def otel_event(self, settings: InstrumentationSettings) -> Event: - content: str | list[dict[str, Any] | str] | dict[str, Any] - if isinstance(self.content, str): - content = self.content if settings.include_content else {'kind': 'text'} - else: - content = [] - for part in self.content: - if isinstance(part, str): - content.append(part if settings.include_content else {'kind': 'text'}) - elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)): - content.append({'kind': part.kind, **({'url': part.url} if settings.include_content else {})}) - elif isinstance(part, BinaryContent): - converted_part = {'kind': part.kind, 'media_type': part.media_type} - if settings.include_content and settings.include_binary_content: - converted_part['binary_content'] = base64.b64encode(part.data).decode() - content.append(converted_part) - else: - content.append({'kind': part.kind}) # pragma: no cover + content = [{'kind': part.pop('type'), **part} for part in self.otel_message_parts(settings)] + content = [ + part['content'] if part == {'kind': 'text', 'content': part.get('content')} else part for part in content + ] + if content in ([{'kind': 'text'}], [self.content]): + content = content[0] return Event('gen_ai.user.message', body={'content': content, 'role': 'user'}) + def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: + parts: list[_otel_messages.MessagePart] = [] + content = [self.content] if isinstance(self.content, str) else self.content + for part in content: + if isinstance(part, str): + parts.append( + _otel_messages.TextPart(type='text', **({'content': part} if settings.include_content else {})) + ) + elif isinstance(part, (ImageUrl, AudioUrl, DocumentUrl, VideoUrl)): + parts.append( + _otel_messages.MediaUrlPart( + type=part.kind, + **{'url': part.url} if settings.include_content else {}, + ) + ) + elif isinstance(part, BinaryContent): + converted_part = _otel_messages.BinaryDataPart(type='binary', media_type=part.media_type) + if settings.include_content and settings.include_binary_content: + converted_part['binary_content'] = base64.b64encode(part.data).decode() + parts.append(converted_part) + else: + parts.append({'type': part.kind}) # pragma: no cover + return parts + __repr__ = _utils.dataclasses_no_defaults_repr @@ -576,6 +592,18 @@ def otel_event(self, settings: InstrumentationSettings) -> Event: }, ) + def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: + from .models.instrumented import InstrumentedModel + + return [ + _otel_messages.ToolCallResponsePart( + type='tool_call_response', + id=self.tool_call_id, + name=self.tool_name, + **({'result': InstrumentedModel.serialize_any(self.content)} if settings.include_content else {}), + ) + ] + def has_content(self) -> bool: """Return `True` if the tool return has content.""" return self.content is not None # pragma: no cover @@ -669,6 +697,19 @@ def otel_event(self, settings: InstrumentationSettings) -> Event: }, ) + def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: + if self.tool_name is None: + return [_otel_messages.TextPart(type='text', content=self.model_response())] + else: + return [ + _otel_messages.ToolCallResponsePart( + type='tool_call_response', + id=self.tool_call_id, + name=self.tool_name, + **({'result': self.model_response()} if settings.include_content else {}), + ) + ] + __repr__ = _utils.dataclasses_no_defaults_repr @@ -894,6 +935,36 @@ def new_event_body(): return result + def otel_message_parts(self, settings: InstrumentationSettings) -> list[_otel_messages.MessagePart]: + parts: list[_otel_messages.MessagePart] = [] + for part in self.parts: + if isinstance(part, TextPart): + parts.append( + _otel_messages.TextPart( + type='text', + **({'content': part.content} if settings.include_content else {}), + ) + ) + elif isinstance(part, ThinkingPart): + parts.append( + _otel_messages.ThinkingPart( + type='thinking', + **({'content': part.content} if settings.include_content else {}), + ) + ) + elif isinstance(part, ToolCallPart): + call_part = _otel_messages.ToolCallPart(type='tool_call', id=part.tool_call_id, name=part.tool_name) + if settings.include_content and part.args is not None: + from .models.instrumented import InstrumentedModel + + if isinstance(part.args, str): + call_part['arguments'] = part.args + else: + call_part['arguments'] = {k: InstrumentedModel.serialize_any(v) for k, v in part.args.items()} + + parts.append(call_part) + return parts + @property @deprecated('`vendor_details` is deprecated, use `provider_details` instead') def vendor_details(self) -> dict[str, Any] | None: diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index b7bf965b9..bbaca7179 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -1,10 +1,11 @@ from __future__ import annotations +import itertools import json from collections.abc import AsyncIterator, Iterator, Mapping from contextlib import asynccontextmanager, contextmanager from dataclasses import dataclass, field -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, cast from urllib.parse import urlparse from opentelemetry._events import ( @@ -18,8 +19,14 @@ from opentelemetry.util.types import AttributeValue from pydantic import TypeAdapter +from .. import _otel_messages from .._run_context import RunContext -from ..messages import ModelMessage, ModelRequest, ModelResponse +from ..messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + SystemPromptPart, +) from ..settings import ModelSettings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse from .wrapper import WrapperModel @@ -80,6 +87,8 @@ class InstrumentationSettings: event_logger: EventLogger = field(repr=False) event_mode: Literal['attributes', 'logs'] = 'attributes' include_binary_content: bool = True + include_content: bool = True + version: Literal[1, 2] = 1 def __init__( self, @@ -90,6 +99,7 @@ def __init__( event_logger_provider: EventLoggerProvider | None = None, include_binary_content: bool = True, include_content: bool = True, + version: Literal[1, 2] = 1, ): """Create instrumentation options. @@ -109,6 +119,7 @@ def __init__( include_binary_content: Whether to include binary content in the instrumentation events. include_content: Whether to include prompts, completions, and tool call arguments and responses in the instrumentation events. + version: TODO """ from pydantic_ai import __version__ @@ -122,6 +133,7 @@ def __init__( self.event_mode = event_mode self.include_binary_content = include_binary_content self.include_content = include_content + self.version = version # As specified in the OpenTelemetry GenAI metrics spec: # https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage @@ -179,6 +191,86 @@ def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]: event.body = InstrumentedModel.serialize_any(event.body) return events + def messages_to_otel_messages(self, messages: list[ModelMessage]) -> list[_otel_messages.ChatMessage]: + result: list[_otel_messages.ChatMessage] = [] + for message in messages: + if isinstance(message, ModelRequest): + for is_system, group in itertools.groupby(message.parts, key=lambda p: isinstance(p, SystemPromptPart)): + message_parts: list[_otel_messages.MessagePart] = [] + for part in group: + if hasattr(part, 'otel_message_parts'): + message_parts.extend(part.otel_message_parts(self)) + result.append( + _otel_messages.ChatMessage(role='system' if is_system else 'user', parts=message_parts) + ) + elif isinstance(message, ModelResponse): # pragma: no branch + result.append(_otel_messages.ChatMessage(role='assistant', parts=message.otel_message_parts(self))) + return result + + def handle_messages(self, input_messages: list[ModelMessage], response: ModelResponse, system: str, span: Span): + if self.version == 1: + events = self.messages_to_otel_events(input_messages) + for event in self.messages_to_otel_events([response]): + events.append( + Event( + 'gen_ai.choice', + body={ + # TODO finish_reason + 'index': 0, + 'message': event.body, + }, + ) + ) + for event in events: + event.attributes = { + GEN_AI_SYSTEM_ATTRIBUTE: system, + **(event.attributes or {}), + } + self._emit_events(span, events) + else: + output_message = cast(_otel_messages.OutputMessage, self.messages_to_otel_messages([response])[0]) + if response.provider_details and 'finish_reason' in response.provider_details: + output_message['finish_reason'] = response.provider_details['finish_reason'] + instructions = InstrumentedModel._get_instructions(input_messages) # pyright: ignore [reportPrivateUsage] + attributes = { + 'gen_ai.input.messages': json.dumps(self.messages_to_otel_messages(input_messages)), + 'gen_ai.output.messages': json.dumps([output_message]), + 'logfire.json_schema': json.dumps( + { + 'type': 'object', + 'properties': { + 'gen_ai.input.messages': {'type': 'array'}, + 'gen_ai.output.messages': {'type': 'array'}, + 'model_request_parameters': {'type': 'object'}, + }, + } + ), + } + if instructions is not None: + attributes['gen_ai.system_instructions'] = instructions + span.set_attributes(attributes) + + def _emit_events(self, span: Span, events: list[Event]) -> None: + if self.event_mode == 'logs': + for event in events: + self.event_logger.emit(event) + else: + attr_name = 'events' + span.set_attributes( + { + attr_name: json.dumps([InstrumentedModel.event_to_dict(event) for event in events]), + 'logfire.json_schema': json.dumps( + { + 'type': 'object', + 'properties': { + attr_name: {'type': 'array'}, + 'model_request_parameters': {'type': 'object'}, + }, + } + ), + } + ) + GEN_AI_SYSTEM_ATTRIBUTE = 'gen_ai.system' GEN_AI_REQUEST_MODEL_ATTRIBUTE = 'gen_ai.request.model' @@ -269,7 +361,7 @@ def finish(response: ModelResponse): # FallbackModel updates these span attributes. attributes.update(getattr(span, 'attributes', {})) request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE] - system = attributes[GEN_AI_SYSTEM_ATTRIBUTE] + system = cast(str, attributes[GEN_AI_SYSTEM_ATTRIBUTE]) response_model = response.model_name or request_model @@ -297,18 +389,7 @@ def _record_metrics(): if not span.is_recording(): return - events = self.instrumentation_settings.messages_to_otel_events(messages) - for event in self.instrumentation_settings.messages_to_otel_events([response]): - events.append( - Event( - 'gen_ai.choice', - body={ - # TODO finish_reason - 'index': 0, - 'message': event.body, - }, - ) - ) + self.instrumentation_settings.handle_messages(messages, response, system, span) span.set_attributes( { **response.usage.opentelemetry_attributes(), @@ -316,12 +397,6 @@ def _record_metrics(): } ) span.update_name(f'{operation} {request_model}') - for event in events: - event.attributes = { - GEN_AI_SYSTEM_ATTRIBUTE: system, - **(event.attributes or {}), - } - self._emit_events(span, events) yield finish finally: @@ -330,27 +405,6 @@ def _record_metrics(): # to prevent them from being redundantly recorded in the span itself by logfire. record_metrics() - def _emit_events(self, span: Span, events: list[Event]) -> None: - if self.instrumentation_settings.event_mode == 'logs': - for event in events: - self.instrumentation_settings.event_logger.emit(event) - else: - attr_name = 'events' - span.set_attributes( - { - attr_name: json.dumps([self.event_to_dict(event) for event in events]), - 'logfire.json_schema': json.dumps( - { - 'type': 'object', - 'properties': { - attr_name: {'type': 'array'}, - 'model_request_parameters': {'type': 'object'}, - }, - } - ), - } - ) - @staticmethod def model_attributes(model: Model): attributes: dict[str, AttributeValue] = { diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index b5704c7b9..3d0aefe7c 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -3,9 +3,9 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from datetime import datetime +from typing import Literal import pytest -from dirty_equals import IsJson from inline_snapshot import snapshot from logfire_api import DEFAULT_LOGFIRE_INSTANCE from opentelemetry._events import NoOpEventLoggerProvider @@ -84,6 +84,7 @@ async def request( ], usage=RequestUsage(input_tokens=100, output_tokens=200), model_name='my_model_123', + provider_details=dict(finish_reason='stop', foo='bar'), ) @asynccontextmanager @@ -525,13 +526,17 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): ) -async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): - model = InstrumentedModel(MyModel(), InstrumentationSettings(event_mode='attributes')) +@pytest.mark.parametrize('instrumentation_version', [1, 2]) +async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire, instrumentation_version: Literal[1, 2]): + model = InstrumentedModel( + MyModel(), InstrumentationSettings(event_mode='attributes', version=instrumentation_version) + ) assert model.system == 'my_system' assert model.model_name == 'my_model' messages = [ ModelRequest( + instructions='instructions', parts=[ SystemPromptPart('system_prompt'), UserPromptPart('user_prompt'), @@ -539,7 +544,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): RetryPromptPart('retry_prompt1', tool_name='tool4', tool_call_id='tool_call_4'), RetryPromptPart('retry_prompt2'), {}, # test unexpected parts # type: ignore - ] + ], ), ModelResponse(parts=[TextPart('text3')]), ] @@ -555,117 +560,229 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): ), ) - assert capfire.exporter.exported_spans_as_dict() == snapshot( - [ - { - 'name': 'chat my_model', - 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, - 'parent': None, - 'start_time': 1000000000, - 'end_time': 2000000000, - 'attributes': { - 'gen_ai.operation.name': 'chat', - 'gen_ai.system': 'my_system', - 'gen_ai.request.model': 'my_model', - 'server.address': 'example.com', - 'server.port': 8000, - 'model_request_parameters': '{"function_tools": [], "builtin_tools": [], "output_mode": "text", "output_object": null, "output_tools": [], "allow_text_output": true}', - 'gen_ai.request.temperature': 1, - 'logfire.msg': 'chat my_model', - 'logfire.span_type': 'span', - 'gen_ai.response.model': 'my_model_123', - 'gen_ai.usage.input_tokens': 100, - 'gen_ai.usage.output_tokens': 200, - 'events': IsJson( - snapshot( - [ - { - 'event.name': 'gen_ai.system.message', - 'content': 'system_prompt', - 'role': 'system', - 'gen_ai.message.index': 0, - 'gen_ai.system': 'my_system', - }, - { - 'event.name': 'gen_ai.user.message', - 'content': 'user_prompt', - 'role': 'user', - 'gen_ai.message.index': 0, - 'gen_ai.system': 'my_system', - }, - { - 'event.name': 'gen_ai.tool.message', - 'content': 'tool_return_content', - 'role': 'tool', - 'name': 'tool3', - 'id': 'tool_call_3', - 'gen_ai.message.index': 0, - 'gen_ai.system': 'my_system', - }, - { - 'event.name': 'gen_ai.tool.message', - 'content': """\ + if instrumentation_version == 1: + assert capfire.exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot( + [ + { + 'name': 'chat my_model', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'gen_ai.operation.name': 'chat', + 'gen_ai.system': 'my_system', + 'gen_ai.request.model': 'my_model', + 'server.address': 'example.com', + 'server.port': 8000, + 'model_request_parameters': { + 'function_tools': [], + 'builtin_tools': [], + 'output_mode': 'text', + 'output_object': None, + 'output_tools': [], + 'allow_text_output': True, + }, + 'gen_ai.request.temperature': 1, + 'logfire.msg': 'chat my_model', + 'logfire.span_type': 'span', + 'gen_ai.response.model': 'my_model_123', + 'gen_ai.usage.input_tokens': 100, + 'gen_ai.usage.output_tokens': 200, + 'events': [ + { + 'content': 'instructions', + 'role': 'system', + 'gen_ai.system': 'my_system', + 'event.name': 'gen_ai.system.message', + }, + { + 'event.name': 'gen_ai.system.message', + 'content': 'system_prompt', + 'role': 'system', + 'gen_ai.message.index': 0, + 'gen_ai.system': 'my_system', + }, + { + 'event.name': 'gen_ai.user.message', + 'content': 'user_prompt', + 'role': 'user', + 'gen_ai.message.index': 0, + 'gen_ai.system': 'my_system', + }, + { + 'event.name': 'gen_ai.tool.message', + 'content': 'tool_return_content', + 'role': 'tool', + 'name': 'tool3', + 'id': 'tool_call_3', + 'gen_ai.message.index': 0, + 'gen_ai.system': 'my_system', + }, + { + 'event.name': 'gen_ai.tool.message', + 'content': """\ retry_prompt1 Fix the errors and try again.\ """, - 'role': 'tool', - 'name': 'tool4', - 'id': 'tool_call_4', - 'gen_ai.message.index': 0, - 'gen_ai.system': 'my_system', - }, - { - 'event.name': 'gen_ai.user.message', - 'content': """\ + 'role': 'tool', + 'name': 'tool4', + 'id': 'tool_call_4', + 'gen_ai.message.index': 0, + 'gen_ai.system': 'my_system', + }, + { + 'event.name': 'gen_ai.user.message', + 'content': """\ Validation feedback: retry_prompt2 Fix the errors and try again.\ """, - 'role': 'user', - 'gen_ai.message.index': 0, - 'gen_ai.system': 'my_system', - }, - { - 'event.name': 'gen_ai.assistant.message', + 'role': 'user', + 'gen_ai.message.index': 0, + 'gen_ai.system': 'my_system', + }, + { + 'event.name': 'gen_ai.assistant.message', + 'role': 'assistant', + 'content': 'text3', + 'gen_ai.message.index': 1, + 'gen_ai.system': 'my_system', + }, + { + 'index': 0, + 'message': { 'role': 'assistant', - 'content': 'text3', - 'gen_ai.message.index': 1, - 'gen_ai.system': 'my_system', + 'content': [ + {'kind': 'text', 'text': 'text1'}, + {'kind': 'text', 'text': 'text2'}, + ], + 'tool_calls': [ + { + 'id': 'tool_call_1', + 'type': 'function', + 'function': {'name': 'tool1', 'arguments': 'args1'}, + }, + { + 'id': 'tool_call_2', + 'type': 'function', + 'function': {'name': 'tool2', 'arguments': {'args2': 3}}, + }, + ], }, - { - 'index': 0, - 'message': { - 'role': 'assistant', - 'content': [ - {'kind': 'text', 'text': 'text1'}, - {'kind': 'text', 'text': 'text2'}, - ], - 'tool_calls': [ - { - 'id': 'tool_call_1', - 'type': 'function', - 'function': {'name': 'tool1', 'arguments': 'args1'}, - }, - { - 'id': 'tool_call_2', - 'type': 'function', - 'function': {'name': 'tool2', 'arguments': {'args2': 3}}, - }, - ], + 'gen_ai.system': 'my_system', + 'event.name': 'gen_ai.choice', + }, + ], + 'logfire.json_schema': { + 'type': 'object', + 'properties': {'events': {'type': 'array'}, 'model_request_parameters': {'type': 'object'}}, + }, + }, + }, + ] + ) + else: + assert capfire.exporter.exported_spans_as_dict(parse_json_attributes=True) == snapshot( + [ + { + 'name': 'chat my_model', + 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, + 'parent': None, + 'start_time': 1000000000, + 'end_time': 2000000000, + 'attributes': { + 'gen_ai.operation.name': 'chat', + 'gen_ai.system': 'my_system', + 'gen_ai.request.model': 'my_model', + 'server.address': 'example.com', + 'server.port': 8000, + 'model_request_parameters': { + 'function_tools': [], + 'builtin_tools': [], + 'output_mode': 'text', + 'output_object': None, + 'output_tools': [], + 'allow_text_output': True, + }, + 'gen_ai.request.temperature': 1, + 'logfire.msg': 'chat my_model', + 'logfire.span_type': 'span', + 'gen_ai.input.messages': [ + { + 'role': 'system', + 'parts': [ + {'type': 'text', 'content': 'system_prompt'}, + ], + }, + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'user_prompt'}, + { + 'type': 'tool_call_response', + 'id': 'tool_call_3', + 'name': 'tool3', + 'result': 'tool_return_content', }, - 'gen_ai.system': 'my_system', - 'event.name': 'gen_ai.choice', - }, - ] - ) - ), - 'logfire.json_schema': '{"type": "object", "properties": {"events": {"type": "array"}, "model_request_parameters": {"type": "object"}}}', + { + 'type': 'tool_call_response', + 'id': 'tool_call_4', + 'name': 'tool4', + 'result': """\ +retry_prompt1 + +Fix the errors and try again.\ +""", + }, + { + 'type': 'text', + 'content': """\ +Validation feedback: +retry_prompt2 + +Fix the errors and try again.\ +""", + }, + ], + }, + {'role': 'assistant', 'parts': [{'type': 'text', 'content': 'text3'}]}, + ], + 'gen_ai.output.messages': [ + { + 'role': 'assistant', + 'parts': [ + {'type': 'text', 'content': 'text1'}, + {'type': 'tool_call', 'id': 'tool_call_1', 'name': 'tool1', 'arguments': 'args1'}, + { + 'type': 'tool_call', + 'id': 'tool_call_2', + 'name': 'tool2', + 'arguments': {'args2': 3}, + }, + {'type': 'text', 'content': 'text2'}, + ], + 'finish_reason': 'stop', + } + ], + 'gen_ai.response.model': 'my_model_123', + 'gen_ai.system_instructions': 'instructions', + 'gen_ai.usage.input_tokens': 100, + 'gen_ai.usage.output_tokens': 200, + 'logfire.json_schema': { + 'type': 'object', + 'properties': { + 'gen_ai.input.messages': {'type': 'array'}, + 'gen_ai.output.messages': {'type': 'array'}, + 'model_request_parameters': {'type': 'object'}, + }, + }, + }, }, - }, - ] - ) + ] + ) def test_messages_to_otel_events_serialization_errors(): @@ -695,6 +812,25 @@ def __repr__(self): 'event.name': 'gen_ai.tool.message', }, ] + assert settings.messages_to_otel_messages(messages) == snapshot( + [ + { + 'role': 'assistant', + 'parts': [{'type': 'tool_call', 'id': 'tool_call_id', 'name': 'tool', 'arguments': {'arg': 'Foo()'}}], + }, + { + 'role': 'user', + 'parts': [ + { + 'type': 'tool_call_response', + 'id': 'return_tool_call_id', + 'name': 'tool', + 'result': 'Unable to serialize: error!', + } + ], + }, + ] + ) def test_messages_to_otel_events_instructions(): @@ -715,6 +851,12 @@ def test_messages_to_otel_events_instructions(): }, ] ) + assert settings.messages_to_otel_messages(messages) == snapshot( + [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'user_prompt'}]}, + {'role': 'assistant', 'parts': [{'type': 'text', 'content': 'text1'}]}, + ] + ) def test_messages_to_otel_events_instructions_multiple_messages(): @@ -737,6 +879,13 @@ def test_messages_to_otel_events_instructions_multiple_messages(): {'content': 'user_prompt2', 'role': 'user', 'gen_ai.message.index': 2, 'event.name': 'gen_ai.user.message'}, ] ) + assert settings.messages_to_otel_messages(messages) == snapshot( + [ + {'role': 'user', 'parts': [{'type': 'text', 'content': 'user_prompt'}]}, + {'role': 'assistant', 'parts': [{'type': 'text', 'content': 'text1'}]}, + {'role': 'user', 'parts': [{'type': 'text', 'content': 'user_prompt2'}]}, + ] + ) def test_messages_to_otel_events_image_url(document_content: BinaryContent): @@ -817,6 +966,60 @@ def test_messages_to_otel_events_image_url(document_content: BinaryContent): }, ] ) + assert settings.messages_to_otel_messages(messages) == snapshot( + [ + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'user_prompt'}, + {'type': 'image-url', 'url': 'https://example.com/image.png'}, + ], + }, + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'user_prompt2'}, + {'type': 'audio-url', 'url': 'https://example.com/audio.mp3'}, + ], + }, + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'user_prompt3'}, + {'type': 'document-url', 'url': 'https://example.com/document.pdf'}, + ], + }, + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'user_prompt4'}, + {'type': 'video-url', 'url': 'https://example.com/video.mp4'}, + ], + }, + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'user_prompt5'}, + {'type': 'image-url', 'url': 'https://example.com/image2.png'}, + {'type': 'audio-url', 'url': 'https://example.com/audio2.mp3'}, + {'type': 'document-url', 'url': 'https://example.com/document2.pdf'}, + {'type': 'video-url', 'url': 'https://example.com/video2.mp4'}, + ], + }, + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'user_prompt6'}, + { + 'type': 'binary', + 'media_type': 'application/pdf', + 'binary_content': IsStr(), + }, + ], + }, + {'role': 'assistant', 'parts': [{'type': 'text', 'content': 'text1'}]}, + ] + ) def test_messages_to_otel_events_without_binary_content(document_content: BinaryContent): @@ -834,6 +1037,17 @@ def test_messages_to_otel_events_without_binary_content(document_content: Binary } ] ) + assert settings.messages_to_otel_messages(messages) == snapshot( + [ + { + 'role': 'user', + 'parts': [ + {'type': 'text', 'content': 'user_prompt6'}, + {'type': 'binary', 'media_type': 'application/pdf'}, + ], + } + ] + ) def test_messages_without_content(document_content: BinaryContent): @@ -928,6 +1142,34 @@ def test_messages_without_content(document_content: BinaryContent): }, ] ) + assert settings.messages_to_otel_messages(messages) == snapshot( + [ + {'role': 'system', 'parts': [{'type': 'text'}]}, + {'role': 'assistant', 'parts': [{'type': 'text'}]}, + { + 'role': 'user', + 'parts': [ + {'type': 'text'}, + {'type': 'video-url'}, + {'type': 'image-url'}, + {'type': 'audio-url'}, + {'type': 'document-url'}, + {'type': 'binary', 'media_type': 'application/pdf'}, + ], + }, + { + 'role': 'assistant', + 'parts': [ + {'type': 'text'}, + {'type': 'tool_call', 'id': IsStr(), 'name': 'my_tool'}, + ], + }, + {'role': 'user', 'parts': [{'type': 'tool_call_response', 'id': 'tool_call_1', 'name': 'tool'}]}, + {'role': 'user', 'parts': [{'type': 'tool_call_response', 'id': 'tool_call_2', 'name': 'tool'}]}, + {'role': 'user', 'parts': [{'type': 'text'}, {'type': 'binary', 'media_type': 'application/pdf'}]}, + {'role': 'user', 'parts': [{'type': 'text'}]}, + ] + ) def test_message_with_thinking_parts(): @@ -963,3 +1205,20 @@ def test_message_with_thinking_parts(): }, ] ) + assert settings.messages_to_otel_messages(messages) == snapshot( + [ + { + 'role': 'assistant', + 'parts': [ + {'type': 'text', 'content': 'text1'}, + {'type': 'thinking', 'content': 'thinking1'}, + {'type': 'text', 'content': 'text2'}, + ], + }, + {'role': 'assistant', 'parts': [{'type': 'thinking', 'content': 'thinking2'}]}, + { + 'role': 'assistant', + 'parts': [{'type': 'thinking', 'content': 'thinking3'}, {'type': 'text', 'content': 'text3'}], + }, + ] + ) diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 2bb02f45c..f20a4bde0 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -75,6 +75,7 @@ def get_summary() -> LogfireSummary: False, InstrumentationSettings(event_mode='attributes'), InstrumentationSettings(event_mode='logs'), + InstrumentationSettings(version=2), ], ) def test_logfire( @@ -241,39 +242,69 @@ async def my_ret(x: int) -> str: ] ) - attribute_mode_attributes = {k: chat_span_attributes.pop(k) for k in ['events']} - assert attribute_mode_attributes == snapshot( - { - 'events': IsJson( - snapshot( - [ - { - 'event.name': 'gen_ai.user.message', - 'content': 'Hello', - 'role': 'user', - 'gen_ai.message.index': 0, - 'gen_ai.system': 'test', - }, - { - 'event.name': 'gen_ai.choice', - 'index': 0, - 'message': { + messages_attributes = { + k: chat_span_attributes.pop(k) + for k in ['events', 'gen_ai.input.messages', 'gen_ai.output.messages'] + if k in chat_span_attributes + } + if 'events' in messages_attributes: + assert messages_attributes == snapshot( + { + 'events': IsJson( + snapshot( + [ + { + 'event.name': 'gen_ai.user.message', + 'content': 'Hello', + 'role': 'user', + 'gen_ai.message.index': 0, + 'gen_ai.system': 'test', + }, + { + 'event.name': 'gen_ai.choice', + 'index': 0, + 'message': { + 'role': 'assistant', + 'tool_calls': [ + { + 'id': IsStr(), + 'type': 'function', + 'function': {'name': 'my_ret', 'arguments': {'x': 0}}, + } + ], + }, + 'gen_ai.system': 'test', + }, + ] + ) + ), + } + ) + else: + assert messages_attributes == snapshot( + { + 'gen_ai.input.messages': IsJson( + snapshot([{'role': 'user', 'parts': [{'type': 'text', 'content': 'Hello'}]}]) + ), + 'gen_ai.output.messages': IsJson( + snapshot( + [ + { 'role': 'assistant', - 'tool_calls': [ + 'parts': [ { + 'type': 'tool_call', 'id': IsStr(), - 'type': 'function', - 'function': {'name': 'my_ret', 'arguments': {'x': 0}}, + 'name': 'my_ret', + 'arguments': {'x': 0}, } ], - }, - 'gen_ai.system': 'test', - }, - ] - ) - ), - } - ) + } + ] + ) + ), + } + ) assert chat_span_attributes == snapshot( { @@ -317,15 +348,24 @@ async def my_ret(x: int) -> str: @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') -def test_instructions_with_structured_output(get_logfire_summary: Callable[[], LogfireSummary]) -> None: +@pytest.mark.parametrize( + 'instrument', + [ + InstrumentationSettings(version=1), + InstrumentationSettings(version=2), + ], +) +def test_instructions_with_structured_output( + get_logfire_summary: Callable[[], LogfireSummary], instrument: InstrumentationSettings +) -> None: @dataclass class MyOutput: content: str - my_agent = Agent(model=TestModel(), instructions='Here are some instructions', instrument=True) + my_agent = Agent(model=TestModel(), instructions='Here are some instructions', instrument=instrument) result = my_agent.run_sync('Hello', output_type=MyOutput) - assert result.output == snapshot(MyOutput(content='a')) + assert result.output == MyOutput(content='a') summary = get_logfire_summary() assert summary.attributes[0] == snapshot( @@ -385,8 +425,8 @@ class MyOutput: } ) chat_span_attributes = summary.attributes[1] - assert chat_span_attributes['events'] == snapshot( - IsJson( + if instrument.version == 1: + assert chat_span_attributes['events'] == IsJson( snapshot( [ { @@ -420,7 +460,27 @@ class MyOutput: ] ) ) - ) + else: + assert chat_span_attributes['gen_ai.input.messages'] == IsJson( + snapshot([{'role': 'user', 'parts': [{'type': 'text', 'content': 'Hello'}]}]) + ) + assert chat_span_attributes['gen_ai.output.messages'] == IsJson( + snapshot( + [ + { + 'role': 'assistant', + 'parts': [ + { + 'type': 'tool_call', + 'id': IsStr(), + 'name': 'final_result', + 'arguments': {'content': 'a'}, + } + ], + } + ] + ) + ) @pytest.mark.skipif(not logfire_installed, reason='logfire not installed')