diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index eb8c16d1f..edd0d898b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -84,7 +84,5 @@ jobs: enable-cache: true - name: Install dependencies run: make sync - - name: Install Python 3.9 dependencies - run: UV_PROJECT_ENVIRONMENT=.venv_39 uv sync --all-extras --all-packages --group dev - name: Run tests run: make old_version_tests diff --git a/.gitignore b/.gitignore index 2e9b92379..c0c4b3254 100644 --- a/.gitignore +++ b/.gitignore @@ -100,7 +100,8 @@ celerybeat.pid *.sage.py # Environments -.env +.python-version +.env* .venv env/ venv/ diff --git a/Makefile b/Makefile index 470d97c14..506f198a9 100644 --- a/Makefile +++ b/Makefile @@ -39,7 +39,8 @@ snapshots-create: uv run pytest --inline-snapshot=create .PHONY: old_version_tests -old_version_tests: +old_version_tests: + UV_PROJECT_ENVIRONMENT=.venv_39 uv sync --python 3.9 --all-extras --all-packages --group dev UV_PROJECT_ENVIRONMENT=.venv_39 uv run --python 3.9 -m pytest .PHONY: build-docs diff --git a/examples/realtime/app/README.md b/examples/realtime/app/README.md index cb5519a79..420134bba 100644 --- a/examples/realtime/app/README.md +++ b/examples/realtime/app/README.md @@ -29,14 +29,19 @@ To use the same UI with your own agents, edit `agent.py` and ensure get_starting 1. Click **Connect** to establish a realtime session 2. Audio capture starts automatically - just speak naturally 3. Click the **Mic On/Off** button to mute/unmute your microphone -4. Watch the conversation unfold in the left pane -5. Monitor raw events in the right pane (click to expand/collapse) -6. Click **Disconnect** when done +4. To send an image, enter an optional prompt and click **🖼️ Send Image** (select a file) +5. Watch the conversation unfold in the left pane (image thumbnails are shown) +6. Monitor raw events in the right pane (click to expand/collapse) +7. Click **Disconnect** when done ## Architecture - **Backend**: FastAPI server with WebSocket connections for real-time communication - **Session Management**: Each connection gets a unique session with the OpenAI Realtime API +- **Image Inputs**: The UI uploads images and the server forwards a + `conversation.item.create` event with `input_image` (plus optional `input_text`), + followed by `response.create` to start the model response. The messages pane + renders image bubbles for `input_image` content. - **Audio Processing**: 24kHz mono audio capture and playback - **Event Handling**: Full event stream processing with transcript generation - **Frontend**: Vanilla JavaScript with clean, responsive CSS diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 26c544dd2..d4ff47e80 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -12,6 +12,8 @@ from typing_extensions import assert_never from agents.realtime import RealtimeRunner, RealtimeSession, RealtimeSessionEvent +from agents.realtime.config import RealtimeUserInputMessage +from agents.realtime.model_inputs import RealtimeModelSendRawMessage # Import TwilioHandler class - handle both module and package use cases if TYPE_CHECKING: @@ -64,6 +66,34 @@ async def send_audio(self, session_id: str, audio_bytes: bytes): if session_id in self.active_sessions: await self.active_sessions[session_id].send_audio(audio_bytes) + async def send_client_event(self, session_id: str, event: dict[str, Any]): + """Send a raw client event to the underlying realtime model.""" + session = self.active_sessions.get(session_id) + if not session: + return + await session.model.send_event( + RealtimeModelSendRawMessage( + message={ + "type": event["type"], + "other_data": {k: v for k, v in event.items() if k != "type"}, + } + ) + ) + + async def send_user_message(self, session_id: str, message: RealtimeUserInputMessage): + """Send a structured user message via the higher-level API (supports input_image).""" + session = self.active_sessions.get(session_id) + if not session: + return + await session.send_message(message) # delegates to RealtimeModelSendUserInput path + + async def interrupt(self, session_id: str) -> None: + """Interrupt current model playback/response for a session.""" + session = self.active_sessions.get(session_id) + if not session: + return + await session.interrupt() + async def _process_events(self, session_id: str): try: session = self.active_sessions[session_id] @@ -101,7 +131,11 @@ async def _serialize_event(self, event: RealtimeSessionEvent) -> dict[str, Any]: elif event.type == "history_updated": base_event["history"] = [item.model_dump(mode="json") for item in event.history] elif event.type == "history_added": - pass + # Provide the added item so the UI can render incrementally. + try: + base_event["item"] = event.item.model_dump(mode="json") + except Exception: + base_event["item"] = None elif event.type == "guardrail_tripped": base_event["guardrail_results"] = [ {"name": result.guardrail.name} for result in event.guardrail_results @@ -134,6 +168,7 @@ async def lifespan(app: FastAPI): @app.websocket("/ws/{session_id}") async def websocket_endpoint(websocket: WebSocket, session_id: str): await manager.connect(websocket, session_id) + image_buffers: dict[str, dict[str, Any]] = {} try: while True: data = await websocket.receive_text() @@ -144,6 +179,124 @@ async def websocket_endpoint(websocket: WebSocket, session_id: str): int16_data = message["data"] audio_bytes = struct.pack(f"{len(int16_data)}h", *int16_data) await manager.send_audio(session_id, audio_bytes) + elif message["type"] == "image": + logger.info("Received image message from client (session %s).", session_id) + # Build a conversation.item.create with input_image (and optional input_text) + data_url = message.get("data_url") + prompt_text = message.get("text") or "Please describe this image." + if data_url: + logger.info( + "Forwarding image (structured message) to Realtime API (len=%d).", + len(data_url), + ) + user_msg: RealtimeUserInputMessage = { + "type": "message", + "role": "user", + "content": ( + [ + {"type": "input_image", "image_url": data_url, "detail": "high"}, + {"type": "input_text", "text": prompt_text}, + ] + if prompt_text + else [ + {"type": "input_image", "image_url": data_url, "detail": "high"} + ] + ), + } + await manager.send_user_message(session_id, user_msg) + # Acknowledge to client UI + await websocket.send_text( + json.dumps( + { + "type": "client_info", + "info": "image_enqueued", + "size": len(data_url), + } + ) + ) + else: + await websocket.send_text( + json.dumps( + { + "type": "error", + "error": "No data_url for image message.", + } + ) + ) + elif message["type"] == "commit_audio": + # Force close the current input audio turn + await manager.send_client_event(session_id, {"type": "input_audio_buffer.commit"}) + elif message["type"] == "image_start": + img_id = str(message.get("id")) + image_buffers[img_id] = { + "text": message.get("text") or "Please describe this image.", + "chunks": [], + } + await websocket.send_text( + json.dumps({"type": "client_info", "info": "image_start_ack", "id": img_id}) + ) + elif message["type"] == "image_chunk": + img_id = str(message.get("id")) + chunk = message.get("chunk", "") + if img_id in image_buffers: + image_buffers[img_id]["chunks"].append(chunk) + if len(image_buffers[img_id]["chunks"]) % 10 == 0: + await websocket.send_text( + json.dumps( + { + "type": "client_info", + "info": "image_chunk_ack", + "id": img_id, + "count": len(image_buffers[img_id]["chunks"]), + } + ) + ) + elif message["type"] == "image_end": + img_id = str(message.get("id")) + buf = image_buffers.pop(img_id, None) + if buf is None: + await websocket.send_text( + json.dumps({"type": "error", "error": "Unknown image id for image_end."}) + ) + else: + data_url = "".join(buf["chunks"]) if buf["chunks"] else None + prompt_text = buf["text"] + if data_url: + logger.info( + "Forwarding chunked image (structured message) to Realtime API (len=%d).", + len(data_url), + ) + user_msg2: RealtimeUserInputMessage = { + "type": "message", + "role": "user", + "content": ( + [ + {"type": "input_image", "image_url": data_url, "detail": "high"}, + {"type": "input_text", "text": prompt_text}, + ] + if prompt_text + else [ + {"type": "input_image", "image_url": data_url, "detail": "high"} + ] + ), + } + await manager.send_user_message(session_id, user_msg2) + await websocket.send_text( + json.dumps( + { + "type": "client_info", + "info": "image_enqueued", + "id": img_id, + "size": len(data_url), + } + ) + ) + else: + await websocket.send_text( + json.dumps({"type": "error", "error": "Empty image."}) + ) + elif message["type"] == "interrupt": + await manager.interrupt(session_id) except WebSocketDisconnect: await manager.disconnect(session_id) @@ -160,4 +313,10 @@ async def read_index(): if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run( + app, + host="0.0.0.0", + port=8000, + # Increased WebSocket frame size to comfortably handle image data URLs. + ws_max_size=16 * 1024 * 1024, + ) diff --git a/examples/realtime/app/static/app.js b/examples/realtime/app/static/app.js index 3ec8fcc99..6858428c6 100644 --- a/examples/realtime/app/static/app.js +++ b/examples/realtime/app/static/app.js @@ -8,26 +8,33 @@ class RealtimeDemo { this.processor = null; this.stream = null; this.sessionId = this.generateSessionId(); - + // Audio playback queue this.audioQueue = []; this.isPlayingAudio = false; this.playbackAudioContext = null; this.currentAudioSource = null; - + this.currentAudioGain = null; // per-chunk gain for smooth fades + this.playbackFadeSec = 0.02; // ~20ms fade to reduce clicks + this.messageNodes = new Map(); // item_id -> DOM node + this.seenItemIds = new Set(); // item_id set for append-only syncing + this.initializeElements(); this.setupEventListeners(); } - + initializeElements() { this.connectBtn = document.getElementById('connectBtn'); this.muteBtn = document.getElementById('muteBtn'); + this.imageBtn = document.getElementById('imageBtn'); + this.imageInput = document.getElementById('imageInput'); + this.imagePrompt = document.getElementById('imagePrompt'); this.status = document.getElementById('status'); this.messagesContent = document.getElementById('messagesContent'); this.eventsContent = document.getElementById('eventsContent'); this.toolsContent = document.getElementById('toolsContent'); } - + setupEventListeners() { this.connectBtn.addEventListener('click', () => { if (this.isConnected) { @@ -36,52 +43,99 @@ class RealtimeDemo { this.connect(); } }); - + this.muteBtn.addEventListener('click', () => { this.toggleMute(); }); + + // Image upload + this.imageBtn.addEventListener('click', (e) => { + e.preventDefault(); + e.stopPropagation(); + console.log('Send Image clicked'); + // Programmatically open the hidden file input + this.imageInput.click(); + }); + + this.imageInput.addEventListener('change', async (e) => { + console.log('Image input change fired'); + const file = e.target.files && e.target.files[0]; + if (!file) return; + await this._handlePickedFile(file); + this.imageInput.value = ''; + }); + + this._handlePickedFile = async (file) => { + try { + const dataUrl = await this.prepareDataURL(file); + const promptText = (this.imagePrompt && this.imagePrompt.value) || ''; + // Send to server; server forwards to Realtime API. + // Use chunked frames to avoid WS frame limits. + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + console.log('Interrupting and sending image (chunked) to server WebSocket'); + // Stop any current audio locally and tell model to interrupt + this.stopAudioPlayback(); + this.ws.send(JSON.stringify({ type: 'interrupt' })); + const id = 'img_' + Math.random().toString(36).slice(2); + const CHUNK = 60_000; // ~60KB per frame + this.ws.send(JSON.stringify({ type: 'image_start', id, text: promptText })); + for (let i = 0; i < dataUrl.length; i += CHUNK) { + const chunk = dataUrl.slice(i, i + CHUNK); + this.ws.send(JSON.stringify({ type: 'image_chunk', id, chunk })); + } + this.ws.send(JSON.stringify({ type: 'image_end', id })); + } else { + console.warn('Not connected; image will not be sent. Click Connect first.'); + } + // Add to UI immediately for better feedback + console.log('Adding local user image bubble'); + this.addUserImageMessage(dataUrl, promptText); + } catch (err) { + console.error('Failed to process image:', err); + } + }; } - + generateSessionId() { return 'session_' + Math.random().toString(36).substr(2, 9); } - + async connect() { try { this.ws = new WebSocket(`ws://localhost:8000/ws/${this.sessionId}`); - + this.ws.onopen = () => { this.isConnected = true; this.updateConnectionUI(); this.startContinuousCapture(); }; - + this.ws.onmessage = (event) => { const data = JSON.parse(event.data); this.handleRealtimeEvent(data); }; - + this.ws.onclose = () => { this.isConnected = false; this.updateConnectionUI(); }; - + this.ws.onerror = (error) => { console.error('WebSocket error:', error); }; - + } catch (error) { console.error('Failed to connect:', error); } } - + disconnect() { if (this.ws) { this.ws.close(); } this.stopContinuousCapture(); } - + updateConnectionUI() { if (this.isConnected) { this.connectBtn.textContent = 'Disconnect'; @@ -97,12 +151,12 @@ class RealtimeDemo { this.muteBtn.disabled = true; } } - + toggleMute() { this.isMuted = !this.isMuted; this.updateMuteUI(); } - + updateMuteUI() { if (this.isMuted) { this.muteBtn.textContent = '🔇 Mic Off'; @@ -115,90 +169,128 @@ class RealtimeDemo { } } } - + + readFileAsDataURL(file) { + return new Promise((resolve, reject) => { + const reader = new FileReader(); + reader.onload = () => resolve(reader.result); + reader.onerror = reject; + reader.readAsDataURL(file); + }); + } + + async prepareDataURL(file) { + const original = await this.readFileAsDataURL(file); + try { + const img = new Image(); + img.decoding = 'async'; + const loaded = new Promise((res, rej) => { + img.onload = () => res(); + img.onerror = rej; + }); + img.src = original; + await loaded; + + const maxDim = 1024; + const maxSide = Math.max(img.width, img.height); + const scale = maxSide > maxDim ? (maxDim / maxSide) : 1; + const w = Math.max(1, Math.round(img.width * scale)); + const h = Math.max(1, Math.round(img.height * scale)); + + const canvas = document.createElement('canvas'); + canvas.width = w; canvas.height = h; + const ctx = canvas.getContext('2d'); + ctx.drawImage(img, 0, 0, w, h); + return canvas.toDataURL('image/jpeg', 0.85); + } catch (e) { + console.warn('Image resize failed; sending original', e); + return original; + } + } + async startContinuousCapture() { if (!this.isConnected || this.isCapturing) return; - + // Check if getUserMedia is available if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) { throw new Error('getUserMedia not available. Please use HTTPS or localhost.'); } - + try { - this.stream = await navigator.mediaDevices.getUserMedia({ + this.stream = await navigator.mediaDevices.getUserMedia({ audio: { sampleRate: 24000, channelCount: 1, echoCancellation: true, noiseSuppression: true - } + } }); - - this.audioContext = new AudioContext({ sampleRate: 24000 }); + + this.audioContext = new AudioContext({ sampleRate: 24000, latencyHint: 'interactive' }); const source = this.audioContext.createMediaStreamSource(this.stream); - + // Create a script processor to capture audio data this.processor = this.audioContext.createScriptProcessor(4096, 1, 1); source.connect(this.processor); this.processor.connect(this.audioContext.destination); - + this.processor.onaudioprocess = (event) => { if (!this.isMuted && this.ws && this.ws.readyState === WebSocket.OPEN) { const inputBuffer = event.inputBuffer.getChannelData(0); const int16Buffer = new Int16Array(inputBuffer.length); - + // Convert float32 to int16 for (let i = 0; i < inputBuffer.length; i++) { int16Buffer[i] = Math.max(-32768, Math.min(32767, inputBuffer[i] * 32768)); } - + this.ws.send(JSON.stringify({ type: 'audio', data: Array.from(int16Buffer) })); } }; - + this.isCapturing = true; this.updateMuteUI(); - + } catch (error) { console.error('Failed to start audio capture:', error); } } - + stopContinuousCapture() { if (!this.isCapturing) return; - + this.isCapturing = false; - + if (this.processor) { this.processor.disconnect(); this.processor = null; } - + if (this.audioContext) { this.audioContext.close(); this.audioContext = null; } - + if (this.stream) { this.stream.getTracks().forEach(track => track.stop()); this.stream = null; } - + this.updateMuteUI(); } - + handleRealtimeEvent(event) { // Add to raw events pane this.addRawEvent(event); - + // Add to tools panel if it's a tool or handoff event if (event.type === 'tool_start' || event.type === 'tool_end' || event.type === 'handoff') { this.addToolEvent(event); } - + // Handle specific event types switch (event.type) { case 'audio': @@ -207,115 +299,214 @@ class RealtimeDemo { case 'audio_interrupted': this.stopAudioPlayback(); break; + case 'input_audio_timeout_triggered': + // Ask server to commit the input buffer to expedite model response + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify({ type: 'commit_audio' })); + } + break; case 'history_updated': - this.updateMessagesFromHistory(event.history); + this.syncMissingFromHistory(event.history); + this.updateLastMessageFromHistory(event.history); + break; + case 'history_added': + // Append just the new item without clearing the thread. + if (event.item) { + this.addMessageFromItem(event.item); + } break; } } - - - updateMessagesFromHistory(history) { - console.log('updateMessagesFromHistory called with:', history); - - // Clear all existing messages - this.messagesContent.innerHTML = ''; - - // Add messages from history - if (history && Array.isArray(history)) { - console.log('Processing history array with', history.length, 'items'); - history.forEach((item, index) => { - console.log(`History item ${index}:`, item); - if (item.type === 'message') { - const role = item.role; - let content = ''; - - console.log(`Message item - role: ${role}, content:`, item.content); - - if (item.content && Array.isArray(item.content)) { - // Extract text from content array - item.content.forEach(contentPart => { - console.log('Content part:', contentPart); - if (contentPart.type === 'text' && contentPart.text) { - content += contentPart.text; - } else if (contentPart.type === 'input_text' && contentPart.text) { - content += contentPart.text; - } else if (contentPart.type === 'input_audio' && contentPart.transcript) { - content += contentPart.transcript; - } else if (contentPart.type === 'audio' && contentPart.transcript) { - content += contentPart.transcript; - } - }); - } - - console.log(`Final content for ${role}:`, content); - - if (content.trim()) { - this.addMessage(role, content.trim()); - console.log(`Added message: ${role} - ${content.trim()}`); + updateLastMessageFromHistory(history) { + if (!history || !Array.isArray(history) || history.length === 0) return; + // Find the last message item in history + let last = null; + for (let i = history.length - 1; i >= 0; i--) { + const it = history[i]; + if (it && it.type === 'message') { last = it; break; } + } + if (!last) return; + const itemId = last.item_id; + + // Extract a text representation (for assistant transcript updates) + let text = ''; + if (Array.isArray(last.content)) { + for (const part of last.content) { + if (!part || typeof part !== 'object') continue; + if (part.type === 'text' && part.text) text += part.text; + else if (part.type === 'input_text' && part.text) text += part.text; + else if ((part.type === 'input_audio' || part.type === 'audio') && part.transcript) text += part.transcript; + } + } + + const node = this.messageNodes.get(itemId); + if (!node) { + // If we haven't rendered this item yet, append it now. + this.addMessageFromItem(last); + return; + } + + // Update only the text content of the bubble, preserving any images already present. + const bubble = node.querySelector('.message-bubble'); + if (bubble && text && text.trim()) { + // If there's an , keep it and only update the trailing caption/text node. + const hasImg = !!bubble.querySelector('img'); + if (hasImg) { + // Ensure there is a caption div after the image + let cap = bubble.querySelector('.image-caption'); + if (!cap) { + cap = document.createElement('div'); + cap.className = 'image-caption'; + cap.style.marginTop = '0.5rem'; + bubble.appendChild(cap); + } + cap.textContent = text.trim(); + } else { + bubble.textContent = text.trim(); + } + this.scrollToBottom(); + } + } + + syncMissingFromHistory(history) { + if (!history || !Array.isArray(history)) return; + for (const item of history) { + if (!item || item.type !== 'message') continue; + const id = item.item_id; + if (!id) continue; + if (!this.seenItemIds.has(id)) { + this.addMessageFromItem(item); + } + } + } + + addMessageFromItem(item) { + try { + if (!item || item.type !== 'message') return; + const role = item.role; + let content = ''; + let imageUrls = []; + + if (Array.isArray(item.content)) { + for (const contentPart of item.content) { + if (!contentPart || typeof contentPart !== 'object') continue; + if (contentPart.type === 'text' && contentPart.text) { + content += contentPart.text; + } else if (contentPart.type === 'input_text' && contentPart.text) { + content += contentPart.text; + } else if (contentPart.type === 'input_audio' && contentPart.transcript) { + content += contentPart.transcript; + } else if (contentPart.type === 'audio' && contentPart.transcript) { + content += contentPart.transcript; + } else if (contentPart.type === 'input_image') { + const url = contentPart.image_url || contentPart.url; + if (typeof url === 'string' && url) imageUrls.push(url); } - } else { - console.log(`Skipping non-message item of type: ${item.type}`); } - }); - } else { - console.log('History is not an array or is null/undefined'); + } + + let node = null; + if (imageUrls.length > 0) { + for (const url of imageUrls) { + node = this.addImageMessage(role, url, content.trim()); + } + } else if (content && content.trim()) { + node = this.addMessage(role, content.trim()); + } + if (node && item.item_id) { + this.messageNodes.set(item.item_id, node); + this.seenItemIds.add(item.item_id); + } + } catch (e) { + console.error('Failed to add message from item:', e, item); } - - this.scrollToBottom(); } - + addMessage(type, content) { const messageDiv = document.createElement('div'); messageDiv.className = `message ${type}`; - + const bubbleDiv = document.createElement('div'); bubbleDiv.className = 'message-bubble'; bubbleDiv.textContent = content; - + messageDiv.appendChild(bubbleDiv); this.messagesContent.appendChild(messageDiv); this.scrollToBottom(); - + return messageDiv; } - + + addImageMessage(role, imageUrl, caption = '') { + const messageDiv = document.createElement('div'); + messageDiv.className = `message ${role}`; + + const bubbleDiv = document.createElement('div'); + bubbleDiv.className = 'message-bubble'; + + const img = document.createElement('img'); + img.src = imageUrl; + img.alt = 'Uploaded image'; + img.style.maxWidth = '220px'; + img.style.borderRadius = '8px'; + img.style.display = 'block'; + + bubbleDiv.appendChild(img); + if (caption) { + const cap = document.createElement('div'); + cap.textContent = caption; + cap.style.marginTop = '0.5rem'; + bubbleDiv.appendChild(cap); + } + + messageDiv.appendChild(bubbleDiv); + this.messagesContent.appendChild(messageDiv); + this.scrollToBottom(); + + return messageDiv; + } + + addUserImageMessage(imageUrl, caption = '') { + return this.addImageMessage('user', imageUrl, caption); + } + addRawEvent(event) { const eventDiv = document.createElement('div'); eventDiv.className = 'event'; - + const headerDiv = document.createElement('div'); headerDiv.className = 'event-header'; headerDiv.innerHTML = ` ${event.type} `; - + const contentDiv = document.createElement('div'); contentDiv.className = 'event-content collapsed'; contentDiv.textContent = JSON.stringify(event, null, 2); - + headerDiv.addEventListener('click', () => { const isCollapsed = contentDiv.classList.contains('collapsed'); contentDiv.classList.toggle('collapsed'); headerDiv.querySelector('span:last-child').textContent = isCollapsed ? '▲' : '▼'; }); - + eventDiv.appendChild(headerDiv); eventDiv.appendChild(contentDiv); this.eventsContent.appendChild(eventDiv); - + // Auto-scroll events pane this.eventsContent.scrollTop = this.eventsContent.scrollHeight; } - + addToolEvent(event) { const eventDiv = document.createElement('div'); eventDiv.className = 'event'; - + let title = ''; let description = ''; let eventClass = ''; - + if (event.type === 'handoff') { title = `🔄 Handoff`; description = `From ${event.from} to ${event.to}`; @@ -329,7 +520,7 @@ class RealtimeDemo { description = `${event.tool}: ${event.output || 'No output'}`; eventClass = 'tool'; } - + eventDiv.innerHTML = `
@@ -339,53 +530,58 @@ class RealtimeDemo { ${new Date().toLocaleTimeString()}
`; - + this.toolsContent.appendChild(eventDiv); - + // Auto-scroll tools pane this.toolsContent.scrollTop = this.toolsContent.scrollHeight; } - + async playAudio(audioBase64) { try { if (!audioBase64 || audioBase64.length === 0) { console.warn('Received empty audio data, skipping playback'); return; } - + // Add to queue this.audioQueue.push(audioBase64); - + // Start processing queue if not already playing if (!this.isPlayingAudio) { this.processAudioQueue(); } - + } catch (error) { console.error('Failed to play audio:', error); } } - + async processAudioQueue() { if (this.isPlayingAudio || this.audioQueue.length === 0) { return; } - + this.isPlayingAudio = true; - + // Initialize audio context if needed if (!this.playbackAudioContext) { - this.playbackAudioContext = new AudioContext({ sampleRate: 24000 }); + this.playbackAudioContext = new AudioContext({ sampleRate: 24000, latencyHint: 'interactive' }); } - + + // Ensure context is running (autoplay policies can suspend it) + if (this.playbackAudioContext.state === 'suspended') { + try { await this.playbackAudioContext.resume(); } catch {} + } + while (this.audioQueue.length > 0) { const audioBase64 = this.audioQueue.shift(); await this.playAudioChunk(audioBase64); } - + this.isPlayingAudio = false; } - + async playAudioChunk(audioBase64) { return new Promise((resolve, reject) => { try { @@ -395,67 +591,99 @@ class RealtimeDemo { for (let i = 0; i < binaryString.length; i++) { bytes[i] = binaryString.charCodeAt(i); } - + const int16Array = new Int16Array(bytes.buffer); - + if (int16Array.length === 0) { console.warn('Audio chunk has no samples, skipping'); resolve(); return; } - + const float32Array = new Float32Array(int16Array.length); - + // Convert int16 to float32 for (let i = 0; i < int16Array.length; i++) { float32Array[i] = int16Array[i] / 32768.0; } - + const audioBuffer = this.playbackAudioContext.createBuffer(1, float32Array.length, 24000); audioBuffer.getChannelData(0).set(float32Array); - + const source = this.playbackAudioContext.createBufferSource(); source.buffer = audioBuffer; - source.connect(this.playbackAudioContext.destination); - - // Store reference to current source + + // Per-chunk gain with short fade-in/out to avoid clicks + const gainNode = this.playbackAudioContext.createGain(); + const now = this.playbackAudioContext.currentTime; + const fade = Math.min(this.playbackFadeSec, Math.max(0.005, audioBuffer.duration / 8)); + try { + gainNode.gain.cancelScheduledValues(now); + gainNode.gain.setValueAtTime(0.0, now); + gainNode.gain.linearRampToValueAtTime(1.0, now + fade); + const endTime = now + audioBuffer.duration; + gainNode.gain.setValueAtTime(1.0, Math.max(now + fade, endTime - fade)); + gainNode.gain.linearRampToValueAtTime(0.0001, endTime); + } catch {} + + source.connect(gainNode); + gainNode.connect(this.playbackAudioContext.destination); + + // Store references to allow smooth stop on interruption this.currentAudioSource = source; - + this.currentAudioGain = gainNode; + source.onended = () => { this.currentAudioSource = null; + this.currentAudioGain = null; resolve(); }; source.start(); - + } catch (error) { console.error('Failed to play audio chunk:', error); reject(error); } }); } - + stopAudioPlayback() { console.log('Stopping audio playback due to interruption'); - - // Stop current audio source if playing - if (this.currentAudioSource) { + + // Smoothly ramp down before stopping to avoid clicks + if (this.currentAudioSource && this.playbackAudioContext) { try { - this.currentAudioSource.stop(); - this.currentAudioSource = null; + const now = this.playbackAudioContext.currentTime; + const fade = Math.max(0.01, this.playbackFadeSec); + if (this.currentAudioGain) { + try { + this.currentAudioGain.gain.cancelScheduledValues(now); + // Capture current value to ramp from it + const current = this.currentAudioGain.gain.value ?? 1.0; + this.currentAudioGain.gain.setValueAtTime(current, now); + this.currentAudioGain.gain.linearRampToValueAtTime(0.0001, now + fade); + } catch {} + } + // Stop after the fade completes + setTimeout(() => { + try { this.currentAudioSource && this.currentAudioSource.stop(); } catch {} + this.currentAudioSource = null; + this.currentAudioGain = null; + }, Math.ceil(fade * 1000)); } catch (error) { console.error('Error stopping audio source:', error); } } - + // Clear the audio queue this.audioQueue = []; - + // Reset playback state this.isPlayingAudio = false; - + console.log('Audio playback stopped and queue cleared'); } - + scrollToBottom() { this.messagesContent.scrollTop = this.messagesContent.scrollHeight; } @@ -464,4 +692,4 @@ class RealtimeDemo { // Initialize the demo when the page loads document.addEventListener('DOMContentLoaded', () => { new RealtimeDemo(); -}); \ No newline at end of file +}); diff --git a/examples/realtime/app/static/favicon.ico b/examples/realtime/app/static/favicon.ico new file mode 100644 index 000000000..e69de29bb diff --git a/examples/realtime/app/static/index.html b/examples/realtime/app/static/index.html index fbd0de46d..aacefbffb 100644 --- a/examples/realtime/app/static/index.html +++ b/examples/realtime/app/static/index.html @@ -204,6 +204,7 @@ background: #f8f9fa; display: flex; gap: 0.5rem; + align-items: center; } .mute-btn { @@ -265,6 +266,9 @@

Realtime Demo

+ + + Disconnected
@@ -292,4 +296,4 @@

Realtime Demo

- \ No newline at end of file + diff --git a/examples/realtime/cli/demo.py b/examples/realtime/cli/demo.py index e372e3ef5..4a7172e30 100644 --- a/examples/realtime/cli/demo.py +++ b/examples/realtime/cli/demo.py @@ -8,13 +8,23 @@ import sounddevice as sd from agents import function_tool -from agents.realtime import RealtimeAgent, RealtimeRunner, RealtimeSession, RealtimeSessionEvent +from agents.realtime import ( + RealtimeAgent, + RealtimePlaybackTracker, + RealtimeRunner, + RealtimeSession, + RealtimeSessionEvent, +) +from agents.realtime.model import RealtimeModelConfig # Audio configuration -CHUNK_LENGTH_S = 0.05 # 50ms +CHUNK_LENGTH_S = 0.04 # 40ms aligns with realtime defaults SAMPLE_RATE = 24000 FORMAT = np.int16 CHANNELS = 1 +ENERGY_THRESHOLD = 0.015 # RMS threshold for barge‑in while assistant is speaking +PREBUFFER_CHUNKS = 3 # initial jitter buffer (~120ms with 40ms chunks) +FADE_OUT_MS = 12 # short fade to avoid clicks when interrupting # Set up logging for OpenAI agents SDK # logging.basicConfig( @@ -49,29 +59,91 @@ def __init__(self) -> None: self.audio_player: sd.OutputStream | None = None self.recording = False + # Playback tracker lets the model know our real playback progress + self.playback_tracker = RealtimePlaybackTracker() + # Audio output state for callback system - self.output_queue: queue.Queue[Any] = queue.Queue(maxsize=10) # Buffer more chunks + # Store tuples: (samples_np, item_id, content_index) + # Use an unbounded queue to avoid drops that sound like skipped words. + self.output_queue: queue.Queue[Any] = queue.Queue(maxsize=0) self.interrupt_event = threading.Event() - self.current_audio_chunk: np.ndarray[Any, np.dtype[Any]] | None = None + self.current_audio_chunk: tuple[np.ndarray[Any, np.dtype[Any]], str, int] | None = None self.chunk_position = 0 + self.bytes_per_sample = np.dtype(FORMAT).itemsize + + # Jitter buffer and fade-out state + self.prebuffering = True + self.prebuffer_target_chunks = PREBUFFER_CHUNKS + self.fading = False + self.fade_total_samples = 0 + self.fade_done_samples = 0 + self.fade_samples = int(SAMPLE_RATE * (FADE_OUT_MS / 1000.0)) def _output_callback(self, outdata, frames: int, time, status) -> None: """Callback for audio output - handles continuous audio stream from server.""" if status: print(f"Output callback status: {status}") - # Check if we should clear the queue due to interrupt + # Handle interruption with a short fade-out to prevent clicks. if self.interrupt_event.is_set(): - # Clear the queue and current chunk state - while not self.output_queue.empty(): - try: - self.output_queue.get_nowait() - except queue.Empty: - break - self.current_audio_chunk = None - self.chunk_position = 0 - self.interrupt_event.clear() outdata.fill(0) + if self.current_audio_chunk is None: + # Nothing to fade, just flush everything and reset. + while not self.output_queue.empty(): + try: + self.output_queue.get_nowait() + except queue.Empty: + break + self.prebuffering = True + self.interrupt_event.clear() + return + + # Prepare fade parameters + if not self.fading: + self.fading = True + self.fade_done_samples = 0 + # Remaining samples in the current chunk + remaining_in_chunk = len(self.current_audio_chunk[0]) - self.chunk_position + self.fade_total_samples = min(self.fade_samples, max(0, remaining_in_chunk)) + + samples, item_id, content_index = self.current_audio_chunk + samples_filled = 0 + while samples_filled < len(outdata) and self.fade_done_samples < self.fade_total_samples: + remaining_output = len(outdata) - samples_filled + remaining_fade = self.fade_total_samples - self.fade_done_samples + n = min(remaining_output, remaining_fade) + + src = samples[self.chunk_position : self.chunk_position + n].astype(np.float32) + # Linear ramp from current level down to 0 across remaining fade samples + idx = np.arange(self.fade_done_samples, self.fade_done_samples + n, dtype=np.float32) + gain = 1.0 - (idx / float(self.fade_total_samples)) + ramped = np.clip(src * gain, -32768.0, 32767.0).astype(np.int16) + outdata[samples_filled : samples_filled + n, 0] = ramped + + # Optionally report played bytes (ramped) to playback tracker + try: + self.playback_tracker.on_play_bytes( + item_id=item_id, item_content_index=content_index, bytes=ramped.tobytes() + ) + except Exception: + pass + + samples_filled += n + self.chunk_position += n + self.fade_done_samples += n + + # If fade completed, flush the remaining audio and reset state + if self.fade_done_samples >= self.fade_total_samples: + self.current_audio_chunk = None + self.chunk_position = 0 + while not self.output_queue.empty(): + try: + self.output_queue.get_nowait() + except queue.Empty: + break + self.fading = False + self.prebuffering = True + self.interrupt_event.clear() return # Fill output buffer from queue and current chunk @@ -82,6 +154,10 @@ def _output_callback(self, outdata, frames: int, time, status) -> None: # If we don't have a current chunk, try to get one from queue if self.current_audio_chunk is None: try: + # Respect a small jitter buffer before starting playback + if self.prebuffering and self.output_queue.qsize() < self.prebuffer_target_chunks: + break + self.prebuffering = False self.current_audio_chunk = self.output_queue.get_nowait() self.chunk_position = 0 except queue.Empty: @@ -92,20 +168,29 @@ def _output_callback(self, outdata, frames: int, time, status) -> None: # Copy data from current chunk to output buffer remaining_output = len(outdata) - samples_filled - remaining_chunk = len(self.current_audio_chunk) - self.chunk_position + samples, item_id, content_index = self.current_audio_chunk + remaining_chunk = len(samples) - self.chunk_position samples_to_copy = min(remaining_output, remaining_chunk) if samples_to_copy > 0: - chunk_data = self.current_audio_chunk[ - self.chunk_position : self.chunk_position + samples_to_copy - ] + chunk_data = samples[self.chunk_position : self.chunk_position + samples_to_copy] # More efficient: direct assignment for mono audio instead of reshape outdata[samples_filled : samples_filled + samples_to_copy, 0] = chunk_data samples_filled += samples_to_copy self.chunk_position += samples_to_copy + # Inform playback tracker about played bytes + try: + self.playback_tracker.on_play_bytes( + item_id=item_id, + item_content_index=content_index, + bytes=chunk_data.tobytes(), + ) + except Exception: + pass + # If we've used up the entire chunk, reset for next iteration - if self.chunk_position >= len(self.current_audio_chunk): + if self.chunk_position >= len(samples): self.current_audio_chunk = None self.chunk_position = 0 @@ -125,7 +210,18 @@ async def run(self) -> None: try: runner = RealtimeRunner(agent) - async with await runner.run() as session: + # Attach playback tracker and enable server‑side interruptions + auto response. + model_config: RealtimeModelConfig = { + "playback_tracker": self.playback_tracker, + "initial_model_settings": { + "turn_detection": { + "type": "semantic_vad", + "interrupt_response": True, + "create_response": True, + }, + }, + } + async with await runner.run(model_config=model_config) as session: self.session = session print("Connected. Starting audio recording...") @@ -170,6 +266,14 @@ async def capture_audio(self) -> None: read_size = int(SAMPLE_RATE * CHUNK_LENGTH_S) try: + # Simple energy-based barge-in: if user speaks while audio is playing, interrupt. + def rms_energy(samples: np.ndarray[Any, np.dtype[Any]]) -> float: + if samples.size == 0: + return 0.0 + # Normalize int16 to [-1, 1] + x = samples.astype(np.float32) / 32768.0 + return float(np.sqrt(np.mean(x * x))) + while self.recording: # Check if there's enough data to read if self.audio_stream.read_available < read_size: @@ -182,8 +286,19 @@ async def capture_audio(self) -> None: # Convert numpy array to bytes audio_bytes = data.tobytes() - # Send audio to session - await self.session.send_audio(audio_bytes) + # Smart barge‑in: if assistant audio is playing, send only if mic has speech. + assistant_playing = ( + self.current_audio_chunk is not None or not self.output_queue.empty() + ) + if assistant_playing: + # Compute RMS energy to detect speech while assistant is talking + samples = data.reshape(-1) + if rms_energy(samples) >= ENERGY_THRESHOLD: + # Locally flush queued assistant audio for snappier interruption. + self.interrupt_event.set() + await self.session.send_audio(audio_bytes) + else: + await self.session.send_audio(audio_bytes) # Yield control back to event loop await asyncio.sleep(0) @@ -212,23 +327,14 @@ async def _on_event(self, event: RealtimeSessionEvent) -> None: elif event.type == "audio_end": print("Audio ended") elif event.type == "audio": - # Enqueue audio for callback-based playback + # Enqueue audio for callback-based playback with metadata np_audio = np.frombuffer(event.audio.data, dtype=np.int16) - try: - self.output_queue.put_nowait(np_audio) - except queue.Full: - # Queue is full - only drop if we have significant backlog - # This prevents aggressive dropping that could cause choppiness - if self.output_queue.qsize() > 8: # Keep some buffer - try: - self.output_queue.get_nowait() - self.output_queue.put_nowait(np_audio) - except queue.Empty: - pass - # If queue isn't too full, just skip this chunk to avoid blocking + # Non-blocking put; queue is unbounded, so drops won’t occur. + self.output_queue.put_nowait((np_audio, event.item_id, event.content_index)) elif event.type == "audio_interrupted": print("Audio interrupted") - # Signal the output callback to clear its queue and state + # Begin graceful fade + flush in the audio callback and rebuild jitter buffer. + self.prebuffering = True self.interrupt_event.set() elif event.type == "error": print(f"Error: {event.error}") @@ -237,7 +343,7 @@ async def _on_event(self, event: RealtimeSessionEvent) -> None: elif event.type == "history_added": pass # Skip these frequent events elif event.type == "raw_model_event": - print(f"Raw model event: {_truncate_str(str(event.data), 50)}") + print(f"Raw model event: {_truncate_str(str(event.data), 200)}") else: print(f"Unknown event type: {event.type}") except Exception as e: diff --git a/pyproject.toml b/pyproject.toml index fb8ac4fb3..a026479a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.9" license = "MIT" authors = [{ name = "OpenAI", email = "support@openai.com" }] dependencies = [ - "openai>=1.104.1,<2", + "openai>=1.107.1,<2", "pydantic>=2.10, <3", "griffe>=1.5.6, <2", "typing-extensions>=4.12.2, <5", diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 1af1a0bae..a574e48ea 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -369,9 +369,9 @@ def convert_message_to_openai( if message.role != "assistant": raise ModelBehaviorError(f"Unsupported role: {message.role}") - tool_calls: list[ - ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall - ] | None = ( + tool_calls: ( + list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None + ) = ( [LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls] if message.tool_calls else None diff --git a/src/agents/realtime/_util.py b/src/agents/realtime/_util.py index c8926edfb..52a3483e9 100644 --- a/src/agents/realtime/_util.py +++ b/src/agents/realtime/_util.py @@ -4,6 +4,6 @@ def calculate_audio_length_ms(format: RealtimeAudioFormat | None, audio_bytes: bytes) -> float: - if format and format.startswith("g711"): + if format and isinstance(format, str) and format.startswith("g711"): return (len(audio_bytes) / 8000) * 1000 return (len(audio_bytes) / 24 / 2) * 1000 diff --git a/src/agents/realtime/agent.py b/src/agents/realtime/agent.py index 29483ac27..c04053db4 100644 --- a/src/agents/realtime/agent.py +++ b/src/agents/realtime/agent.py @@ -6,6 +6,8 @@ from dataclasses import dataclass, field from typing import Any, Callable, Generic, cast +from agents.prompts import Prompt + from ..agent import AgentBase from ..guardrail import OutputGuardrail from ..handoffs import Handoff @@ -55,6 +57,11 @@ class RealtimeAgent(AgentBase, Generic[TContext]): return a string. """ + prompt: Prompt | None = None + """A prompt object. Prompts allow you to dynamically configure the instructions, tools + and other config for an agent outside of your code. Only usable with OpenAI models. + """ + handoffs: list[RealtimeAgent[Any] | Handoff[TContext, RealtimeAgent[Any]]] = field( default_factory=list ) diff --git a/src/agents/realtime/audio_formats.py b/src/agents/realtime/audio_formats.py new file mode 100644 index 000000000..d9757d244 --- /dev/null +++ b/src/agents/realtime/audio_formats.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from openai.types.realtime.realtime_audio_formats import ( + AudioPCM, + AudioPCMA, + AudioPCMU, + RealtimeAudioFormats, +) + +from ..logger import logger + + +def to_realtime_audio_format( + input_audio_format: str | RealtimeAudioFormats | None, +) -> RealtimeAudioFormats | None: + format: RealtimeAudioFormats | None = None + if input_audio_format is not None: + if isinstance(input_audio_format, str): + if input_audio_format in ["pcm16", "audio/pcm", "pcm"]: + format = AudioPCM(type="audio/pcm", rate=24000) + elif input_audio_format in ["g711_ulaw", "audio/pcmu", "pcmu"]: + format = AudioPCMU(type="audio/pcmu") + elif input_audio_format in ["g711_alaw", "audio/pcma", "pcma"]: + format = AudioPCMA(type="audio/pcma") + else: + logger.debug(f"Unknown input_audio_format: {input_audio_format}") + else: + format = input_audio_format + return format diff --git a/src/agents/realtime/config.py b/src/agents/realtime/config.py index 36254012b..8b70c872f 100644 --- a/src/agents/realtime/config.py +++ b/src/agents/realtime/config.py @@ -6,8 +6,13 @@ Union, ) +from openai.types.realtime.realtime_audio_formats import ( + RealtimeAudioFormats as OpenAIRealtimeAudioFormats, +) from typing_extensions import NotRequired, TypeAlias, TypedDict +from agents.prompts import Prompt + from ..guardrail import OutputGuardrail from ..handoffs import Handoff from ..model_settings import ToolChoice @@ -15,6 +20,8 @@ RealtimeModelName: TypeAlias = Union[ Literal[ + "gpt-realtime", + "gpt-realtime-2025-08-28", "gpt-4o-realtime-preview", "gpt-4o-mini-realtime-preview", "gpt-4o-realtime-preview-2025-06-03", @@ -91,6 +98,9 @@ class RealtimeSessionModelSettings(TypedDict): instructions: NotRequired[str] """System instructions for the model.""" + prompt: NotRequired[Prompt] + """The prompt to use for the model.""" + modalities: NotRequired[list[Literal["text", "audio"]]] """The modalities the model should support.""" @@ -100,10 +110,10 @@ class RealtimeSessionModelSettings(TypedDict): speed: NotRequired[float] """The speed of the model's responses.""" - input_audio_format: NotRequired[RealtimeAudioFormat] + input_audio_format: NotRequired[RealtimeAudioFormat | OpenAIRealtimeAudioFormats] """The format for input audio streams.""" - output_audio_format: NotRequired[RealtimeAudioFormat] + output_audio_format: NotRequired[RealtimeAudioFormat | OpenAIRealtimeAudioFormats] """The format for output audio streams.""" input_audio_transcription: NotRequired[RealtimeInputAudioTranscriptionConfig] @@ -177,6 +187,14 @@ class RealtimeUserInputText(TypedDict): """The text content from the user.""" +class RealtimeUserInputImage(TypedDict, total=False): + """An image input from the user (Realtime).""" + + type: Literal["input_image"] + image_url: str + detail: NotRequired[Literal["auto", "low", "high"] | str] + + class RealtimeUserInputMessage(TypedDict): """A message input from the user.""" @@ -186,8 +204,8 @@ class RealtimeUserInputMessage(TypedDict): role: Literal["user"] """The role identifier for user messages.""" - content: list[RealtimeUserInputText] - """List of text content items in the message.""" + content: list[RealtimeUserInputText | RealtimeUserInputImage] + """List of content items (text and image) in the message.""" RealtimeUserInput: TypeAlias = Union[str, RealtimeUserInputMessage] diff --git a/src/agents/realtime/items.py b/src/agents/realtime/items.py index f8a288145..58106fad8 100644 --- a/src/agents/realtime/items.py +++ b/src/agents/realtime/items.py @@ -34,6 +34,22 @@ class InputAudio(BaseModel): model_config = ConfigDict(extra="allow") +class InputImage(BaseModel): + """Image input content for realtime messages.""" + + type: Literal["input_image"] = "input_image" + """The type identifier for image input.""" + + image_url: str | None = None + """Data/remote URL string (data:... or https:...).""" + + detail: str | None = None + """Optional detail hint (e.g., 'auto', 'high', 'low').""" + + # Allow extra data (e.g., `detail`) + model_config = ConfigDict(extra="allow") + + class AssistantText(BaseModel): """Text content from the assistant in realtime responses.""" @@ -100,7 +116,7 @@ class UserMessageItem(BaseModel): role: Literal["user"] = "user" """The role identifier for user messages.""" - content: list[Annotated[InputText | InputAudio, Field(discriminator="type")]] + content: list[Annotated[InputText | InputAudio | InputImage, Field(discriminator="type")]] """List of content items, can be text or audio.""" # Allow extra data diff --git a/src/agents/realtime/model_inputs.py b/src/agents/realtime/model_inputs.py index df09e6697..9d7ab143d 100644 --- a/src/agents/realtime/model_inputs.py +++ b/src/agents/realtime/model_inputs.py @@ -24,12 +24,26 @@ class RealtimeModelInputTextContent(TypedDict): text: str +class RealtimeModelInputImageContent(TypedDict, total=False): + """An image to be sent to the model. + + The Realtime API expects `image_url` to be a string data/remote URL. + """ + + type: Literal["input_image"] + image_url: str + """String URL (data:... or https:...).""" + + detail: NotRequired[str] + """Optional detail hint such as 'high', 'low', or 'auto'.""" + + class RealtimeModelUserInputMessage(TypedDict): """A message to be sent to the model.""" type: Literal["message"] role: Literal["user"] - content: list[RealtimeModelInputTextContent] + content: list[RealtimeModelInputTextContent | RealtimeModelInputImageContent] RealtimeModelUserInput: TypeAlias = Union[str, RealtimeModelUserInputMessage] diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index b9048a1ec..4d6cf398c 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -5,59 +5,87 @@ import inspect import json import os +from collections.abc import Mapping from datetime import datetime -from typing import Annotated, Any, Callable, Literal, Union +from typing import Annotated, Any, Callable, Literal, Union, cast import pydantic import websockets -from openai.types.beta.realtime.conversation_item import ( +from openai.types.realtime import realtime_audio_config as _rt_audio_config +from openai.types.realtime.conversation_item import ( ConversationItem, ConversationItem as OpenAIConversationItem, ) -from openai.types.beta.realtime.conversation_item_content import ( - ConversationItemContent as OpenAIConversationItemContent, -) -from openai.types.beta.realtime.conversation_item_create_event import ( +from openai.types.realtime.conversation_item_create_event import ( ConversationItemCreateEvent as OpenAIConversationItemCreateEvent, ) -from openai.types.beta.realtime.conversation_item_retrieve_event import ( +from openai.types.realtime.conversation_item_retrieve_event import ( ConversationItemRetrieveEvent as OpenAIConversationItemRetrieveEvent, ) -from openai.types.beta.realtime.conversation_item_truncate_event import ( +from openai.types.realtime.conversation_item_truncate_event import ( ConversationItemTruncateEvent as OpenAIConversationItemTruncateEvent, ) -from openai.types.beta.realtime.input_audio_buffer_append_event import ( +from openai.types.realtime.input_audio_buffer_append_event import ( InputAudioBufferAppendEvent as OpenAIInputAudioBufferAppendEvent, ) -from openai.types.beta.realtime.input_audio_buffer_commit_event import ( +from openai.types.realtime.input_audio_buffer_commit_event import ( InputAudioBufferCommitEvent as OpenAIInputAudioBufferCommitEvent, ) -from openai.types.beta.realtime.realtime_client_event import ( +from openai.types.realtime.realtime_audio_formats import ( + AudioPCM, + AudioPCMA, + AudioPCMU, +) +from openai.types.realtime.realtime_client_event import ( RealtimeClientEvent as OpenAIRealtimeClientEvent, ) -from openai.types.beta.realtime.realtime_server_event import ( +from openai.types.realtime.realtime_conversation_item_assistant_message import ( + RealtimeConversationItemAssistantMessage, +) +from openai.types.realtime.realtime_conversation_item_function_call_output import ( + RealtimeConversationItemFunctionCallOutput, +) +from openai.types.realtime.realtime_conversation_item_system_message import ( + RealtimeConversationItemSystemMessage, +) +from openai.types.realtime.realtime_conversation_item_user_message import ( + Content, + RealtimeConversationItemUserMessage, +) +from openai.types.realtime.realtime_function_tool import ( + RealtimeFunctionTool as OpenAISessionFunction, +) +from openai.types.realtime.realtime_server_event import ( RealtimeServerEvent as OpenAIRealtimeServerEvent, ) -from openai.types.beta.realtime.response_audio_delta_event import ResponseAudioDeltaEvent -from openai.types.beta.realtime.response_cancel_event import ( +from openai.types.realtime.realtime_session_create_request import ( + RealtimeSessionCreateRequest as OpenAISessionCreateRequest, +) +from openai.types.realtime.realtime_tracing_config import ( + TracingConfiguration as OpenAITracingConfiguration, +) +from openai.types.realtime.realtime_transcription_session_create_request import ( + RealtimeTranscriptionSessionCreateRequest as OpenAIRealtimeTranscriptionSessionCreateRequest, +) +from openai.types.realtime.response_audio_delta_event import ResponseAudioDeltaEvent +from openai.types.realtime.response_cancel_event import ( ResponseCancelEvent as OpenAIResponseCancelEvent, ) -from openai.types.beta.realtime.response_create_event import ( +from openai.types.realtime.response_create_event import ( ResponseCreateEvent as OpenAIResponseCreateEvent, ) -from openai.types.beta.realtime.session_update_event import ( - Session as OpenAISessionObject, - SessionTool as OpenAISessionTool, - SessionTracing as OpenAISessionTracing, - SessionTracingTracingConfiguration as OpenAISessionTracingConfiguration, +from openai.types.realtime.session_update_event import ( SessionUpdateEvent as OpenAISessionUpdateEvent, ) -from pydantic import BaseModel, Field, TypeAdapter +from openai.types.responses.response_prompt import ResponsePrompt +from pydantic import Field, TypeAdapter from typing_extensions import assert_never from websockets.asyncio.client import ClientConnection from agents.handoffs import Handoff +from agents.prompts import Prompt from agents.realtime._default_tracker import ModelAudioTracker +from agents.realtime.audio_formats import to_realtime_audio_format from agents.tool import FunctionTool, Tool from agents.util._types import MaybeAwaitable @@ -103,17 +131,23 @@ RealtimeModelSendUserInput, ) +# Avoid direct imports of non-exported names by referencing via module +OpenAIRealtimeAudioConfig = _rt_audio_config.RealtimeAudioConfig +OpenAIRealtimeAudioInput = _rt_audio_config.RealtimeAudioConfigInput # type: ignore[attr-defined] +OpenAIRealtimeAudioOutput = _rt_audio_config.RealtimeAudioConfigOutput # type: ignore[attr-defined] + + _USER_AGENT = f"Agents/Python {__version__}" DEFAULT_MODEL_SETTINGS: RealtimeSessionModelSettings = { "voice": "ash", - "modalities": ["text", "audio"], + "modalities": ["audio"], "input_audio_format": "pcm16", "output_audio_format": "pcm16", "input_audio_transcription": { "model": "gpt-4o-mini-transcribe", }, - "turn_detection": {"type": "semantic_vad"}, + "turn_detection": {"type": "semantic_vad", "interrupt_response": True}, } @@ -129,19 +163,8 @@ async def get_api_key(key: str | Callable[[], MaybeAwaitable[str]] | None) -> st return os.getenv("OPENAI_API_KEY") -class _InputAudioBufferTimeoutTriggeredEvent(BaseModel): - type: Literal["input_audio_buffer.timeout_triggered"] - event_id: str - audio_start_ms: int - audio_end_ms: int - item_id: str - - AllRealtimeServerEvents = Annotated[ - Union[ - OpenAIRealtimeServerEvent, - _InputAudioBufferTimeoutTriggeredEvent, - ], + Union[OpenAIRealtimeServerEvent,], Field(discriminator="type"), ] @@ -155,11 +178,16 @@ def get_server_event_type_adapter() -> TypeAdapter[AllRealtimeServerEvents]: return ServerEventTypeAdapter +# Note: Avoid a module-level union alias for Python 3.9 compatibility. +# Using a union at runtime (e.g., A | B) in a type alias triggers evaluation +# during import on 3.9. We instead inline the union in annotations below. + + class OpenAIRealtimeWebSocketModel(RealtimeModel): """A model that uses OpenAI's WebSocket API.""" def __init__(self) -> None: - self.model = "gpt-4o-realtime-preview" # Default model + self.model = "gpt-realtime" # Default model self._websocket: ClientConnection | None = None self._websocket_task: asyncio.Task[None] | None = None self._listeners: list[RealtimeModelListener] = [] @@ -168,7 +196,7 @@ def __init__(self) -> None: self._ongoing_response: bool = False self._tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None = None self._playback_tracker: RealtimePlaybackTracker | None = None - self._created_session: OpenAISessionObject | None = None + self._created_session: OpenAISessionCreateRequest | None = None self._server_event_type_adapter = get_server_event_type_adapter() async def connect(self, options: RealtimeModelConfig) -> None: @@ -199,12 +227,7 @@ async def connect(self, options: RealtimeModelConfig) -> None: if not api_key: raise UserError("API key is required but was not provided.") - headers.update( - { - "Authorization": f"Bearer {api_key}", - "OpenAI-Beta": "realtime=v1", - } - ) + headers.update({"Authorization": f"Bearer {api_key}"}) self._websocket = await websockets.connect( url, user_agent_header=_USER_AGENT, @@ -222,7 +245,11 @@ async def _send_tracing_config( converted_tracing_config = _ConversionHelper.convert_tracing_config(tracing_config) await self._send_raw_message( OpenAISessionUpdateEvent( - session=OpenAISessionObject(tracing=converted_tracing_config), + session=OpenAISessionCreateRequest( + model=self.model, + type="realtime", + tracing=converted_tracing_config, + ), type="session.update", ) ) @@ -304,8 +331,8 @@ async def send_event(self, event: RealtimeModelSendEvent) -> None: async def _send_raw_message(self, event: OpenAIRealtimeClientEvent) -> None: """Send a raw message to the model.""" assert self._websocket is not None, "Not connected" - - await self._websocket.send(event.model_dump_json(exclude_none=True, exclude_unset=True)) + payload = event.model_dump_json(exclude_none=True, exclude_unset=True) + await self._websocket.send(payload) async def _send_user_input(self, event: RealtimeModelSendUserInput) -> None: converted = _ConversionHelper.convert_user_input_to_item_create(event) @@ -398,10 +425,13 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None: f"content index: {current_item_content_index}" ) + session = self._created_session automatic_response_cancellation_enabled = ( - self._created_session - and self._created_session.turn_detection - and self._created_session.turn_detection.interrupt_response + session + and session.audio is not None + and session.audio.input is not None + and session.audio.input.turn_detection is not None + and session.audio.input.turn_detection.interrupt_response is True, ) if not automatic_response_cancellation_enabled: await self._cancel_response() @@ -495,40 +525,103 @@ async def _cancel_response(self) -> None: async def _handle_ws_event(self, event: dict[str, Any]): await self._emit_event(RealtimeModelRawServerEvent(data=event)) + # The public interface definedo on this Agents SDK side (e.g., RealtimeMessageItem) + # must be the same even after the GA migration, so this part does the conversion + if isinstance(event, dict) and event.get("type") in ( + "response.output_item.added", + "response.output_item.done", + ): + item = event.get("item") + if isinstance(item, dict) and item.get("type") == "message": + raw_content = item.get("content") or [] + converted_content: list[dict[str, Any]] = [] + for part in raw_content: + if not isinstance(part, dict): + continue + if part.get("type") == "audio": + converted_content.append( + { + "type": "audio", + "audio": part.get("audio"), + "transcript": part.get("transcript"), + } + ) + elif part.get("type") == "text": + converted_content.append({"type": "text", "text": part.get("text")}) + status = item.get("status") + if status not in ("in_progress", "completed", "incomplete"): + is_done = event.get("type") == "response.output_item.done" + status = "completed" if is_done else "in_progress" + # Explicitly type the adapter for mypy + type_adapter: TypeAdapter[RealtimeMessageItem] = TypeAdapter(RealtimeMessageItem) + message_item: RealtimeMessageItem = type_adapter.validate_python( + { + "item_id": item.get("id", ""), + "type": "message", + "role": item.get("role", "assistant"), + "content": converted_content, + "status": status, + } + ) + await self._emit_event(RealtimeModelItemUpdatedEvent(item=message_item)) + return + try: if "previous_item_id" in event and event["previous_item_id"] is None: event["previous_item_id"] = "" # TODO (rm) remove parsed: AllRealtimeServerEvents = self._server_event_type_adapter.validate_python(event) except pydantic.ValidationError as e: logger.error(f"Failed to validate server event: {event}", exc_info=True) - await self._emit_event( - RealtimeModelErrorEvent( - error=e, - ) - ) + await self._emit_event(RealtimeModelErrorEvent(error=e)) return except Exception as e: event_type = event.get("type", "unknown") if isinstance(event, dict) else "unknown" logger.error(f"Failed to validate server event: {event}", exc_info=True) - await self._emit_event( - RealtimeModelExceptionEvent( - exception=e, - context=f"Failed to validate server event: {event_type}", - ) + exception_event = RealtimeModelExceptionEvent( + exception=e, + context=f"Failed to validate server event: {event_type}", ) + await self._emit_event(exception_event) return - if parsed.type == "response.audio.delta": + if parsed.type == "response.output_audio.delta": await self._handle_audio_delta(parsed) - elif parsed.type == "response.audio.done": - await self._emit_event( - RealtimeModelAudioDoneEvent( - item_id=parsed.item_id, - content_index=parsed.content_index, - ) + elif parsed.type == "response.output_audio.done": + audio_done_event = RealtimeModelAudioDoneEvent( + item_id=parsed.item_id, + content_index=parsed.content_index, ) + await self._emit_event(audio_done_event) elif parsed.type == "input_audio_buffer.speech_started": - await self._send_interrupt(RealtimeModelSendInterrupt()) + # On VAD speech start, immediately stop local playback so the user can + # barge‑in without overlapping assistant audio. + last_audio = self._audio_state_tracker.get_last_audio_item() + if last_audio is not None: + item_id, content_index = last_audio + await self._emit_event( + RealtimeModelAudioInterruptedEvent(item_id=item_id, content_index=content_index) + ) + + # Reset trackers so subsequent playback state queries don't + # reference audio that has been interrupted client‑side. + self._audio_state_tracker.on_interrupted() + if self._playback_tracker: + self._playback_tracker.on_interrupted() + + # If server isn't configured to auto‑interrupt/cancel, cancel the + # response to prevent further audio. + session = self._created_session + automatic_response_cancellation_enabled = ( + session + and session.audio is not None + and session.audio.input is not None + and session.audio.input.turn_detection is not None + and session.audio.input.turn_detection.interrupt_response is True, + ) + if not automatic_response_cancellation_enabled: + await self._cancel_response() + # Avoid sending conversation.item.truncate here; when GA is set to + # interrupt on VAD start, the server will handle truncation. elif parsed.type == "response.created": self._ongoing_response = True await self._emit_event(RealtimeModelTurnStartedEvent()) @@ -537,15 +630,16 @@ async def _handle_ws_event(self, event: dict[str, Any]): await self._emit_event(RealtimeModelTurnEndedEvent()) elif parsed.type == "session.created": await self._send_tracing_config(self._tracing_config) - self._update_created_session(parsed.session) # type: ignore + self._update_created_session(parsed.session) elif parsed.type == "session.updated": - self._update_created_session(parsed.session) # type: ignore + self._update_created_session(parsed.session) elif parsed.type == "error": await self._emit_event(RealtimeModelErrorEvent(error=parsed.error)) elif parsed.type == "conversation.item.deleted": await self._emit_event(RealtimeModelItemDeletedEvent(item_id=parsed.item_id)) elif ( - parsed.type == "conversation.item.created" + parsed.type == "conversation.item.added" + or parsed.type == "conversation.item.created" or parsed.type == "conversation.item.retrieved" ): previous_item_id = ( @@ -570,7 +664,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): item_id=parsed.item_id, transcript=parsed.transcript ) ) - elif parsed.type == "response.audio_transcript.delta": + elif parsed.type == "response.output_audio_transcript.delta": await self._emit_event( RealtimeModelTranscriptDeltaEvent( item_id=parsed.item_id, delta=parsed.delta, response_id=parsed.response_id @@ -578,7 +672,7 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) elif ( parsed.type == "conversation.item.input_audio_transcription.delta" - or parsed.type == "response.text.delta" + or parsed.type == "response.output_text.delta" or parsed.type == "response.function_call_arguments.delta" ): # No support for partials yet @@ -597,12 +691,107 @@ async def _handle_ws_event(self, event: dict[str, Any]): ) ) - def _update_created_session(self, session: OpenAISessionObject) -> None: - self._created_session = session - if session.output_audio_format: - self._audio_state_tracker.set_audio_format(session.output_audio_format) - if self._playback_tracker: - self._playback_tracker.set_audio_format(session.output_audio_format) + def _update_created_session( + self, + session: OpenAISessionCreateRequest + | OpenAIRealtimeTranscriptionSessionCreateRequest + | Mapping[str, object] + | pydantic.BaseModel, + ) -> None: + # Only store/playback-format information for realtime sessions (not transcription-only) + normalized_session = self._normalize_session_payload(session) + if not normalized_session: + return + + self._created_session = normalized_session + normalized_format = self._extract_audio_format(normalized_session) + if normalized_format is None: + return + + self._audio_state_tracker.set_audio_format(normalized_format) + if self._playback_tracker: + self._playback_tracker.set_audio_format(normalized_format) + + @staticmethod + def _normalize_session_payload( + session: OpenAISessionCreateRequest + | OpenAIRealtimeTranscriptionSessionCreateRequest + | Mapping[str, object] + | pydantic.BaseModel, + ) -> OpenAISessionCreateRequest | None: + if isinstance(session, OpenAISessionCreateRequest): + return session + + if isinstance(session, OpenAIRealtimeTranscriptionSessionCreateRequest): + return None + + session_payload: Mapping[str, object] + if isinstance(session, pydantic.BaseModel): + session_payload = cast(Mapping[str, object], session.model_dump()) + elif isinstance(session, Mapping): + session_payload = session + else: + return None + + if OpenAIRealtimeWebSocketModel._is_transcription_session(session_payload): + return None + + try: + return OpenAISessionCreateRequest.model_validate(session_payload) + except pydantic.ValidationError: + return None + + @staticmethod + def _is_transcription_session(payload: Mapping[str, object]) -> bool: + try: + OpenAIRealtimeTranscriptionSessionCreateRequest.model_validate(payload) + except pydantic.ValidationError: + return False + else: + return True + + @staticmethod + def _extract_audio_format(session: OpenAISessionCreateRequest) -> str | None: + audio = session.audio + if not audio or not audio.output or not audio.output.format: + return None + + return OpenAIRealtimeWebSocketModel._normalize_audio_format(audio.output.format) + + @staticmethod + def _normalize_audio_format(fmt: object) -> str: + if isinstance(fmt, AudioPCM): + return "pcm16" + if isinstance(fmt, AudioPCMU): + return "g711_ulaw" + if isinstance(fmt, AudioPCMA): + return "g711_alaw" + + fmt_type = OpenAIRealtimeWebSocketModel._read_format_type(fmt) + if isinstance(fmt_type, str) and fmt_type: + return fmt_type + + return str(fmt) + + @staticmethod + def _read_format_type(fmt: object) -> str | None: + if isinstance(fmt, str): + return fmt + + if isinstance(fmt, Mapping): + type_value = fmt.get("type") + return type_value if isinstance(type_value, str) else None + + if isinstance(fmt, pydantic.BaseModel): + type_value = fmt.model_dump().get("type") + return type_value if isinstance(type_value, str) else None + + try: + type_value = fmt.type # type: ignore[attr-defined] + except AttributeError: + return None + + return type_value if isinstance(type_value, str) else None async def _update_session_config(self, model_settings: RealtimeSessionModelSettings) -> None: session_config = self._get_session_config(model_settings) @@ -612,51 +801,95 @@ async def _update_session_config(self, model_settings: RealtimeSessionModelSetti def _get_session_config( self, model_settings: RealtimeSessionModelSettings - ) -> OpenAISessionObject: + ) -> OpenAISessionCreateRequest: """Get the session config.""" - return OpenAISessionObject( - instructions=model_settings.get("instructions", None), - model=( - model_settings.get("model_name", self.model) # type: ignore - or DEFAULT_MODEL_SETTINGS.get("model_name") - ), - voice=model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")), - speed=model_settings.get("speed", None), - modalities=model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities")), - input_audio_format=model_settings.get( - "input_audio_format", - DEFAULT_MODEL_SETTINGS.get("input_audio_format"), # type: ignore - ), - output_audio_format=model_settings.get( - "output_audio_format", - DEFAULT_MODEL_SETTINGS.get("output_audio_format"), # type: ignore - ), - input_audio_transcription=model_settings.get( - "input_audio_transcription", - DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"), # type: ignore - ), - turn_detection=model_settings.get( - "turn_detection", - DEFAULT_MODEL_SETTINGS.get("turn_detection"), # type: ignore - ), - tool_choice=model_settings.get( - "tool_choice", - DEFAULT_MODEL_SETTINGS.get("tool_choice"), # type: ignore - ), - tools=self._tools_to_session_tools( - tools=model_settings.get("tools", []), handoffs=model_settings.get("handoffs", []) + model_name = (model_settings.get("model_name") or self.model) or "gpt-realtime" + + voice = model_settings.get("voice", DEFAULT_MODEL_SETTINGS.get("voice")) + speed = model_settings.get("speed") + modalities = model_settings.get("modalities", DEFAULT_MODEL_SETTINGS.get("modalities")) + + input_audio_format = model_settings.get( + "input_audio_format", + DEFAULT_MODEL_SETTINGS.get("input_audio_format"), + ) + input_audio_transcription = model_settings.get( + "input_audio_transcription", + DEFAULT_MODEL_SETTINGS.get("input_audio_transcription"), + ) + turn_detection = model_settings.get( + "turn_detection", + DEFAULT_MODEL_SETTINGS.get("turn_detection"), + ) + output_audio_format = model_settings.get( + "output_audio_format", + DEFAULT_MODEL_SETTINGS.get("output_audio_format"), + ) + + input_audio_config = None + if any( + value is not None + for value in [input_audio_format, input_audio_transcription, turn_detection] + ): + input_audio_config = OpenAIRealtimeAudioInput( + format=to_realtime_audio_format(input_audio_format), + transcription=cast(Any, input_audio_transcription), + turn_detection=cast(Any, turn_detection), + ) + + output_audio_config = None + if any(value is not None for value in [output_audio_format, speed, voice]): + output_audio_config = OpenAIRealtimeAudioOutput( + format=to_realtime_audio_format(output_audio_format), + speed=speed, + voice=voice, + ) + + audio_config = None + if input_audio_config or output_audio_config: + audio_config = OpenAIRealtimeAudioConfig( + input=input_audio_config, + output=output_audio_config, + ) + + prompt: ResponsePrompt | None = None + if model_settings.get("prompt") is not None: + _passed_prompt: Prompt = model_settings["prompt"] + variables: dict[str, Any] | None = _passed_prompt.get("variables") + prompt = ResponsePrompt( + id=_passed_prompt["id"], + variables=variables, + version=_passed_prompt.get("version"), + ) + + # Construct full session object. `type` will be excluded at serialization time for updates. + return OpenAISessionCreateRequest( + model=model_name, + type="realtime", + instructions=model_settings.get("instructions"), + prompt=prompt, + output_modalities=modalities, + audio=audio_config, + max_output_tokens=cast(Any, model_settings.get("max_output_tokens")), + tool_choice=cast(Any, model_settings.get("tool_choice")), + tools=cast( + Any, + self._tools_to_session_tools( + tools=model_settings.get("tools", []), + handoffs=model_settings.get("handoffs", []), + ), ), ) def _tools_to_session_tools( self, tools: list[Tool], handoffs: list[Handoff] - ) -> list[OpenAISessionTool]: - converted_tools: list[OpenAISessionTool] = [] + ) -> list[OpenAISessionFunction]: + converted_tools: list[OpenAISessionFunction] = [] for tool in tools: if not isinstance(tool, FunctionTool): raise UserError(f"Tool {tool.name} is unsupported. Must be a function tool.") converted_tools.append( - OpenAISessionTool( + OpenAISessionFunction( name=tool.name, description=tool.description, parameters=tool.params_json_schema, @@ -666,7 +899,7 @@ def _tools_to_session_tools( for handoff in handoffs: converted_tools.append( - OpenAISessionTool( + OpenAISessionFunction( name=handoff.tool_name, description=handoff.tool_description, parameters=handoff.input_json_schema, @@ -682,15 +915,32 @@ class _ConversionHelper: def conversation_item_to_realtime_message_item( cls, item: ConversationItem, previous_item_id: str | None ) -> RealtimeMessageItem: + if not isinstance( + item, + ( + RealtimeConversationItemUserMessage, + RealtimeConversationItemAssistantMessage, + RealtimeConversationItemSystemMessage, + ), + ): + raise ValueError("Unsupported conversation item type for message conversion.") + content: list[dict[str, Any]] = [] + for each in item.content: + c = each.model_dump() + if each.type == "output_text": + # For backward-compatibility of assistant message items + c["type"] = "text" + elif each.type == "output_audio": + # For backward-compatibility of assistant message items + c["type"] = "audio" + content.append(c) return TypeAdapter(RealtimeMessageItem).validate_python( { "item_id": item.id or "", "previous_item_id": previous_item_id, "type": item.type, "role": item.role, - "content": ( - [content.model_dump() for content in item.content] if item.content else [] - ), + "content": content, "status": "in_progress", }, ) @@ -710,12 +960,12 @@ def try_convert_raw_message( @classmethod def convert_tracing_config( cls, tracing_config: RealtimeModelTracingConfig | Literal["auto"] | None - ) -> OpenAISessionTracing | None: + ) -> OpenAITracingConfiguration | Literal["auto"] | None: if tracing_config is None: return None elif tracing_config == "auto": return "auto" - return OpenAISessionTracingConfiguration( + return OpenAITracingConfiguration( group_id=tracing_config.get("group_id"), metadata=tracing_config.get("metadata"), workflow_name=tracing_config.get("workflow_name"), @@ -728,22 +978,53 @@ def convert_user_input_to_conversation_item( user_input = event.user_input if isinstance(user_input, dict): - return OpenAIConversationItem( + content: list[Content] = [] + for item in user_input.get("content", []): + try: + if not isinstance(item, dict): + continue + t = item.get("type") + if t == "input_text": + _txt = item.get("text") + text_val = _txt if isinstance(_txt, str) else None + content.append(Content(type="input_text", text=text_val)) + elif t == "input_image": + iu = item.get("image_url") + if isinstance(iu, str) and iu: + d = item.get("detail") + detail_val = cast( + Literal["auto", "low", "high"] | None, + d if isinstance(d, str) and d in ("auto", "low", "high") else None, + ) + if detail_val is None: + content.append( + Content( + type="input_image", + image_url=iu, + ) + ) + else: + content.append( + Content( + type="input_image", + image_url=iu, + detail=detail_val, + ) + ) + # ignore unknown types for forward-compat + except Exception: + # best-effort; skip malformed parts + continue + return RealtimeConversationItemUserMessage( type="message", role="user", - content=[ - OpenAIConversationItemContent( - type="input_text", - text=item.get("text"), - ) - for item in user_input.get("content", []) - ], + content=content, ) else: - return OpenAIConversationItem( + return RealtimeConversationItemUserMessage( type="message", role="user", - content=[OpenAIConversationItemContent(type="input_text", text=user_input)], + content=[Content(type="input_text", text=user_input)], ) @classmethod @@ -769,7 +1050,7 @@ def convert_audio_to_input_audio_buffer_append( def convert_tool_output(cls, event: RealtimeModelSendToolOutput) -> OpenAIRealtimeClientEvent: return OpenAIConversationItemCreateEvent( type="conversation.item.create", - item=OpenAIConversationItem( + item=RealtimeConversationItemFunctionCallOutput( type="function_call_output", output=event.output, call_id=event.tool_call.call_id, diff --git a/src/agents/realtime/session.py b/src/agents/realtime/session.py index 32c418fac..62adc529c 100644 --- a/src/agents/realtime/session.py +++ b/src/agents/realtime/session.py @@ -35,7 +35,16 @@ RealtimeToolStart, ) from .handoffs import realtime_handoff -from .items import AssistantAudio, InputAudio, InputText, RealtimeItem +from .items import ( + AssistantAudio, + AssistantMessageItem, + AssistantText, + InputAudio, + InputImage, + InputText, + RealtimeItem, + UserMessageItem, +) from .model import RealtimeModel, RealtimeModelConfig, RealtimeModelListener from .model_events import ( RealtimeModelEvent, @@ -230,10 +239,17 @@ async def on_event(self, event: RealtimeModelEvent) -> None: ) ) elif event.type == "input_audio_transcription_completed": + prev_len = len(self._history) self._history = RealtimeSession._get_new_history(self._history, event) - await self._put_event( - RealtimeHistoryUpdated(info=self._event_info, history=self._history) - ) + # If a new user item was appended (no existing item), + # emit history_added for incremental UIs. + if len(self._history) > prev_len and len(self._history) > 0: + new_item = self._history[-1] + await self._put_event(RealtimeHistoryAdded(info=self._event_info, item=new_item)) + else: + await self._put_event( + RealtimeHistoryUpdated(info=self._event_info, history=self._history) + ) elif event.type == "input_audio_timeout_triggered": await self._put_event( RealtimeInputAudioTimeoutTriggered( @@ -248,6 +264,13 @@ async def on_event(self, event: RealtimeModelEvent) -> None: self._item_guardrail_run_counts[item_id] = 0 self._item_transcripts[item_id] += event.delta + self._history = self._get_new_history( + self._history, + AssistantMessageItem( + item_id=item_id, + content=[AssistantAudio(transcript=self._item_transcripts[item_id])], + ), + ) # Check if we should run guardrails based on debounce threshold current_length = len(self._item_transcripts[item_id]) @@ -297,7 +320,7 @@ async def on_event(self, event: RealtimeModelEvent) -> None: # If still missing and this is an assistant item, fall back to # accumulated transcript deltas tracked during the turn. - if not preserved and incoming_item.role == "assistant": + if incoming_item.role == "assistant": preserved = self._item_transcripts.get(incoming_item.item_id) if preserved: @@ -462,9 +485,9 @@ def _get_new_history( old_history: list[RealtimeItem], event: RealtimeModelInputAudioTranscriptionCompletedEvent | RealtimeItem, ) -> list[RealtimeItem]: - # Merge transcript into placeholder input_audio message. if isinstance(event, RealtimeModelInputAudioTranscriptionCompletedEvent): new_history: list[RealtimeItem] = [] + existing_item_found = False for item in old_history: if item.item_id == event.item_id and item.type == "message" and item.role == "user": content: list[InputText | InputAudio] = [] @@ -477,11 +500,18 @@ def _get_new_history( new_history.append( item.model_copy(update={"content": content, "status": "completed"}) ) + existing_item_found = True else: new_history.append(item) + + if existing_item_found is False: + new_history.append( + UserMessageItem( + item_id=event.item_id, content=[InputText(text=event.transcript)] + ) + ) return new_history - # Otherwise it's just a new item # TODO (rm) Add support for audio storage config # If the item already exists, update it @@ -490,8 +520,122 @@ def _get_new_history( ) if existing_index is not None: new_history = old_history.copy() - new_history[existing_index] = event + if event.type == "message" and event.content is not None and len(event.content) > 0: + existing_item = old_history[existing_index] + if existing_item.type == "message": + # Merge content preserving existing transcript/text when incoming entry is empty + if event.role == "assistant" and existing_item.role == "assistant": + assistant_existing_content = existing_item.content + assistant_incoming = event.content + assistant_new_content: list[AssistantText | AssistantAudio] = [] + for idx, ac in enumerate(assistant_incoming): + if idx >= len(assistant_existing_content): + assistant_new_content.append(ac) + continue + assistant_current = assistant_existing_content[idx] + if ac.type == "audio": + if ac.transcript is None: + assistant_new_content.append(assistant_current) + else: + assistant_new_content.append(ac) + else: # text + cur_text = ( + assistant_current.text + if isinstance(assistant_current, AssistantText) + else None + ) + if cur_text is not None and ac.text is None: + assistant_new_content.append(assistant_current) + else: + assistant_new_content.append(ac) + updated_assistant = event.model_copy( + update={"content": assistant_new_content} + ) + new_history[existing_index] = updated_assistant + elif event.role == "user" and existing_item.role == "user": + user_existing_content = existing_item.content + user_incoming = event.content + + # Start from incoming content (prefer latest fields) + user_new_content: list[InputText | InputAudio | InputImage] = list( + user_incoming + ) + + # Merge by type with special handling for images and transcripts + def _image_url_str(val: object) -> str | None: + if isinstance(val, InputImage): + return val.image_url or None + return None + + # 1) Preserve any existing images that are missing from the incoming payload + incoming_image_urls: set[str] = set() + for part in user_incoming: + if isinstance(part, InputImage): + u = _image_url_str(part) + if u: + incoming_image_urls.add(u) + + missing_images: list[InputImage] = [] + for part in user_existing_content: + if isinstance(part, InputImage): + u = _image_url_str(part) + if u and u not in incoming_image_urls: + missing_images.append(part) + + # Insert missing images at the beginning to keep them visible and stable + if missing_images: + user_new_content = missing_images + user_new_content + + # 2) For text/audio entries, preserve existing when incoming entry is empty + merged: list[InputText | InputAudio | InputImage] = [] + for idx, uc in enumerate(user_new_content): + if uc.type == "input_audio": + # Attempt to preserve transcript if empty + transcript = getattr(uc, "transcript", None) + if transcript is None and idx < len(user_existing_content): + prev = user_existing_content[idx] + if isinstance(prev, InputAudio) and prev.transcript is not None: + uc = uc.model_copy(update={"transcript": prev.transcript}) + merged.append(uc) + elif uc.type == "input_text": + text = getattr(uc, "text", None) + if (text is None or text == "") and idx < len( + user_existing_content + ): + prev = user_existing_content[idx] + if isinstance(prev, InputText) and prev.text: + uc = uc.model_copy(update={"text": prev.text}) + merged.append(uc) + else: + merged.append(uc) + + updated_user = event.model_copy(update={"content": merged}) + new_history[existing_index] = updated_user + elif event.role == "system" and existing_item.role == "system": + system_existing_content = existing_item.content + system_incoming = event.content + # Prefer existing non-empty text when incoming is empty + system_new_content: list[InputText] = [] + for idx, sc in enumerate(system_incoming): + if idx >= len(system_existing_content): + system_new_content.append(sc) + continue + system_current = system_existing_content[idx] + cur_text = system_current.text + if cur_text is not None and sc.text is None: + system_new_content.append(system_current) + else: + system_new_content.append(sc) + updated_system = event.model_copy(update={"content": system_new_content}) + new_history[existing_index] = updated_system + else: + # Role changed or mismatched; just replace + new_history[existing_index] = event + else: + # If the existing item is not a message, just replace it. + new_history[existing_index] = event return new_history + # Otherwise, insert it after the previous_item_id if that is set elif event.previous_item_id: # Insert the new item after the previous item @@ -628,6 +772,9 @@ async def _get_updated_model_settings_from_agent( # Start with the merged base settings from run and model configuration. updated_settings = self._base_model_settings.copy() + if agent.prompt is not None: + updated_settings["prompt"] = agent.prompt + instructions, tools, handoffs = await asyncio.gather( agent.get_system_prompt(self._context_wrapper), agent.get_all_tools(self._context_wrapper), diff --git a/src/agents/voice/input.py b/src/agents/voice/input.py index 8cbc8b735..d59ceea21 100644 --- a/src/agents/voice/input.py +++ b/src/agents/voice/input.py @@ -13,7 +13,7 @@ def _buffer_to_audio_file( - buffer: npt.NDArray[np.int16 | np.float32], + buffer: npt.NDArray[np.int16 | np.float32 | np.float64], frame_rate: int = DEFAULT_SAMPLE_RATE, sample_width: int = 2, channels: int = 1, diff --git a/src/agents/voice/models/openai_stt.py b/src/agents/voice/models/openai_stt.py index 19e91d9be..12333b025 100644 --- a/src/agents/voice/models/openai_stt.py +++ b/src/agents/voice/models/openai_stt.py @@ -278,7 +278,6 @@ async def _process_websocket_connection(self) -> None: "wss://api.openai.com/v1/realtime?intent=transcription", additional_headers={ "Authorization": f"Bearer {self._client.api_key}", - "OpenAI-Beta": "realtime=v1", "OpenAI-Log-Session": "1", }, ) as ws: diff --git a/tests/conftest.py b/tests/conftest.py index b73d734d1..1e11e086a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,17 @@ def setup_span_processor(): set_trace_processors([SPAN_PROCESSOR_TESTING]) +# Ensure a default OpenAI API key is present for tests that construct clients +# without explicitly configuring a key/client. Tests that need no key use +# monkeypatch.delenv("OPENAI_API_KEY", ...) to remove it locally. +@pytest.fixture(scope="session", autouse=True) +def ensure_openai_api_key(): + import os + + if not os.environ.get("OPENAI_API_KEY"): + os.environ["OPENAI_API_KEY"] = "test_key" + + # This fixture will run before each test @pytest.fixture(autouse=True) def clear_span_processor(): diff --git a/tests/realtime/test_audio_formats_unit.py b/tests/realtime/test_audio_formats_unit.py new file mode 100644 index 000000000..5c621d462 --- /dev/null +++ b/tests/realtime/test_audio_formats_unit.py @@ -0,0 +1,28 @@ +from openai.types.realtime.realtime_audio_formats import AudioPCM + +from agents.realtime.audio_formats import to_realtime_audio_format + + +def test_to_realtime_audio_format_from_strings(): + assert to_realtime_audio_format("pcm").type == "audio/pcm" # type: ignore[union-attr] + assert to_realtime_audio_format("pcm16").type == "audio/pcm" # type: ignore[union-attr] + assert to_realtime_audio_format("audio/pcm").type == "audio/pcm" # type: ignore[union-attr] + assert to_realtime_audio_format("pcmu").type == "audio/pcmu" # type: ignore[union-attr] + assert to_realtime_audio_format("audio/pcmu").type == "audio/pcmu" # type: ignore[union-attr] + assert to_realtime_audio_format("g711_ulaw").type == "audio/pcmu" # type: ignore[union-attr] + assert to_realtime_audio_format("pcma").type == "audio/pcma" # type: ignore[union-attr] + assert to_realtime_audio_format("audio/pcma").type == "audio/pcma" # type: ignore[union-attr] + assert to_realtime_audio_format("g711_alaw").type == "audio/pcma" # type: ignore[union-attr] + + +def test_to_realtime_audio_format_passthrough_and_unknown_logs(): + fmt = AudioPCM(type="audio/pcm", rate=24000) + # Passing a RealtimeAudioFormats should return the same instance + assert to_realtime_audio_format(fmt) is fmt + + # Unknown string returns None (and logs at debug level internally) + assert to_realtime_audio_format("something_else") is None + + +def test_to_realtime_audio_format_none(): + assert to_realtime_audio_format(None) is None diff --git a/tests/realtime/test_conversion_helpers.py b/tests/realtime/test_conversion_helpers.py index 2d84c8c49..535621f13 100644 --- a/tests/realtime/test_conversion_helpers.py +++ b/tests/realtime/test_conversion_helpers.py @@ -3,15 +3,14 @@ import base64 from unittest.mock import Mock -from openai.types.beta.realtime.conversation_item import ConversationItem -from openai.types.beta.realtime.conversation_item_create_event import ConversationItemCreateEvent -from openai.types.beta.realtime.conversation_item_truncate_event import ( - ConversationItemTruncateEvent, -) -from openai.types.beta.realtime.input_audio_buffer_append_event import InputAudioBufferAppendEvent -from openai.types.beta.realtime.session_update_event import ( - SessionTracingTracingConfiguration, +import pytest +from openai.types.realtime.conversation_item_create_event import ConversationItemCreateEvent +from openai.types.realtime.conversation_item_truncate_event import ConversationItemTruncateEvent +from openai.types.realtime.input_audio_buffer_append_event import InputAudioBufferAppendEvent +from openai.types.realtime.realtime_conversation_item_function_call_output import ( + RealtimeConversationItemFunctionCallOutput, ) +from pydantic import ValidationError from agents.realtime.config import RealtimeModelTracingConfig from agents.realtime.model_inputs import ( @@ -34,6 +33,8 @@ def test_try_convert_raw_message_valid_session_update(self): "type": "session.update", "other_data": { "session": { + "model": "gpt-realtime", + "type": "realtime", "modalities": ["text", "audio"], "voice": "ash", } @@ -125,7 +126,8 @@ def test_convert_tracing_config_dict_full(self): result = _ConversionHelper.convert_tracing_config(tracing_config) - assert isinstance(result, SessionTracingTracingConfiguration) + assert result is not None + assert result != "auto" assert result.group_id == "test-group" assert result.metadata == {"env": "test"} assert result.workflow_name == "test-workflow" @@ -138,7 +140,8 @@ def test_convert_tracing_config_dict_partial(self): result = _ConversionHelper.convert_tracing_config(tracing_config) - assert isinstance(result, SessionTracingTracingConfiguration) + assert result is not None + assert result != "auto" assert result.group_id == "test-group" assert result.metadata is None assert result.workflow_name is None @@ -149,7 +152,8 @@ def test_convert_tracing_config_empty_dict(self): result = _ConversionHelper.convert_tracing_config(tracing_config) - assert isinstance(result, SessionTracingTracingConfiguration) + assert result is not None + assert result != "auto" assert result.group_id is None assert result.metadata is None assert result.workflow_name is None @@ -164,7 +168,6 @@ def test_convert_user_input_to_conversation_item_string(self): result = _ConversionHelper.convert_user_input_to_conversation_item(event) - assert isinstance(result, ConversationItem) assert result.type == "message" assert result.role == "user" assert result.content is not None @@ -186,7 +189,6 @@ def test_convert_user_input_to_conversation_item_dict(self): result = _ConversionHelper.convert_user_input_to_conversation_item(event) - assert isinstance(result, ConversationItem) assert result.type == "message" assert result.role == "user" assert result.content is not None @@ -207,7 +209,6 @@ def test_convert_user_input_to_conversation_item_dict_empty_content(self): result = _ConversionHelper.convert_user_input_to_conversation_item(event) - assert isinstance(result, ConversationItem) assert result.type == "message" assert result.role == "user" assert result.content is not None @@ -221,7 +222,6 @@ def test_convert_user_input_to_item_create(self): assert isinstance(result, ConversationItemCreateEvent) assert result.type == "conversation.item.create" - assert isinstance(result.item, ConversationItem) assert result.item.type == "message" assert result.item.role == "user" @@ -287,10 +287,11 @@ def test_convert_tool_output(self): assert isinstance(result, ConversationItemCreateEvent) assert result.type == "conversation.item.create" - assert isinstance(result.item, ConversationItem) assert result.item.type == "function_call_output" - assert result.item.output == "Function executed successfully" - assert result.item.call_id == "call_123" + assert isinstance(result.item, RealtimeConversationItemFunctionCallOutput) + tool_output_item = result.item + assert tool_output_item.output == "Function executed successfully" + assert tool_output_item.call_id == "call_123" def test_convert_tool_output_no_call_id(self): """Test converting tool output with None call_id.""" @@ -303,11 +304,11 @@ def test_convert_tool_output_no_call_id(self): start_response=False, ) - result = _ConversionHelper.convert_tool_output(event) - - assert isinstance(result, ConversationItemCreateEvent) - assert result.type == "conversation.item.create" - assert result.item.call_id is None + with pytest.raises( + ValidationError, + match="1 validation error for RealtimeConversationItemFunctionCallOutput", + ): + _ConversionHelper.convert_tool_output(event) def test_convert_tool_output_empty_output(self): """Test converting tool output with empty output.""" @@ -323,6 +324,8 @@ def test_convert_tool_output_empty_output(self): result = _ConversionHelper.convert_tool_output(event) assert isinstance(result, ConversationItemCreateEvent) + assert result.type == "conversation.item.create" + assert isinstance(result.item, RealtimeConversationItemFunctionCallOutput) assert result.item.output == "" assert result.item.call_id == "call_456" diff --git a/tests/realtime/test_ga_session_update_normalization.py b/tests/realtime/test_ga_session_update_normalization.py new file mode 100644 index 000000000..7056e8c96 --- /dev/null +++ b/tests/realtime/test_ga_session_update_normalization.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Any, cast + +import pytest +from websockets.asyncio.client import ClientConnection + +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel + + +class _DummyWS: + def __init__(self) -> None: + self.sent: list[str] = [] + + async def send(self, data: str) -> None: + self.sent.append(data) + + +@pytest.mark.asyncio +async def test_no_auto_interrupt_on_vad_speech_started(monkeypatch: Any) -> None: + model = OpenAIRealtimeWebSocketModel() + + called = {"interrupt": False} + + async def _fake_interrupt(event: Any) -> None: + called["interrupt"] = True + + # Prevent network use; _websocket only needed for other paths + model._websocket = cast(ClientConnection, _DummyWS()) + monkeypatch.setattr(model, "_send_interrupt", _fake_interrupt) + + # This event previously triggered an interrupt; now it should be ignored + await model._handle_ws_event({"type": "input_audio_buffer.speech_started"}) + + assert called["interrupt"] is False diff --git a/tests/realtime/test_item_parsing.py b/tests/realtime/test_item_parsing.py index ba128f7fd..e8484a58f 100644 --- a/tests/realtime/test_item_parsing.py +++ b/tests/realtime/test_item_parsing.py @@ -1,5 +1,15 @@ -from openai.types.beta.realtime.conversation_item import ConversationItem -from openai.types.beta.realtime.conversation_item_content import ConversationItemContent +from openai.types.realtime.realtime_conversation_item_assistant_message import ( + Content as AssistantMessageContent, + RealtimeConversationItemAssistantMessage, +) +from openai.types.realtime.realtime_conversation_item_system_message import ( + Content as SystemMessageContent, + RealtimeConversationItemSystemMessage, +) +from openai.types.realtime.realtime_conversation_item_user_message import ( + Content as UserMessageContent, + RealtimeConversationItemUserMessage, +) from agents.realtime.items import ( AssistantMessageItem, @@ -11,14 +21,12 @@ def test_user_message_conversion() -> None: - item = ConversationItem( + item = RealtimeConversationItemUserMessage( id="123", type="message", role="user", content=[ - ConversationItemContent( - id=None, audio=None, text=None, transcript=None, type="input_text" - ) + UserMessageContent(type="input_text", text=None), ], ) @@ -28,14 +36,12 @@ def test_user_message_conversion() -> None: assert isinstance(converted, UserMessageItem) - item = ConversationItem( + item = RealtimeConversationItemUserMessage( id="123", type="message", role="user", content=[ - ConversationItemContent( - id=None, audio=None, text=None, transcript=None, type="input_audio" - ) + UserMessageContent(type="input_audio", audio=None), ], ) @@ -45,13 +51,11 @@ def test_user_message_conversion() -> None: def test_assistant_message_conversion() -> None: - item = ConversationItem( + item = RealtimeConversationItemAssistantMessage( id="123", type="message", role="assistant", - content=[ - ConversationItemContent(id=None, audio=None, text=None, transcript=None, type="text") - ], + content=[AssistantMessageContent(type="output_text", text=None)], ) converted: RealtimeMessageItem = _ConversionHelper.conversation_item_to_realtime_message_item( @@ -62,15 +66,11 @@ def test_assistant_message_conversion() -> None: def test_system_message_conversion() -> None: - item = ConversationItem( + item = RealtimeConversationItemSystemMessage( id="123", type="message", role="system", - content=[ - ConversationItemContent( - id=None, audio=None, text=None, transcript=None, type="input_text" - ) - ], + content=[SystemMessageContent(type="input_text", text=None)], ) converted: RealtimeMessageItem = _ConversionHelper.conversation_item_to_realtime_message_item( diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 08b8d878f..34352df44 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -1,15 +1,24 @@ -from typing import Any +from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch import pytest import websockets +from agents import Agent from agents.exceptions import UserError +from agents.handoffs import handoff from agents.realtime.model_events import ( RealtimeModelAudioEvent, RealtimeModelErrorEvent, RealtimeModelToolCallEvent, ) +from agents.realtime.model_inputs import ( + RealtimeModelSendAudio, + RealtimeModelSendInterrupt, + RealtimeModelSendSessionUpdate, + RealtimeModelSendToolOutput, + RealtimeModelSendUserInput, +) from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel @@ -77,7 +86,7 @@ def mock_create_task_func(coro): assert ( call_args[1]["additional_headers"]["Authorization"] == "Bearer test-api-key-123" ) - assert call_args[1]["additional_headers"]["OpenAI-Beta"] == "realtime=v1" + assert call_args[1]["additional_headers"].get("OpenAI-Beta") is None # Verify task was created for message listening mock_create_task.assert_called_once() @@ -228,7 +237,7 @@ async def test_handle_invalid_event_schema_logs_error(self, model): mock_listener = AsyncMock() model.add_listener(mock_listener) - invalid_event = {"type": "response.audio.delta"} # Missing required fields + invalid_event = {"type": "response.output_audio.delta"} # Missing required fields await model._handle_ws_event(invalid_event) @@ -267,7 +276,7 @@ async def test_handle_audio_delta_event_success(self, model): # Valid audio delta event (minimal required fields for OpenAI spec) audio_event = { - "type": "response.audio.delta", + "type": "response.output_audio.delta", "event_id": "event_123", "response_id": "resp_123", "item_id": "item_456", @@ -293,6 +302,165 @@ async def test_handle_audio_delta_event_success(self, model): assert audio_state is not None assert audio_state.audio_length_ms > 0 # Should have some audio length + @pytest.mark.asyncio + async def test_backward_compat_output_item_added_and_done(self, model): + """response.output_item.added/done paths emit item updates.""" + listener = AsyncMock() + model.add_listener(listener) + + msg_added = { + "type": "response.output_item.added", + "item": { + "id": "m1", + "type": "message", + "role": "assistant", + "content": [ + {"type": "text", "text": "hello"}, + {"type": "audio", "audio": "...", "transcript": "hi"}, + ], + }, + } + await model._handle_ws_event(msg_added) + + msg_done = { + "type": "response.output_item.done", + "item": { + "id": "m1", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "bye"}], + }, + } + await model._handle_ws_event(msg_done) + + # Ensure we emitted item_updated events for both cases + types = [c[0][0].type for c in listener.on_event.call_args_list] + assert types.count("item_updated") >= 2 + + # Note: response.created/done require full OpenAI response payload which is + # out-of-scope for unit tests here; covered indirectly via other branches. + + @pytest.mark.asyncio + async def test_transcription_related_and_timeouts_and_speech_started(self, model, monkeypatch): + listener = AsyncMock() + model.add_listener(listener) + + # Prepare tracker state to simulate ongoing audio + model._audio_state_tracker.set_audio_format("pcm16") + model._audio_state_tracker.on_audio_delta("i1", 0, b"aaaa") + model._ongoing_response = True + + # Patch sending to avoid websocket dependency + monkeypatch.setattr( + model, + "_send_raw_message", + AsyncMock(), + ) + + # Speech started should emit interrupted and cancel the response + await model._handle_ws_event( + { + "type": "input_audio_buffer.speech_started", + "event_id": "es1", + "item_id": "i1", + "audio_start_ms": 0, + "audio_end_ms": 1, + } + ) + + # Output transcript delta + await model._handle_ws_event( + { + "type": "response.output_audio_transcript.delta", + "event_id": "e3", + "item_id": "i3", + "response_id": "r3", + "output_index": 0, + "content_index": 0, + "delta": "abc", + } + ) + + # Timeout triggered + await model._handle_ws_event( + { + "type": "input_audio_buffer.timeout_triggered", + "event_id": "e4", + "item_id": "i4", + "audio_start_ms": 0, + "audio_end_ms": 100, + } + ) + + # raw + interrupted, raw + transcript delta, raw + timeout + assert listener.on_event.call_count >= 6 + types = [call[0][0].type for call in listener.on_event.call_args_list] + assert "audio_interrupted" in types + assert "transcript_delta" in types + assert "input_audio_timeout_triggered" in types + + +class TestSendEventAndConfig(TestOpenAIRealtimeWebSocketModel): + @pytest.mark.asyncio + async def test_send_event_dispatch(self, model, monkeypatch): + send_raw = AsyncMock() + monkeypatch.setattr(model, "_send_raw_message", send_raw) + + await model.send_event(RealtimeModelSendUserInput(user_input="hi")) + await model.send_event(RealtimeModelSendAudio(audio=b"a", commit=False)) + await model.send_event(RealtimeModelSendAudio(audio=b"a", commit=True)) + await model.send_event( + RealtimeModelSendToolOutput( + tool_call=RealtimeModelToolCallEvent(name="t", call_id="c", arguments="{}"), + output="ok", + start_response=True, + ) + ) + await model.send_event(RealtimeModelSendInterrupt()) + await model.send_event(RealtimeModelSendSessionUpdate(session_settings={"voice": "nova"})) + + # user_input -> 2 raw messages (item.create + response.create) + # audio append -> 1, commit -> +1 + # tool output -> 1 + # interrupt -> 1 + # session update -> 1 + assert send_raw.await_count == 8 + + def test_add_remove_listener_and_tools_conversion(self, model): + listener = AsyncMock() + model.add_listener(listener) + model.add_listener(listener) + assert len(model._listeners) == 1 + model.remove_listener(listener) + assert len(model._listeners) == 0 + + # tools conversion rejects non function tools and includes handoffs + with pytest.raises(UserError): + from agents.tool import Tool + + class X: + name = "x" + + model._tools_to_session_tools(cast(list[Tool], [X()]), []) + + h = handoff(Agent(name="a")) + out = model._tools_to_session_tools([], [h]) + assert out[0].name.startswith("transfer_to_") + + def test_get_and_update_session_config(self, model): + settings = { + "model_name": "gpt-realtime", + "voice": "verse", + "output_audio_format": "g711_ulaw", + "modalities": ["audio"], + "input_audio_format": "pcm16", + "input_audio_transcription": {"model": "gpt-4o-mini-transcribe"}, + "turn_detection": {"type": "semantic_vad", "interrupt_response": True}, + } + cfg = model._get_session_config(settings) + assert cfg.audio is not None and cfg.audio.output is not None + assert cfg.audio.output.voice == "verse" + @pytest.mark.asyncio async def test_handle_error_event_success(self, model): """Test successful handling of error events.""" @@ -363,7 +531,7 @@ async def test_audio_timing_calculation_accuracy(self, model): # Send multiple audio deltas to test cumulative timing audio_deltas = [ { - "type": "response.audio.delta", + "type": "response.output_audio.delta", "event_id": "event_1", "response_id": "resp_1", "item_id": "item_1", @@ -372,7 +540,7 @@ async def test_audio_timing_calculation_accuracy(self, model): "delta": "dGVzdA==", # 4 bytes -> "test" }, { - "type": "response.audio.delta", + "type": "response.output_audio.delta", "event_id": "event_2", "response_id": "resp_1", "item_id": "item_1", diff --git a/tests/realtime/test_openai_realtime_conversions.py b/tests/realtime/test_openai_realtime_conversions.py new file mode 100644 index 000000000..2597b7dce --- /dev/null +++ b/tests/realtime/test_openai_realtime_conversions.py @@ -0,0 +1,103 @@ +from typing import cast + +import pytest +from openai.types.realtime.realtime_conversation_item_user_message import ( + RealtimeConversationItemUserMessage, +) +from openai.types.realtime.realtime_tracing_config import ( + TracingConfiguration, +) + +from agents import Agent +from agents.exceptions import UserError +from agents.handoffs import handoff +from agents.realtime.config import RealtimeModelTracingConfig +from agents.realtime.model_inputs import ( + RealtimeModelSendRawMessage, + RealtimeModelSendUserInput, + RealtimeModelUserInputMessage, +) +from agents.realtime.openai_realtime import ( + OpenAIRealtimeWebSocketModel, + _ConversionHelper, + get_api_key, +) +from agents.tool import Tool + + +@pytest.mark.asyncio +async def test_get_api_key_from_env(monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "env-key") + assert await get_api_key(None) == "env-key" + + +@pytest.mark.asyncio +async def test_get_api_key_from_callable_async(): + async def f(): + return "k" + + assert await get_api_key(f) == "k" + + +def test_try_convert_raw_message_invalid_returns_none(): + msg = RealtimeModelSendRawMessage(message={"type": "invalid.event", "other_data": {}}) + assert _ConversionHelper.try_convert_raw_message(msg) is None + + +def test_convert_user_input_to_conversation_item_dict_and_str(): + # Dict with mixed, including unknown parts (silently skipped) + dict_input_any = { + "type": "message", + "role": "user", + "content": [ + {"type": "input_text", "text": "hello"}, + {"type": "input_image", "image_url": "http://x/y.png", "detail": "auto"}, + {"type": "bogus", "x": 1}, + ], + } + event = RealtimeModelSendUserInput( + user_input=cast(RealtimeModelUserInputMessage, dict_input_any) + ) + item_any = _ConversionHelper.convert_user_input_to_conversation_item(event) + item = cast(RealtimeConversationItemUserMessage, item_any) + assert item.role == "user" + + # String input becomes input_text + event2 = RealtimeModelSendUserInput(user_input="hi") + item2_any = _ConversionHelper.convert_user_input_to_conversation_item(event2) + item2 = cast(RealtimeConversationItemUserMessage, item2_any) + assert item2.content[0].type == "input_text" + + +def test_convert_tracing_config_variants(): + from agents.realtime.openai_realtime import _ConversionHelper as CH + + assert CH.convert_tracing_config(None) is None + assert CH.convert_tracing_config("auto") == "auto" + cfg: RealtimeModelTracingConfig = { + "group_id": "g", + "metadata": {"k": "v"}, + "workflow_name": "wf", + } + oc_any = CH.convert_tracing_config(cfg) + oc = cast(TracingConfiguration, oc_any) + assert oc.group_id == "g" + assert oc.workflow_name == "wf" + + +def test_tools_to_session_tools_raises_on_non_function_tool(): + class NotFunctionTool: + def __init__(self): + self.name = "x" + + m = OpenAIRealtimeWebSocketModel() + with pytest.raises(UserError): + m._tools_to_session_tools(cast(list[Tool], [NotFunctionTool()]), []) + + +def test_tools_to_session_tools_includes_handoffs(): + a = Agent(name="a") + h = handoff(a) + m = OpenAIRealtimeWebSocketModel() + out = m._tools_to_session_tools([], [h]) + assert out[0].name is not None and out[0].name.startswith("transfer_to_") diff --git a/tests/realtime/test_playback_tracker_manual_unit.py b/tests/realtime/test_playback_tracker_manual_unit.py new file mode 100644 index 000000000..35adc1264 --- /dev/null +++ b/tests/realtime/test_playback_tracker_manual_unit.py @@ -0,0 +1,23 @@ +from agents.realtime.model import RealtimePlaybackTracker + + +def test_playback_tracker_on_play_bytes_and_state(): + tr = RealtimePlaybackTracker() + tr.set_audio_format("pcm16") # PCM path + + # 48k bytes -> (48000 / 24 / 2) * 1000 = 1,000,000ms per current util + tr.on_play_bytes("item1", 0, b"x" * 48000) + st = tr.get_state() + assert st["current_item_id"] == "item1" + assert st["elapsed_ms"] and abs(st["elapsed_ms"] - 1_000_000.0) < 1e-6 + + # Subsequent play on same item accumulates + tr.on_play_ms("item1", 0, 500.0) + st2 = tr.get_state() + assert st2["elapsed_ms"] and abs(st2["elapsed_ms"] - 1_000_500.0) < 1e-6 + + # Interruption clears state + tr.on_interrupted() + st3 = tr.get_state() + assert st3["current_item_id"] is None + assert st3["elapsed_ms"] is None diff --git a/tests/realtime/test_realtime_handoffs.py b/tests/realtime/test_realtime_handoffs.py index 07385fe20..a94c06bb0 100644 --- a/tests/realtime/test_realtime_handoffs.py +++ b/tests/realtime/test_realtime_handoffs.py @@ -1,11 +1,14 @@ """Tests for realtime handoff functionality.""" +from typing import Any from unittest.mock import Mock import pytest from agents import Agent +from agents.exceptions import ModelBehaviorError, UserError from agents.realtime import RealtimeAgent, realtime_handoff +from agents.run_context import RunContextWrapper def test_realtime_handoff_creation(): @@ -94,3 +97,58 @@ def test_type_annotations_work(): # This should be typed as Handoff[Any, RealtimeAgent[Any]] assert isinstance(handoff_obj, Handoff) + + +def test_realtime_handoff_invalid_param_counts_raise(): + rt = RealtimeAgent(name="x") + + # on_handoff with input_type but wrong param count + def bad2(a): # only one parameter + return None + + with pytest.raises(UserError): + realtime_handoff(rt, on_handoff=bad2, input_type=int) # type: ignore[arg-type] + + # on_handoff without input but wrong param count + def bad1(a, b): # two parameters + return None + + with pytest.raises(UserError): + realtime_handoff(rt, on_handoff=bad1) # type: ignore[arg-type] + + +@pytest.mark.asyncio +async def test_realtime_handoff_missing_input_json_raises_model_error(): + rt = RealtimeAgent(name="x") + + async def with_input(ctx: RunContextWrapper[Any], data: int): # simple non-object type + return None + + h = realtime_handoff(rt, on_handoff=with_input, input_type=int) + + with pytest.raises(ModelBehaviorError): + await h.on_invoke_handoff(RunContextWrapper(None), "null") + + +@pytest.mark.asyncio +async def test_realtime_handoff_is_enabled_async(monkeypatch): + rt = RealtimeAgent(name="x") + + async def is_enabled(ctx, agent): + return True + + h = realtime_handoff(rt, is_enabled=is_enabled) + + # Patch missing symbol in module to satisfy isinstance in closure + import agents.realtime.handoffs as rh + + if not hasattr(rh, "RealtimeAgent"): + from agents.realtime import RealtimeAgent as _RT + + rh.RealtimeAgent = _RT # type: ignore[attr-defined] + + from collections.abc import Awaitable + from typing import cast as _cast + + assert callable(h.is_enabled) + assert await _cast(Awaitable[bool], h.is_enabled(RunContextWrapper(None), rt)) diff --git a/tests/realtime/test_session.py b/tests/realtime/test_session.py index 66db03ef1..7ffb6d981 100644 --- a/tests/realtime/test_session.py +++ b/tests/realtime/test_session.py @@ -1,8 +1,10 @@ -from typing import cast -from unittest.mock import AsyncMock, Mock, PropertyMock +import asyncio +from typing import Any, cast +from unittest.mock import AsyncMock, Mock, PropertyMock, patch import pytest +from agents.exceptions import UserError from agents.guardrail import GuardrailFunctionOutput, OutputGuardrail from agents.handoffs import Handoff from agents.realtime.agent import RealtimeAgent @@ -46,12 +48,202 @@ RealtimeModelTurnEndedEvent, RealtimeModelTurnStartedEvent, ) -from agents.realtime.model_inputs import RealtimeModelSendSessionUpdate +from agents.realtime.model_inputs import ( + RealtimeModelSendAudio, + RealtimeModelSendInterrupt, + RealtimeModelSendSessionUpdate, + RealtimeModelSendUserInput, +) from agents.realtime.session import RealtimeSession from agents.tool import FunctionTool from agents.tool_context import ToolContext +class _DummyModel(RealtimeModel): + def __init__(self) -> None: + super().__init__() + self.events: list[Any] = [] + self.listeners: list[Any] = [] + + async def connect(self, options=None): # pragma: no cover - not used here + pass + + async def close(self): # pragma: no cover - not used here + pass + + async def send_event(self, event): + self.events.append(event) + + def add_listener(self, listener): + self.listeners.append(listener) + + def remove_listener(self, listener): + if listener in self.listeners: + self.listeners.remove(listener) + + +@pytest.mark.asyncio +async def test_property_and_send_helpers_and_enter_alias(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + # property + assert session.model is model + + # enter alias calls __aenter__ + async with await session.enter(): + # send helpers + await session.send_message("hi") + await session.send_audio(b"abc", commit=True) + await session.interrupt() + + # verify sent events + assert any(isinstance(e, RealtimeModelSendUserInput) for e in model.events) + assert any(isinstance(e, RealtimeModelSendAudio) and e.commit for e in model.events) + assert any(isinstance(e, RealtimeModelSendInterrupt) for e in model.events) + + +@pytest.mark.asyncio +async def test_aiter_cancel_breaks_loop_gracefully(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + async def consume(): + async for _ in session: + pass + + consumer = asyncio.create_task(consume()) + await asyncio.sleep(0.01) + consumer.cancel() + # The iterator swallows CancelledError internally and exits cleanly + await consumer + + +@pytest.mark.asyncio +async def test_transcription_completed_adds_new_user_item(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + event = RealtimeModelInputAudioTranscriptionCompletedEvent(item_id="item1", transcript="hello") + await session.on_event(event) + + # Should have appended a new user item + assert len(session._history) == 1 + assert session._history[0].type == "message" + assert session._history[0].role == "user" + + +class _FakeAudio: + # Looks like an audio part but is not an InputAudio/AssistantAudio instance + type = "audio" + transcript = None + + +@pytest.mark.asyncio +async def test_item_updated_merge_exception_path_logs_error(monkeypatch): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + # existing assistant message with transcript to preserve + existing = AssistantMessageItem( + item_id="a1", role="assistant", content=[AssistantAudio(audio=None, transcript="t")] + ) + session._history = [existing] + + # incoming message with a deliberately bogus content entry to trigger assertion path + incoming = AssistantMessageItem( + item_id="a1", role="assistant", content=[AssistantAudio(audio=None, transcript=None)] + ) + incoming.content[0] = cast(Any, _FakeAudio()) + + with patch("agents.realtime.session.logger") as mock_logger: + await session.on_event(RealtimeModelItemUpdatedEvent(item=incoming)) + # error branch should be hit + assert mock_logger.error.called + + +@pytest.mark.asyncio +async def test_handle_tool_call_handoff_invalid_result_raises(): + model = _DummyModel() + target = RealtimeAgent(name="target") + + bad_handoff = Handoff( + tool_name="switch", + tool_description="", + input_json_schema={}, + on_invoke_handoff=AsyncMock(return_value=123), # invalid return + input_filter=None, + agent_name=target.name, + is_enabled=True, + ) + + agent = RealtimeAgent(name="agent", handoffs=[bad_handoff]) + session = RealtimeSession(model, agent, None) + + with pytest.raises(UserError): + await session._handle_tool_call( + RealtimeModelToolCallEvent(name="switch", call_id="c1", arguments="{}") + ) + + +@pytest.mark.asyncio +async def test_on_guardrail_task_done_emits_error_event(): + model = _DummyModel() + agent = RealtimeAgent(name="agent") + session = RealtimeSession(model, agent, None) + + async def failing_task(): + raise ValueError("task failed") + + task = asyncio.create_task(failing_task()) + # Wait for it to finish so exception() is available + try: + await task + except Exception: # noqa: S110 + pass + + session._on_guardrail_task_done(task) + + # Allow event task to enqueue + await asyncio.sleep(0.01) + + # Should have a RealtimeError queued + err = await session._event_queue.get() + assert isinstance(err, RealtimeError) + + +@pytest.mark.asyncio +async def test_get_handoffs_async_is_enabled(monkeypatch): + # Agent includes both a direct Handoff and a RealtimeAgent (auto-converted) + target = RealtimeAgent(name="target") + other = RealtimeAgent(name="other") + + async def is_enabled(ctx, agent): + return True + + # direct handoff with async is_enabled + direct = Handoff( + tool_name="to_target", + tool_description="", + input_json_schema={}, + on_invoke_handoff=AsyncMock(return_value=target), + input_filter=None, + agent_name=target.name, + is_enabled=is_enabled, + ) + + a = RealtimeAgent(name="a", handoffs=[direct, other]) + session = RealtimeSession(_DummyModel(), a, None) + + enabled = await RealtimeSession._get_handoffs(a, session._context_wrapper) + # Both should be enabled + assert len(enabled) == 2 + + class MockRealtimeModel(RealtimeModel): def __init__(self): super().__init__() diff --git a/tests/realtime/test_session_payload_and_formats.py b/tests/realtime/test_session_payload_and_formats.py new file mode 100644 index 000000000..f3e72ae13 --- /dev/null +++ b/tests/realtime/test_session_payload_and_formats.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, cast + +import pydantic +from openai.types.realtime.realtime_audio_config import RealtimeAudioConfig +from openai.types.realtime.realtime_audio_formats import ( + AudioPCM, + AudioPCMA, + AudioPCMU, +) +from openai.types.realtime.realtime_session_create_request import ( + RealtimeSessionCreateRequest, +) +from openai.types.realtime.realtime_transcription_session_create_request import ( + RealtimeTranscriptionSessionCreateRequest, +) + +from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel as Model + + +class _DummyModel(pydantic.BaseModel): + type: str + + +def _session_with_output(fmt: Any | None) -> RealtimeSessionCreateRequest: + if fmt is None: + return RealtimeSessionCreateRequest(type="realtime", model="gpt-realtime") + return RealtimeSessionCreateRequest( + type="realtime", + model="gpt-realtime", + # Use dict for output to avoid importing non-exported symbols in tests + audio=RealtimeAudioConfig(output=cast(Any, {"format": fmt})), + ) + + +def test_normalize_session_payload_variants() -> None: + # Passthrough: already a realtime session model + rt = _session_with_output(AudioPCM(type="audio/pcm")) + assert Model._normalize_session_payload(rt) is rt + + # Transcription session instance should be ignored + ts = RealtimeTranscriptionSessionCreateRequest(type="transcription") + assert Model._normalize_session_payload(ts) is None + + # Transcription-like mapping should be ignored + transcription_mapping: Mapping[str, object] = {"type": "transcription"} + assert Model._normalize_session_payload(transcription_mapping) is None + + # Valid realtime mapping should be converted to model + realtime_mapping: Mapping[str, object] = {"type": "realtime", "model": "gpt-realtime"} + as_model = Model._normalize_session_payload(realtime_mapping) + assert isinstance(as_model, RealtimeSessionCreateRequest) + assert as_model.type == "realtime" + + # Invalid mapping returns None + invalid_mapping: Mapping[str, object] = {"type": "bogus"} + assert Model._normalize_session_payload(invalid_mapping) is None + + +def test_extract_audio_format_from_session_objects() -> None: + # Known OpenAI audio format models -> normalized names + s_pcm = _session_with_output(AudioPCM(type="audio/pcm")) + assert Model._extract_audio_format(s_pcm) == "pcm16" + + s_ulaw = _session_with_output(AudioPCMU(type="audio/pcmu")) + assert Model._extract_audio_format(s_ulaw) == "g711_ulaw" + + s_alaw = _session_with_output(AudioPCMA(type="audio/pcma")) + assert Model._extract_audio_format(s_alaw) == "g711_alaw" + + # Missing/None output format -> None + s_none = _session_with_output(None) + assert Model._extract_audio_format(s_none) is None + + +def test_normalize_audio_format_fallbacks() -> None: + # String passthrough + assert Model._normalize_audio_format("pcm24") == "pcm24" + + # Mapping with type field + assert Model._normalize_audio_format({"type": "g711_ulaw"}) == "g711_ulaw" + + # Pydantic model with type field + assert Model._normalize_audio_format(_DummyModel(type="custom")) == "custom" + + # Object with attribute 'type' + class HasType: + def __init__(self) -> None: + self.type = "weird" + + assert Model._normalize_audio_format(HasType()) == "weird" diff --git a/tests/realtime/test_tracing.py b/tests/realtime/test_tracing.py index 69de79e83..60004ab0b 100644 --- a/tests/realtime/test_tracing.py +++ b/tests/realtime/test_tracing.py @@ -1,6 +1,11 @@ +from typing import cast from unittest.mock import AsyncMock, Mock, patch import pytest +from openai.types.realtime.realtime_session_create_request import ( + RealtimeSessionCreateRequest, +) +from openai.types.realtime.realtime_tracing_config import TracingConfiguration from agents.realtime.agent import RealtimeAgent from agents.realtime.model import RealtimeModel @@ -95,15 +100,14 @@ async def async_websocket(*args, **kwargs): session_created_event = { "type": "session.created", "event_id": "event_123", - "session": {"id": "session_456"}, + "session": {"id": "session_456", "type": "realtime", "model": "gpt-realtime"}, } with patch.object(model, "_send_raw_message") as mock_send_raw_message: await model._handle_ws_event(session_created_event) # Should send session.update with tracing config - from openai.types.beta.realtime.session_update_event import ( - SessionTracingTracingConfiguration, + from openai.types.realtime.session_update_event import ( SessionUpdateEvent, ) @@ -111,9 +115,10 @@ async def async_websocket(*args, **kwargs): call_args = mock_send_raw_message.call_args[0][0] assert isinstance(call_args, SessionUpdateEvent) assert call_args.type == "session.update" - assert isinstance(call_args.session.tracing, SessionTracingTracingConfiguration) - assert call_args.session.tracing.workflow_name == "test_workflow" - assert call_args.session.tracing.group_id == "group_123" + session_req = cast(RealtimeSessionCreateRequest, call_args.session) + assert isinstance(session_req.tracing, TracingConfiguration) + assert session_req.tracing.workflow_name == "test_workflow" + assert session_req.tracing.group_id == "group_123" @pytest.mark.asyncio async def test_send_tracing_config_auto_mode(self, model, mock_websocket): @@ -136,20 +141,21 @@ async def async_websocket(*args, **kwargs): session_created_event = { "type": "session.created", "event_id": "event_123", - "session": {"id": "session_456"}, + "session": {"id": "session_456", "type": "realtime", "model": "gpt-realtime"}, } with patch.object(model, "_send_raw_message") as mock_send_raw_message: await model._handle_ws_event(session_created_event) # Should send session.update with "auto" - from openai.types.beta.realtime.session_update_event import SessionUpdateEvent + from openai.types.realtime.session_update_event import SessionUpdateEvent mock_send_raw_message.assert_called_once() call_args = mock_send_raw_message.call_args[0][0] assert isinstance(call_args, SessionUpdateEvent) assert call_args.type == "session.update" - assert call_args.session.tracing == "auto" + session_req = cast(RealtimeSessionCreateRequest, call_args.session) + assert session_req.tracing == "auto" @pytest.mark.asyncio async def test_tracing_config_none_skips_session_update(self, model, mock_websocket): @@ -160,7 +166,7 @@ async def test_tracing_config_none_skips_session_update(self, model, mock_websoc session_created_event = { "type": "session.created", "event_id": "event_123", - "session": {"id": "session_456"}, + "session": {"id": "session_456", "type": "realtime", "model": "gpt-realtime"}, } with patch.object(model, "send_event") as mock_send_event: @@ -199,15 +205,14 @@ async def async_websocket(*args, **kwargs): session_created_event = { "type": "session.created", "event_id": "event_123", - "session": {"id": "session_456"}, + "session": {"id": "session_456", "type": "realtime", "model": "gpt-realtime"}, } with patch.object(model, "_send_raw_message") as mock_send_raw_message: await model._handle_ws_event(session_created_event) # Should send session.update with complete tracing config including metadata - from openai.types.beta.realtime.session_update_event import ( - SessionTracingTracingConfiguration, + from openai.types.realtime.session_update_event import ( SessionUpdateEvent, ) @@ -215,9 +220,10 @@ async def async_websocket(*args, **kwargs): call_args = mock_send_raw_message.call_args[0][0] assert isinstance(call_args, SessionUpdateEvent) assert call_args.type == "session.update" - assert isinstance(call_args.session.tracing, SessionTracingTracingConfiguration) - assert call_args.session.tracing.workflow_name == "complex_workflow" - assert call_args.session.tracing.metadata == complex_metadata + session_req = cast(RealtimeSessionCreateRequest, call_args.session) + assert isinstance(session_req.tracing, TracingConfiguration) + assert session_req.tracing.workflow_name == "complex_workflow" + assert session_req.tracing.metadata == complex_metadata @pytest.mark.asyncio async def test_tracing_disabled_prevents_tracing(self, mock_websocket): diff --git a/tests/test_session.py b/tests/test_session.py index d249e900d..5e96d3f25 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -483,9 +483,7 @@ async def test_sqlite_session_special_characters_and_sql_injection(): items: list[TResponseInputItem] = [ {"role": "user", "content": "O'Reilly"}, {"role": "assistant", "content": "DROP TABLE sessions;"}, - {"role": "user", "content": ( - '"SELECT * FROM users WHERE name = \"admin\";"' - )}, + {"role": "user", "content": ('"SELECT * FROM users WHERE name = "admin";"')}, {"role": "assistant", "content": "Robert'); DROP TABLE students;--"}, {"role": "user", "content": "Normal message"}, ] @@ -496,17 +494,19 @@ async def test_sqlite_session_special_characters_and_sql_injection(): assert len(retrieved) == len(items) assert retrieved[0].get("content") == "O'Reilly" assert retrieved[1].get("content") == "DROP TABLE sessions;" - assert retrieved[2].get("content") == '"SELECT * FROM users WHERE name = \"admin\";"' + assert retrieved[2].get("content") == '"SELECT * FROM users WHERE name = "admin";"' assert retrieved[3].get("content") == "Robert'); DROP TABLE students;--" assert retrieved[4].get("content") == "Normal message" session.close() + @pytest.mark.asyncio async def test_sqlite_session_concurrent_access(): """ Test concurrent access to the same session to verify data integrity. """ import concurrent.futures + with tempfile.TemporaryDirectory() as temp_dir: db_path = Path(temp_dir) / "test_concurrent.db" session_id = "concurrent_test" @@ -523,6 +523,7 @@ def add_item(item): asyncio.set_event_loop(loop) loop.run_until_complete(session.add_items([item])) loop.close() + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: executor.map(add_item, items) diff --git a/tests/voice/test_input.py b/tests/voice/test_input.py index fbef84c1b..fa3951eab 100644 --- a/tests/voice/test_input.py +++ b/tests/voice/test_input.py @@ -55,8 +55,7 @@ def test_buffer_to_audio_file_invalid_dtype(): buffer = np.array([1.0, 2.0, 3.0], dtype=np.float64) with pytest.raises(UserError, match="Buffer must be a numpy array of int16 or float32"): - # Purposely ignore the type error - _buffer_to_audio_file(buffer) # type: ignore + _buffer_to_audio_file(buffer=buffer) class TestAudioInput: diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index f1ec04fdc..12c58a22c 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -112,7 +112,7 @@ async def test_session_connects_and_configures_successfully(): assert "wss://api.openai.com/v1/realtime?intent=transcription" in args[0] headers = kwargs.get("additional_headers", {}) assert headers.get("Authorization") == "Bearer FAKE_KEY" - assert headers.get("OpenAI-Beta") == "realtime=v1" + assert headers.get("OpenAI-Beta") is None assert headers.get("OpenAI-Log-Session") == "1" # Check that we sent a 'transcription_session.update' message diff --git a/uv.lock b/uv.lock index 94a8ca9c0..6c2d7556f 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.11'", @@ -1797,7 +1797,7 @@ wheels = [ [[package]] name = "openai" -version = "1.104.1" +version = "1.107.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1809,9 +1809,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/47/55/7e0242a7db611ad4a091a98ca458834b010639e94e84faca95741ded4050/openai-1.104.1.tar.gz", hash = "sha256:8b234ada6f720fa82859fb7dcecf853f8ddf3892c3038c81a9cc08bcb4cd8d86", size = 557053, upload-time = "2025-09-02T19:59:37.818Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/e0/a62daa7ff769df969cc1b782852cace79615039630b297005356f5fb46fb/openai-1.107.1.tar.gz", hash = "sha256:7c51b6b8adadfcf5cada08a613423575258b180af5ad4bc2954b36ebc0d3ad48", size = 563671, upload-time = "2025-09-10T15:04:40.288Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/64/de/af0eefab4400d2c888cea4f9b929bd5208d98aa7619c38b93554b0699d60/openai-1.104.1-py3-none-any.whl", hash = "sha256:153f2e9c60d4c8bb90f2f3ef03b6433b3c186ee9497c088d323028f777760af4", size = 928094, upload-time = "2025-09-02T19:59:36.155Z" }, + { url = "https://files.pythonhosted.org/packages/d4/12/32c19999a58eec4a695e8ce334442b6135df949f0bb61b2ceaa4fa60d3a9/openai-1.107.1-py3-none-any.whl", hash = "sha256:168f9885b1b70d13ada0868a0d0adfd538c16a02f7fd9fe063851a2c9a025e72", size = 945177, upload-time = "2025-09-10T15:04:37.782Z" }, ] [[package]] @@ -1882,7 +1882,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.11.0,<2" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, - { name = "openai", specifier = ">=1.104.1,<2" }, + { name = "openai", specifier = ">=1.107.1,<2" }, { name = "pydantic", specifier = ">=2.10,<3" }, { name = "requests", specifier = ">=2.0,<3" }, { name = "sqlalchemy", marker = "extra == 'sqlalchemy'", specifier = ">=2.0" },