diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index fe2ed232..36b3cfb3 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -1,7 +1,7 @@ use std::pin::Pin; use bytes::Bytes; -use futures::{stream::StreamExt, Stream}; +use futures::{stream::StreamExt, Stream, TryStreamExt}; use reqwest::multipart::Form; use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; @@ -479,13 +479,20 @@ pub(crate) async fn stream( where O: DeserializeOwned + std::marker::Send + 'static, { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); tokio::spawn(async move { while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + let send_reuslt = match e { + reqwest_eventsource::Error::StreamEnded => { + tx.send(Err((true, OpenAIError::StreamError(e.to_string())))) + } + _ => tx.send(Err((false, OpenAIError::StreamError(e.to_string())))), + }; + + if let Err(_e) = send_reuslt { // rx dropped break; } @@ -497,7 +504,9 @@ where } let response = match serde_json::from_str::(&message.data) { - Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())), + Err(e) => { + Err((false, map_deserialization_error(e, message.data.as_bytes()))) + } Ok(output) => Ok(output), }; @@ -514,7 +523,17 @@ where event_source.close(); }); - Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) + Box::pin( + tokio_stream::wrappers::UnboundedReceiverStream::new(rx) + .scan((), |_, item| async { + if let Err((true, _)) = item { + None + } else { + Some(item) + } + }) + .map_err(|e| e.1), + ) } pub(crate) async fn stream_mapped_raw_events(