diff --git a/async-openai/src/responses.rs b/async-openai/src/responses.rs index 5c2689a3..9160b7be 100644 --- a/async-openai/src/responses.rs +++ b/async-openai/src/responses.rs @@ -1,13 +1,13 @@ use crate::{ config::Config, error::OpenAIError, - types::responses::{CreateResponse, Response}, + types::responses::{CreateResponse, Response, ResponseStream}, Client, }; /// Given text input or a list of context items, the model will generate a response. /// -/// Related guide: [Responses API](https://platform.openai.com/docs/guides/responses) +/// Related guide: [Responses](https://platform.openai.com/docs/api-reference/responses) pub struct Responses<'c, C: Config> { client: &'c Client, } @@ -26,4 +26,30 @@ impl<'c, C: Config> Responses<'c, C> { pub async fn create(&self, request: CreateResponse) -> Result { self.client.post("/responses", request).await } + + /// Creates a model response for the given input with streaming. + /// + /// Response events will be sent as server-sent events as they become available, + #[crate::byot( + T0 = serde::Serialize, + R = serde::de::DeserializeOwned, + stream = "true", + where_clause = "R: std::marker::Send + 'static" + )] + #[allow(unused_mut)] + pub async fn create_stream( + &self, + mut request: CreateResponse, + ) -> Result { + #[cfg(not(feature = "byot"))] + { + if matches!(request.stream, Some(false)) { + return Err(OpenAIError::InvalidArgument( + "When stream is false, use Responses::create".into(), + )); + } + request.stream = Some(true); + } + Ok(self.client.post_stream("/responses", request).await) + } } diff --git a/async-openai/src/types/responses.rs b/async-openai/src/types/responses.rs index 4e0eeec7..78735cdb 100644 --- a/async-openai/src/types/responses.rs +++ b/async-openai/src/types/responses.rs @@ -4,9 +4,11 @@ pub use crate::types::{ ResponseFormatJsonSchema, }; use derive_builder::Builder; +use futures::Stream; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; +use std::pin::Pin; /// Role of messages in the API. #[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq)] @@ -1434,3 +1436,722 @@ pub enum Status { InProgress, Incomplete, } + +/// Event types for streaming responses from the Responses API +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(tag = "type")] +#[non_exhaustive] // Future-proof against breaking changes +pub enum ResponseEvent { + /// Response creation started + #[serde(rename = "response.created")] + ResponseCreated(ResponseCreated), + /// Processing in progress + #[serde(rename = "response.in_progress")] + ResponseInProgress(ResponseInProgress), + /// Response completed (different from done) + #[serde(rename = "response.completed")] + ResponseCompleted(ResponseCompleted), + /// Response failed + #[serde(rename = "response.failed")] + ResponseFailed(ResponseFailed), + /// Response incomplete + #[serde(rename = "response.incomplete")] + ResponseIncomplete(ResponseIncomplete), + /// Response queued + #[serde(rename = "response.queued")] + ResponseQueued(ResponseQueued), + /// Output item added + #[serde(rename = "response.output_item.added")] + ResponseOutputItemAdded(ResponseOutputItemAdded), + /// Content part added + #[serde(rename = "response.content_part.added")] + ResponseContentPartAdded(ResponseContentPartAdded), + /// Text delta update + #[serde(rename = "response.output_text.delta")] + ResponseOutputTextDelta(ResponseOutputTextDelta), + /// Text output completed + #[serde(rename = "response.output_text.done")] + ResponseOutputTextDone(ResponseOutputTextDone), + /// Refusal delta update + #[serde(rename = "response.refusal.delta")] + ResponseRefusalDelta(ResponseRefusalDelta), + /// Refusal completed + #[serde(rename = "response.refusal.done")] + ResponseRefusalDone(ResponseRefusalDone), + /// Content part completed + #[serde(rename = "response.content_part.done")] + ResponseContentPartDone(ResponseContentPartDone), + /// Output item completed + #[serde(rename = "response.output_item.done")] + ResponseOutputItemDone(ResponseOutputItemDone), + /// Function call arguments delta + #[serde(rename = "response.function_call_arguments.delta")] + ResponseFunctionCallArgumentsDelta(ResponseFunctionCallArgumentsDelta), + /// Function call arguments completed + #[serde(rename = "response.function_call_arguments.done")] + ResponseFunctionCallArgumentsDone(ResponseFunctionCallArgumentsDone), + /// File search call in progress + #[serde(rename = "response.file_search_call.in_progress")] + ResponseFileSearchCallInProgress(ResponseFileSearchCallInProgress), + /// File search call searching + #[serde(rename = "response.file_search_call.searching")] + ResponseFileSearchCallSearching(ResponseFileSearchCallSearching), + /// File search call completed + #[serde(rename = "response.file_search_call.completed")] + ResponseFileSearchCallCompleted(ResponseFileSearchCallCompleted), + /// Web search call in progress + #[serde(rename = "response.web_search_call.in_progress")] + ResponseWebSearchCallInProgress(ResponseWebSearchCallInProgress), + /// Web search call searching + #[serde(rename = "response.web_search_call.searching")] + ResponseWebSearchCallSearching(ResponseWebSearchCallSearching), + /// Web search call completed + #[serde(rename = "response.web_search_call.completed")] + ResponseWebSearchCallCompleted(ResponseWebSearchCallCompleted), + /// Reasoning summary part added + #[serde(rename = "response.reasoning_summary_part.added")] + ResponseReasoningSummaryPartAdded(ResponseReasoningSummaryPartAdded), + /// Reasoning summary part done + #[serde(rename = "response.reasoning_summary_part.done")] + ResponseReasoningSummaryPartDone(ResponseReasoningSummaryPartDone), + /// Reasoning summary text delta + #[serde(rename = "response.reasoning_summary_text.delta")] + ResponseReasoningSummaryTextDelta(ResponseReasoningSummaryTextDelta), + /// Reasoning summary text done + #[serde(rename = "response.reasoning_summary_text.done")] + ResponseReasoningSummaryTextDone(ResponseReasoningSummaryTextDone), + /// Reasoning summary delta + #[serde(rename = "response.reasoning_summary.delta")] + ResponseReasoningSummaryDelta(ResponseReasoningSummaryDelta), + /// Reasoning summary done + #[serde(rename = "response.reasoning_summary.done")] + ResponseReasoningSummaryDone(ResponseReasoningSummaryDone), + /// Image generation call in progress + #[serde(rename = "response.image_generation_call.in_progress")] + ResponseImageGenerationCallInProgress(ResponseImageGenerationCallInProgress), + /// Image generation call generating + #[serde(rename = "response.image_generation_call.generating")] + ResponseImageGenerationCallGenerating(ResponseImageGenerationCallGenerating), + /// Image generation call partial image + #[serde(rename = "response.image_generation_call.partial_image")] + ResponseImageGenerationCallPartialImage(ResponseImageGenerationCallPartialImage), + /// Image generation call completed + #[serde(rename = "response.image_generation_call.completed")] + ResponseImageGenerationCallCompleted(ResponseImageGenerationCallCompleted), + /// MCP call arguments delta + #[serde(rename = "response.mcp_call_arguments.delta")] + ResponseMcpCallArgumentsDelta(ResponseMcpCallArgumentsDelta), + /// MCP call arguments done + #[serde(rename = "response.mcp_call_arguments.done")] + ResponseMcpCallArgumentsDone(ResponseMcpCallArgumentsDone), + /// MCP call completed + #[serde(rename = "response.mcp_call.completed")] + ResponseMcpCallCompleted(ResponseMcpCallCompleted), + /// MCP call failed + #[serde(rename = "response.mcp_call.failed")] + ResponseMcpCallFailed(ResponseMcpCallFailed), + /// MCP call in progress + #[serde(rename = "response.mcp_call.in_progress")] + ResponseMcpCallInProgress(ResponseMcpCallInProgress), + /// MCP list tools completed + #[serde(rename = "response.mcp_list_tools.completed")] + ResponseMcpListToolsCompleted(ResponseMcpListToolsCompleted), + /// MCP list tools failed + #[serde(rename = "response.mcp_list_tools.failed")] + ResponseMcpListToolsFailed(ResponseMcpListToolsFailed), + /// MCP list tools in progress + #[serde(rename = "response.mcp_list_tools.in_progress")] + ResponseMcpListToolsInProgress(ResponseMcpListToolsInProgress), + /// Code interpreter call in progress + #[serde(rename = "response.code_interpreter_call.in_progress")] + ResponseCodeInterpreterCallInProgress(ResponseCodeInterpreterCallInProgress), + /// Code interpreter call interpreting + #[serde(rename = "response.code_interpreter_call.interpreting")] + ResponseCodeInterpreterCallInterpreting(ResponseCodeInterpreterCallInterpreting), + /// Code interpreter call completed + #[serde(rename = "response.code_interpreter_call.completed")] + ResponseCodeInterpreterCallCompleted(ResponseCodeInterpreterCallCompleted), + /// Code interpreter call code delta + #[serde(rename = "response.code_interpreter_call_code.delta")] + ResponseCodeInterpreterCallCodeDelta(ResponseCodeInterpreterCallCodeDelta), + /// Code interpreter call code done + #[serde(rename = "response.code_interpreter_call_code.done")] + ResponseCodeInterpreterCallCodeDone(ResponseCodeInterpreterCallCodeDone), + /// Output text annotation added + #[serde(rename = "response.output_text.annotation.added")] + ResponseOutputTextAnnotationAdded(ResponseOutputTextAnnotationAdded), + /// Error occurred + #[serde(rename = "error")] + ResponseError(ResponseError), + + /// Unknown event type + #[serde(untagged)] + Unknown(serde_json::Value), +} + +/// Stream of response events +pub type ResponseStream = Pin> + Send>>; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCreated { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseInProgress { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputItemAdded { + pub sequence_number: u64, + pub output_index: u32, + pub item: OutputItem, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseContentPartAdded { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub part: ContentPart, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputTextDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub delta: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub logprobs: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseContentPartDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub part: ContentPart, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputItemDone { + pub sequence_number: u64, + pub output_index: u32, + pub item: OutputItem, +} + +/// Response completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCompleted { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +/// Response failed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFailed { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +/// Response incomplete event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseIncomplete { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +/// Response queued event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseQueued { + pub sequence_number: u64, + pub response: ResponseMetadata, +} + +/// Text output completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputTextDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub text: String, + pub logprobs: Option>, +} + +/// Refusal delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseRefusalDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub delta: String, +} + +/// Refusal done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseRefusalDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub refusal: String, +} + +/// Function call arguments delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFunctionCallArgumentsDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub delta: String, +} + +/// Function call arguments done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFunctionCallArgumentsDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub arguments: String, +} + +/// Error event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseError { + pub sequence_number: u64, + pub code: Option, + pub message: String, + pub param: Option, +} + +/// File search call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFileSearchCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// File search call searching event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFileSearchCallSearching { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// File search call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseFileSearchCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Web search call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseWebSearchCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Web search call searching event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseWebSearchCallSearching { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Web search call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseWebSearchCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Reasoning summary part added event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryPartAdded { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub part: serde_json::Value, // Could be more specific but using Value for flexibility +} + +/// Reasoning summary part done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryPartDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub part: serde_json::Value, +} + +/// Reasoning summary text delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryTextDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub delta: String, +} + +/// Reasoning summary text done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryTextDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub text: String, +} + +/// Reasoning summary delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryDelta { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub delta: serde_json::Value, +} + +/// Reasoning summary done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseReasoningSummaryDone { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub summary_index: u32, + pub text: String, +} + +/// Image generation call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseImageGenerationCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Image generation call generating event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseImageGenerationCallGenerating { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Image generation call partial image event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseImageGenerationCallPartialImage { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub partial_image_index: u32, + pub partial_image_b64: String, +} + +/// Image generation call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseImageGenerationCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP call arguments delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallArgumentsDelta { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub delta: String, +} + +/// MCP call arguments done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallArgumentsDone { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub arguments: String, +} + +/// MCP call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP call failed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallFailed { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP list tools completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpListToolsCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP list tools failed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpListToolsFailed { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// MCP list tools in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMcpListToolsInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Code interpreter call in progress event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallInProgress { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Code interpreter call interpreting event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallInterpreting { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Code interpreter call completed event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallCompleted { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, +} + +/// Code interpreter call code delta event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallCodeDelta { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub delta: String, +} + +/// Code interpreter call code done event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseCodeInterpreterCallCodeDone { + pub sequence_number: u64, + pub output_index: u32, + pub item_id: String, + pub code: String, +} + +/// Response metadata +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseMetadata { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + pub created_at: u64, + pub status: Status, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub incomplete_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + /// Whether the model was run in background mode + #[serde(skip_serializing_if = "Option::is_none")] + pub background: Option, + /// The service tier that was actually used + #[serde(skip_serializing_if = "Option::is_none")] + pub service_tier: Option, + /// The effective value of top_logprobs parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + /// The effective value of max_tool_calls parameter + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_response_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub truncation: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, + /// Prompt cache key for improved performance + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_cache_key: Option, + /// Safety identifier for content filtering + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_identifier: Option, +} + +/// Output item +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct OutputItem { + pub id: String, + #[serde(rename = "type")] + pub item_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + /// For reasoning items - summary paragraphs + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option>, +} + +/// Content part +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ContentPart { + #[serde(rename = "type")] + pub part_type: String, + pub text: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub annotations: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub logprobs: Option>, +} + +// ===== RESPONSE COLLECTOR ===== + +/// Collects streaming response events into a complete response + +/// Output text annotation added event +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct ResponseOutputTextAnnotationAdded { + pub sequence_number: u64, + pub item_id: String, + pub output_index: u32, + pub content_index: u32, + pub annotation_index: u32, + pub annotation: TextAnnotation, +} + +/// Text annotation object for output text +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[non_exhaustive] +pub struct TextAnnotation { + #[serde(rename = "type")] + pub annotation_type: String, + pub text: String, + pub start: u32, + pub end: u32, +} diff --git a/async-openai/tests/responses.rs b/async-openai/tests/responses.rs new file mode 100644 index 00000000..03e60012 --- /dev/null +++ b/async-openai/tests/responses.rs @@ -0,0 +1,148 @@ +use async_openai::types::responses::*; +use serde_json; + +#[test] +fn test_response_event_deserialization() { + // Test basic streaming events - using actual JSON from responses API + let created_json = r#"{ + "type": "response.created", + "sequence_number": 0, + "response": { + "id": "resp_68819584a96881a082f70cbb524d3b6c00c0da7b27b3d5bd", + "object": "response", + "created_at": 1753322884, + "status": "in_progress", + "model": "o3-2025-04-16", + "background": false, + "service_tier": "auto", + "top_logprobs": 0, + "output": [], + "parallel_tool_calls": true, + "reasoning": { + "effort": "medium" + }, + "store": true, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + } + }, + "tool_choice": "auto", + "tools": [], + "top_p": 1.0, + "truncation": "disabled", + "metadata": {} + } + }"#; + + let delta_json = r#"{ + "type": "response.output_text.delta", + "sequence_number": 6, + "item_id": "msg_688195877b9081a088b67ef1d8707db800c0da7b27b3d5bd", + "output_index": 1, + "content_index": 0, + "delta": "Silent", + "logprobs": [] + }"#; + + let completed_json = r#"{ + "type": "response.completed", + "sequence_number": 26, + "response": { + "id": "resp_68819584a96881a082f70cbb524d3b6c00c0da7b27b3d5bd", + "object": "response", + "created_at": 1753322884, + "status": "completed", + "model": "o3-2025-04-16", + "usage": { + "input_tokens": 13, + "input_tokens_details": { + "audio_tokens": null, + "cached_tokens": 0 + }, + "output_tokens": 151, + "output_tokens_details": { + "accepted_prediction_tokens": null, + "audio_tokens": null, + "reasoning_tokens": 128, + "rejected_prediction_tokens": null + }, + "total_tokens": 164 + }, + "background": false, + "service_tier": "auto", + "top_logprobs": 0, + "output": [ + { + "id": "rs_68819585260c81a0b001a62df6c4164000c0da7b27b3d5bd", + "type": "reasoning", + "summary": [] + }, + { + "id": "msg_688195877b9081a088b67ef1d8707db800c0da7b27b3d5bd", + "type": "message", + "status": "completed", + "content": [ + { + "type": "output_text", + "text": "Silent lines of code \nLogic weaves through glowing night \nBugs flee with sunrise", + "annotations": [], + "logprobs": [] + } + ], + "role": "assistant" + } + ], + "parallel_tool_calls": true, + "reasoning": { + "effort": "medium" + }, + "store": true, + "temperature": 1.0, + "text": { + "format": { + "type": "text" + } + }, + "tool_choice": "auto", + "tools": [], + "top_p": 1.0, + "truncation": "disabled", + "metadata": {} + } + }"#; + + // Test deserialization + let created: ResponseEvent = serde_json::from_str(created_json).unwrap(); + let delta: ResponseEvent = serde_json::from_str(delta_json).unwrap(); + let completed: ResponseEvent = serde_json::from_str(completed_json).unwrap(); + + assert!(matches!(created, ResponseEvent::ResponseCreated(_))); + assert!(matches!(delta, ResponseEvent::ResponseOutputTextDelta(_))); + assert!(matches!(completed, ResponseEvent::ResponseCompleted(_))); + + // Test serialization round-trip + let created_serialized = serde_json::to_string(&created).unwrap(); + let _: ResponseEvent = serde_json::from_str(&created_serialized).unwrap(); +} + +#[test] +fn test_response_event_unknown() { + // Test Unknown event handling for completely unknown event types + let unknown_json = r#"{ + "type": "response.future_feature", + "sequence_number": 42, + "some_new_field": "value" + }"#; + + let event: ResponseEvent = serde_json::from_str(unknown_json).unwrap(); + match event { + ResponseEvent::Unknown(value) => { + assert_eq!(value.get("type").unwrap().as_str().unwrap(), "response.future_feature"); + assert_eq!(value.get("sequence_number").unwrap().as_u64().unwrap(), 42); + assert_eq!(value.get("some_new_field").unwrap().as_str().unwrap(), "value"); + } + _ => panic!("Expected Unknown event"), + } +} diff --git a/examples/responses-stream/Cargo.toml b/examples/responses-stream/Cargo.toml new file mode 100644 index 00000000..82eb90a7 --- /dev/null +++ b/examples/responses-stream/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "responses-stream" +version = "0.1.0" +edition = "2024" + +[dependencies] +async-openai = { path = "../../async-openai" } +tokio = { version = "1.0", features = ["full"] } +futures = "0.3" +serde_json = "1.0" diff --git a/examples/responses-stream/src/main.rs b/examples/responses-stream/src/main.rs new file mode 100644 index 00000000..fcf25a77 --- /dev/null +++ b/examples/responses-stream/src/main.rs @@ -0,0 +1,50 @@ +use async_openai::{ + Client, + types::responses::{ + CreateResponseArgs, Input, InputContent, InputItem, InputMessageArgs, ResponseEvent, Role, + }, +}; +use futures::StreamExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(); + + let request = CreateResponseArgs::default() + .model("o3") + .stream(true) + .input(Input::Items(vec![InputItem::Message( + InputMessageArgs::default() + .role(Role::User) + .content(InputContent::TextInput( + "Write a haiku about programming.".to_string(), + )) + .build()?, + )])) + .build()?; + + let mut stream = client.responses().create_stream(request).await?; + + while let Some(result) = stream.next().await { + match result { + Ok(response_event) => match &response_event { + ResponseEvent::ResponseOutputTextDelta(delta) => { + print!("{}", delta.delta); + } + ResponseEvent::ResponseCompleted(_) + | ResponseEvent::ResponseIncomplete(_) + | ResponseEvent::ResponseFailed(_) => { + break; + } + _ => {} + }, + Err(_) => { + // When a stream ends, it returns Err(OpenAIError::StreamError("Stream ended")) + // Without this, the stream will never end + break; + } + } + } + + Ok(()) +}