Skip to content

Commit adc8122

Browse files
committed
wip
1 parent 5def693 commit adc8122

File tree

10 files changed

+1057
-64
lines changed

10 files changed

+1057
-64
lines changed

src/agents/agent.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,33 @@ class AgentBase:
9494
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
9595
"""Configuration for MCP servers."""
9696

97+
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
98+
"""Fetches the available tools from the MCP servers."""
99+
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
100+
return await MCPUtil.get_all_function_tools(
101+
self.mcp_servers, convert_schemas_to_strict, run_context, self
102+
)
103+
104+
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
105+
"""All agent tools, including MCP tools and function tools."""
106+
mcp_tools = await self.get_mcp_tools(run_context)
107+
108+
async def _check_tool_enabled(tool: Tool) -> bool:
109+
if not isinstance(tool, FunctionTool):
110+
return True
111+
112+
attr = tool.is_enabled
113+
if isinstance(attr, bool):
114+
return attr
115+
res = attr(run_context, self)
116+
if inspect.isawaitable(res):
117+
return bool(await res)
118+
return bool(res)
119+
120+
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
121+
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
122+
return [*mcp_tools, *enabled]
123+
97124

98125
@dataclass
99126
class Agent(AgentBase, Generic[TContext]):
@@ -262,30 +289,3 @@ async def get_prompt(
262289
) -> ResponsePromptParam | None:
263290
"""Get the prompt for the agent."""
264291
return await PromptUtil.to_model_input(self.prompt, run_context, self)
265-
266-
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
267-
"""Fetches the available tools from the MCP servers."""
268-
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
269-
return await MCPUtil.get_all_function_tools(
270-
self.mcp_servers, convert_schemas_to_strict, run_context, self
271-
)
272-
273-
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
274-
"""All agent tools, including MCP tools and function tools."""
275-
mcp_tools = await self.get_mcp_tools(run_context)
276-
277-
async def _check_tool_enabled(tool: Tool) -> bool:
278-
if not isinstance(tool, FunctionTool):
279-
return True
280-
281-
attr = tool.is_enabled
282-
if isinstance(attr, bool):
283-
return attr
284-
res = attr(run_context, self)
285-
if inspect.isawaitable(res):
286-
return bool(await res)
287-
return bool(res)
288-
289-
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
290-
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
291-
return [*mcp_tools, *enabled]

src/agents/model_settings.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
class _OmitTypeAnnotation:
1818
@classmethod
1919
def __get_pydantic_core_schema__(
20-
cls,
21-
_source_type: Any,
22-
_handler: GetCoreSchemaHandler,
20+
cls,
21+
_source_type: Any,
22+
_handler: GetCoreSchemaHandler,
2323
) -> core_schema.CoreSchema:
2424
def validate_from_none(value: None) -> _Omit:
2525
return _Omit()
@@ -39,12 +39,14 @@ def validate_from_none(value: None) -> _Omit:
3939
from_none_schema,
4040
]
4141
),
42-
serialization=core_schema.plain_serializer_function_ser_schema(
43-
lambda instance: None
44-
),
42+
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None),
4543
)
44+
45+
4646
Omit = Annotated[_Omit, _OmitTypeAnnotation]
4747
Headers: TypeAlias = Mapping[str, Union[str, Omit]]
48+
ToolChoice: TypeAlias = Literal["auto", "required", "none"] | str | None
49+
4850

4951
@dataclass
5052
class ModelSettings:

src/agents/realtime/__init__.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,51 @@
11
from .agent import RealtimeAgent, RealtimeAgentHooks, RealtimeRunHooks
2+
from .config import APIKeyOrKeyFunc
3+
from .events import (
4+
RealtimeAgentEndEvent,
5+
RealtimeAgentStartEvent,
6+
RealtimeAudio,
7+
RealtimeAudioEnd,
8+
RealtimeAudioInterrupted,
9+
RealtimeError,
10+
RealtimeHandoffEvent,
11+
RealtimeHistoryAdded,
12+
RealtimeHistoryUpdated,
13+
RealtimeRawTransportEvent,
14+
RealtimeSessionEvent,
15+
RealtimeToolEnd,
16+
RealtimeToolStart,
17+
)
18+
from .session import RealtimeSession
19+
from .transport import (
20+
RealtimeModelName,
21+
RealtimeSessionTransport,
22+
RealtimeTransportConnectionOptions,
23+
RealtimeTransportListener,
24+
)
225

3-
__all__ = ["RealtimeAgent", "RealtimeAgentHooks", "RealtimeRunHooks"]
26+
__all__ = [
27+
"RealtimeAgent",
28+
"RealtimeAgentHooks",
29+
"RealtimeRunHooks",
30+
"RealtimeSession",
31+
"RealtimeSessionListener",
32+
"RealtimeSessionListenerFunc",
33+
"APIKeyOrKeyFunc",
34+
"RealtimeModelName",
35+
"RealtimeSessionTransport",
36+
"RealtimeTransportListener",
37+
"RealtimeTransportConnectionOptions",
38+
"RealtimeSessionEvent",
39+
"RealtimeAgentStartEvent",
40+
"RealtimeAgentEndEvent",
41+
"RealtimeHandoffEvent",
42+
"RealtimeToolStart",
43+
"RealtimeToolEnd",
44+
"RealtimeRawTransportEvent",
45+
"RealtimeAudioEnd",
46+
"RealtimeAudio",
47+
"RealtimeAudioInterrupted",
48+
"RealtimeError",
49+
"RealtimeHistoryUpdated",
50+
"RealtimeHistoryAdded",
51+
]

src/agents/realtime/agent.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import dataclasses
54
import inspect
65
from collections.abc import Awaitable
@@ -10,9 +9,7 @@
109
from ..agent import AgentBase
1110
from ..lifecycle import AgentHooksBase, RunHooksBase
1211
from ..logger import logger
13-
from ..mcp import MCPUtil
1412
from ..run_context import RunContextWrapper, TContext
15-
from ..tool import FunctionTool, Tool
1613
from ..util._types import MaybeAwaitable
1714

1815
RealtimeAgentHooks = AgentHooksBase[TContext, "RealtimeAgent[TContext]"]
@@ -81,30 +78,3 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
8178
logger.error(f"Instructions must be a string or a function, got {self.instructions}")
8279

8380
return None
84-
85-
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
86-
"""Fetches the available tools from the MCP servers."""
87-
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
88-
return await MCPUtil.get_all_function_tools(
89-
self.mcp_servers, convert_schemas_to_strict, run_context, self
90-
)
91-
92-
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
93-
"""All agent tools, including MCP tools and function tools."""
94-
mcp_tools = await self.get_mcp_tools(run_context)
95-
96-
async def _check_tool_enabled(tool: Tool) -> bool:
97-
if not isinstance(tool, FunctionTool):
98-
return True
99-
100-
attr = tool.is_enabled
101-
if isinstance(attr, bool):
102-
return attr
103-
res = attr(run_context, self)
104-
if inspect.isawaitable(res):
105-
return bool(await res)
106-
return bool(res)
107-
108-
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
109-
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
110-
return [*mcp_tools, *enabled]

src/agents/realtime/config.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from typing import (
5+
Callable,
6+
Literal,
7+
Union,
8+
)
9+
10+
from typing_extensions import NotRequired, TypeAlias, TypedDict
11+
12+
from ..model_settings import ToolChoice
13+
from ..tool import FunctionTool
14+
from ..util._types import MaybeAwaitable
15+
16+
17+
class RealtimeClientMessage(TypedDict, total=False):
18+
type: str # explicitly required
19+
# All additional keys are permitted because total=False
20+
21+
22+
class UserInputText(TypedDict):
23+
type: Literal["input_text"]
24+
text: str
25+
26+
27+
class RealtimeUserInputMessage(TypedDict):
28+
type: Literal["message"]
29+
role: Literal["user"]
30+
content: list[UserInputText]
31+
32+
33+
RealtimeUserInput: TypeAlias = Union[str, RealtimeUserInputMessage]
34+
35+
36+
RealtimeAudioFormat: TypeAlias = Union[Literal["pcm16", "g711_ulaw", "g711_alaw"], str]
37+
38+
39+
class RealtimeInputAudioTranscriptionConfig(TypedDict, total=False):
40+
language: NotRequired[str]
41+
model: NotRequired[Literal["gpt-4o-transcribe", "gpt-4o-mini-transcribe", "whisper-1"] | str]
42+
prompt: NotRequired[str]
43+
44+
45+
class RealtimeTurnDetectionConfig(TypedDict, total=False):
46+
"""Turn detection config. Allows extra vendor keys if needed."""
47+
48+
type: NotRequired[Literal["semantic_vad", "server_vad"]]
49+
create_response: NotRequired[bool]
50+
eagerness: NotRequired[Literal["auto", "low", "medium", "high"]]
51+
interrupt_response: NotRequired[bool]
52+
prefix_padding_ms: NotRequired[int]
53+
silence_duration_ms: NotRequired[int]
54+
threshold: NotRequired[float]
55+
56+
57+
class RealtimeSessionConfig(TypedDict):
58+
api_key: NotRequired[APIKeyOrKeyFunc]
59+
model: NotRequired[str]
60+
instructions: NotRequired[str]
61+
modalities: NotRequired[list[Literal["text", "audio"]]]
62+
voice: NotRequired[str]
63+
64+
input_audio_format: NotRequired[RealtimeAudioFormat]
65+
output_audio_format: NotRequired[RealtimeAudioFormat]
66+
input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig]
67+
turn_detection: NotRequired[RealtimeTurnDetectionConfig]
68+
69+
tool_choice: NotRequired[ToolChoice]
70+
tools: NotRequired[list[FunctionTool]]
71+
72+
73+
APIKeyOrKeyFunc = str | Callable[[], MaybeAwaitable[str]]
74+
"""Either an API key or a function that returns an API key."""
75+
76+
77+
async def get_api_key(key: APIKeyOrKeyFunc | None) -> str | None:
78+
"""Get the API key from the key or key function."""
79+
if key is None:
80+
return None
81+
elif isinstance(key, str):
82+
return key
83+
84+
result = key()
85+
if inspect.isawaitable(result):
86+
return await result
87+
return result
88+
89+
# TODO (rm) Add tracing support
90+
# tracing: NotRequired[RealtimeTracingConfig | None]

0 commit comments

Comments
 (0)